Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion pkg/httpclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@ func NewHTTPClient(opts ...Opt) *http.Client {
// Enforce a consistent User-Agent header
httpOptions.Header.Set("User-Agent", fmt.Sprintf("Cagent/%s (%s; %s)", version.Version, runtime.GOOS, runtime.GOARCH))

// Disable automatic gzip: Go's default transport transparently compresses
// and decompresses responses, which is incompatible with SSE streaming.
// See https://github.com/docker/docker-agent/issues/1956
rt := newTransport()

return &http.Client{
Transport: &userAgentTransport{
httpOptions: httpOptions,
rt: http.DefaultTransport,
rt: rt,
},
}
}
Expand Down Expand Up @@ -90,6 +95,17 @@ func WithQuery(query url.Values) Opt {
}
}

// newTransport returns an HTTP transport with automatic gzip compression disabled.
func newTransport() http.RoundTripper {
t, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return http.DefaultTransport
}
transport := t.Clone()
transport.DisableCompression = true
return transport
}

type userAgentTransport struct {
httpOptions HTTPOptions
rt http.RoundTripper
Expand Down
90 changes: 40 additions & 50 deletions pkg/httpclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,89 +9,79 @@ import (
"github.com/stretchr/testify/require"
)

func TestWithModelName(t *testing.T) {
func TestHeaders(t *testing.T) {
t.Parallel()

tests := []struct {
name string
modelName string
wantSet bool
name string
opts []Opt
wantHeader string
wantValue string
}{
{
name: "sets header when name is provided",
modelName: "my-fast-model",
wantSet: true,
name: "WithModel sets X-Cagent-Model",
opts: []Opt{WithModel("gpt-4o")},
wantHeader: "X-Cagent-Model",
wantValue: "gpt-4o",
},
{
name: "skips header when name is empty",
modelName: "",
wantSet: false,
name: "WithModelName sets X-Cagent-Model-Name",
opts: []Opt{WithModelName("my-fast-model")},
wantHeader: "X-Cagent-Model-Name",
wantValue: "my-fast-model",
},
{
name: "WithModelName skips header when empty",
opts: []Opt{WithModelName("")},
wantHeader: "X-Cagent-Model-Name",
wantValue: "",
},
{
name: "WithProvider sets X-Cagent-Provider",
opts: []Opt{WithProvider("openai")},
wantHeader: "X-Cagent-Provider",
wantValue: "openai",
},
{
name: "compression is disabled to support SSE streaming",
wantHeader: "Accept-Encoding",
wantValue: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

var capturedHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header
}))
defer srv.Close()

client := NewHTTPClient(WithModelName(tt.modelName))
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
headers := doRequest(t, tt.opts...)

resp, err := client.Do(req)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()

if tt.wantSet {
assert.Equal(t, tt.modelName, capturedHeaders.Get("X-Cagent-Model-Name"))
if tt.wantValue != "" {
assert.Equal(t, tt.wantValue, headers.Get(tt.wantHeader))
} else {
assert.Empty(t, capturedHeaders.Get("X-Cagent-Model-Name"))
assert.Empty(t, headers.Get(tt.wantHeader))
}
})
}
}

func TestWithModel(t *testing.T) {
t.Parallel()

var capturedHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header
}))
defer srv.Close()

client := NewHTTPClient(WithModel("gpt-4o"))
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()

assert.Equal(t, "gpt-4o", capturedHeaders.Get("X-Cagent-Model"))
}

func TestWithProvider(t *testing.T) {
t.Parallel()
// doRequest creates an HTTP client with the given options, sends a GET request
// to a test server, and returns the headers the server received.
func doRequest(t *testing.T, opts ...Opt) http.Header {
t.Helper()

var capturedHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header
}))
defer srv.Close()

client := NewHTTPClient(WithProvider("openai"))
client := NewHTTPClient(opts...)
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()

assert.Equal(t, "openai", capturedHeaders.Get("X-Cagent-Provider"))
return capturedHeaders
}