From 64082d547ec514e83796268791792fc8c5aef415 Mon Sep 17 00:00:00 2001 From: irisyann Date: Mon, 20 Apr 2026 11:57:10 +0800 Subject: [PATCH] feat: add source identifier to API requests for telemetry --- internal/api/client.go | 5 +++++ internal/api/client_test.go | 30 ++++++++++++++++++++++++++++++ internal/ws/client.go | 2 +- internal/ws/client_test.go | 36 ++++++++++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 1 deletion(-) diff --git a/internal/api/client.go b/internal/api/client.go index c358883..4c5b21e 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -200,6 +200,11 @@ func (c *Client) baseURL() string { func (c *Client) get(ctx context.Context, path string, result any) error { url := c.baseURL() + path + if strings.Contains(path, "?") { + url += "&source=cli" + } else { + url += "?source=cli" + } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { diff --git a/internal/api/client_test.go b/internal/api/client_test.go index b3add9b..565b4da 100644 --- a/internal/api/client_test.go +++ b/internal/api/client_test.go @@ -356,6 +356,36 @@ func TestRetryAfterInvalidFallback(t *testing.T) { assert.ErrorIs(t, err, ErrRateLimited) } +func TestSourceCLIAppendedWithoutExistingQuery(t *testing.T) { + var gotSource string + c, srv := testClient(func(w http.ResponseWriter, r *http.Request) { + gotSource = r.URL.Query().Get("source") + w.WriteHeader(200) + _, _ = w.Write([]byte("{}")) + }) + defer srv.Close() + + var result map[string]any + _ = c.get(context.Background(), "/test", &result) + assert.Equal(t, "cli", gotSource) +} + +func TestSourceCLIAppendedWithExistingQuery(t *testing.T) { + var gotSource, gotExisting string + c, srv := testClient(func(w http.ResponseWriter, r *http.Request) { + gotSource = r.URL.Query().Get("source") + gotExisting = r.URL.Query().Get("foo") + w.WriteHeader(200) + _, _ = w.Write([]byte("{}")) + }) + defer srv.Close() + + var result map[string]any + _ = c.get(context.Background(), "/test?foo=bar", &result) + assert.Equal(t, "cli", gotSource) + assert.Equal(t, "bar", gotExisting) +} + func TestRequirePaid(t *testing.T) { cfg := &config.Config{Tier: config.TierDemo} c := NewClient(cfg) diff --git a/internal/ws/client.go b/internal/ws/client.go index d5be4bf..1b99fcf 100644 --- a/internal/ws/client.go +++ b/internal/ws/client.go @@ -134,7 +134,7 @@ func (c *Client) Close() error { // connect dials the WebSocket endpoint and waits for the welcome message. func (c *Client) connect(ctx context.Context) error { - url := c.wsURL + "?x_cg_pro_api_key=" + c.cfg.APIKey + url := c.wsURL + "?x_cg_pro_api_key=" + c.cfg.APIKey + "&source=cli" header := http.Header{} header.Set("User-Agent", c.UserAgent) diff --git a/internal/ws/client_test.go b/internal/ws/client_test.go index d8cc2a1..d847c58 100644 --- a/internal/ws/client_test.go +++ b/internal/ws/client_test.go @@ -379,6 +379,42 @@ func TestConnect_APIKeyInQueryParam(t *testing.T) { require.NoError(t, client.Close()) } +func TestConnect_SourceCLIQueryParam(t *testing.T) { + var receivedSource string + var mu sync.Mutex + + srv := newTestWSServer(t, func(conn *websocket.Conn) { + happyHandshake(t, conn) + for { + if _, _, err := conn.ReadMessage(); err != nil { + return + } + } + }) + origHandler := srv.Config.Handler + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + receivedSource = r.URL.Query().Get("source") + mu.Unlock() + origHandler.ServeHTTP(w, r) + }) + + client := NewClient(paidCfg(), []string{"bitcoin"}) + client.SetURL(wsURL(srv)) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err := client.Connect(ctx) + require.NoError(t, err) + + mu.Lock() + assert.Equal(t, "cli", receivedSource) + mu.Unlock() + + require.NoError(t, client.Close()) +} + func TestConnect_UserAgentHeader(t *testing.T) { var gotUA string var mu sync.Mutex