diff --git a/pkg/httpclient/client.go b/pkg/httpclient/client.go index 4e3720df1..b5e9eac85 100644 --- a/pkg/httpclient/client.go +++ b/pkg/httpclient/client.go @@ -29,10 +29,15 @@ func NewHTTPClient(opts ...Opt) *http.Client { // Enforce a consistent User-Agent header httpOptions.Header.Set("User-Agent", fmt.Sprintf("Cagent/%s (%s; %s)", version.Version, runtime.GOOS, runtime.GOARCH)) + // Disable automatic gzip: Go's default transport transparently compresses + // and decompresses responses, which is incompatible with SSE streaming. + // See https://github.com/docker/docker-agent/issues/1956 + rt := newTransport() + return &http.Client{ Transport: &userAgentTransport{ httpOptions: httpOptions, - rt: http.DefaultTransport, + rt: rt, }, } } @@ -90,6 +95,17 @@ func WithQuery(query url.Values) Opt { } } +// newTransport returns an HTTP transport with automatic gzip compression disabled. +func newTransport() http.RoundTripper { + t, ok := http.DefaultTransport.(*http.Transport) + if !ok { + return http.DefaultTransport + } + transport := t.Clone() + transport.DisableCompression = true + return transport +} + type userAgentTransport struct { httpOptions HTTPOptions rt http.RoundTripper diff --git a/pkg/httpclient/client_test.go b/pkg/httpclient/client_test.go index 949d52c51..dfe5c6f1f 100644 --- a/pkg/httpclient/client_test.go +++ b/pkg/httpclient/client_test.go @@ -9,23 +9,43 @@ import ( "github.com/stretchr/testify/require" ) -func TestWithModelName(t *testing.T) { +func TestHeaders(t *testing.T) { t.Parallel() tests := []struct { - name string - modelName string - wantSet bool + name string + opts []Opt + wantHeader string + wantValue string }{ { - name: "sets header when name is provided", - modelName: "my-fast-model", - wantSet: true, + name: "WithModel sets X-Cagent-Model", + opts: []Opt{WithModel("gpt-4o")}, + wantHeader: "X-Cagent-Model", + wantValue: "gpt-4o", }, { - name: "skips header when name is empty", - modelName: "", - wantSet: false, + name: "WithModelName sets X-Cagent-Model-Name", + opts: []Opt{WithModelName("my-fast-model")}, + wantHeader: "X-Cagent-Model-Name", + wantValue: "my-fast-model", + }, + { + name: "WithModelName skips header when empty", + opts: []Opt{WithModelName("")}, + wantHeader: "X-Cagent-Model-Name", + wantValue: "", + }, + { + name: "WithProvider sets X-Cagent-Provider", + opts: []Opt{WithProvider("openai")}, + wantHeader: "X-Cagent-Provider", + wantValue: "openai", + }, + { + name: "compression is disabled to support SSE streaming", + wantHeader: "Accept-Encoding", + wantValue: "", }, } @@ -33,51 +53,21 @@ func TestWithModelName(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - var capturedHeaders http.Header - srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - capturedHeaders = r.Header - })) - defer srv.Close() - - client := NewHTTPClient(WithModelName(tt.modelName)) - req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) - require.NoError(t, err) + headers := doRequest(t, tt.opts...) - resp, err := client.Do(req) - require.NoError(t, err) - defer func() { _ = resp.Body.Close() }() - - if tt.wantSet { - assert.Equal(t, tt.modelName, capturedHeaders.Get("X-Cagent-Model-Name")) + if tt.wantValue != "" { + assert.Equal(t, tt.wantValue, headers.Get(tt.wantHeader)) } else { - assert.Empty(t, capturedHeaders.Get("X-Cagent-Model-Name")) + assert.Empty(t, headers.Get(tt.wantHeader)) } }) } } -func TestWithModel(t *testing.T) { - t.Parallel() - - var capturedHeaders http.Header - srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - capturedHeaders = r.Header - })) - defer srv.Close() - - client := NewHTTPClient(WithModel("gpt-4o")) - req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) - require.NoError(t, err) - - resp, err := client.Do(req) - require.NoError(t, err) - defer func() { _ = resp.Body.Close() }() - - assert.Equal(t, "gpt-4o", capturedHeaders.Get("X-Cagent-Model")) -} - -func TestWithProvider(t *testing.T) { - t.Parallel() +// doRequest creates an HTTP client with the given options, sends a GET request +// to a test server, and returns the headers the server received. +func doRequest(t *testing.T, opts ...Opt) http.Header { + t.Helper() var capturedHeaders http.Header srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { @@ -85,7 +75,7 @@ func TestWithProvider(t *testing.T) { })) defer srv.Close() - client := NewHTTPClient(WithProvider("openai")) + client := NewHTTPClient(opts...) req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) require.NoError(t, err) @@ -93,5 +83,5 @@ func TestWithProvider(t *testing.T) { require.NoError(t, err) defer func() { _ = resp.Body.Close() }() - assert.Equal(t, "openai", capturedHeaders.Get("X-Cagent-Provider")) + return capturedHeaders }