diff --git a/provider/openai.go b/provider/openai.go index dd68f0d..0a2ccce 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -94,6 +94,20 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") defer tracing.EndSpanErr(span, &outErr) + cfg := p.cfg + + // In centralized mode, http.go strips Authorization (it carried the + // Coder token), so the header is absent and cfg keeps the centralized + // key. + // + // In BYOK mode, http.go only strips the BYOK header and leaves the + // user's LLM credentials intact. OpenAI uses Authorization: Bearer + // for both API keys and OAuth tokens, so we just extract the token + // and overwrite cfg.Key. + if bearer := r.Header.Get("Authorization"); bearer != "" { + cfg.Key = strings.TrimPrefix(bearer, "Bearer ") + } + var interceptor intercept.Interceptor path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix()) @@ -105,9 +119,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) } case routeResponses: @@ -120,9 +134,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace return nil, fmt.Errorf("unmarshal request body: %w", err) } if req.Stream { - interceptor = responses.NewStreamingInterceptor(id, &req, payload, p.cfg, string(req.Model), r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, string(req.Model), r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, &req, payload, p.cfg, string(req.Model), r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, string(req.Model), r.Header, p.AuthHeader(), tracer) } default: @@ -146,6 +160,12 @@ func (p *OpenAI) InjectAuthHeader(headers *http.Header) { headers = &http.Header{} } + // BYOK: if the request already carries user-supplied credentials, + // do not overwrite them with the centralized key. + if headers.Get("Authorization") != "" { + return + } + headers.Set(p.AuthHeader(), "Bearer "+p.cfg.Key) } diff --git a/provider/openai_test.go b/provider/openai_test.go index 4add332..549cbc2 100644 --- a/provider/openai_test.go +++ b/provider/openai_test.go @@ -162,66 +162,121 @@ func generateResponsesPayload(payloadSize int, inputCount int, stream bool) []by func TestOpenAI_CreateInterceptor(t *testing.T) { t.Parallel() - tests := []struct { + routes := []struct { name string route string requestBody string responseBody string }{ { - name: "ChatCompletions_ClientHeaders", + name: "ChatCompletions", route: routeChatCompletions, requestBody: `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": false}`, responseBody: chatCompletionResponse, }, { - name: "Responses_ClientHeaders", + name: "Responses", route: routeResponses, requestBody: `{"model": "gpt-5", "input": "hello", "stream": false}`, responseBody: responsesAPIResponse, }, } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() + byokCases := []struct { + name string + setHeaders map[string]string + wantAuthorization string + }{ + { + name: "Centralized_UsesCentralizedKey", + setHeaders: map[string]string{}, + wantAuthorization: "Bearer test-key", + }, + { + name: "BYOK_BearerToken", + setHeaders: map[string]string{"Authorization": "Bearer user-oauth-token"}, + wantAuthorization: "Bearer user-oauth-token", + }, + } - var receivedHeaders http.Header + for _, route := range routes { + for _, bc := range byokCases { + t.Run(route.name+"_"+bc.name, func(t *testing.T) { + t.Parallel() + + var receivedHeaders http.Header + + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(route.responseBody)) + require.NoError(t, err) + })) + t.Cleanup(mockUpstream.Close) + + provider := NewOpenAI(config.OpenAI{ + BaseURL: mockUpstream.URL, + Key: "test-key", + }) + + req := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+route.route, bytes.NewBufferString(route.requestBody)) + for k, v := range bc.setHeaders { + req.Header.Set(k, v) + } + w := httptest.NewRecorder() - mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(tc.responseBody)) + interceptor, err := provider.CreateInterceptor(w, req, testTracer) require.NoError(t, err) - })) - t.Cleanup(mockUpstream.Close) + require.NotNil(t, interceptor) + + logger := slog.Make() + interceptor.Setup(logger, &testutil.MockRecorder{}, nil) - provider := NewOpenAI(config.OpenAI{ - BaseURL: mockUpstream.URL, - Key: "test-key", + processReq := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+route.route, nil) + err = interceptor.ProcessRequest(w, processReq) + require.NoError(t, err) + + assert.Equal(t, bc.wantAuthorization, receivedHeaders.Get("Authorization")) }) + } + } +} + +func TestOpenAI_InjectAuthHeader_BYOK(t *testing.T) { + t.Parallel() - req := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+tc.route, bytes.NewBufferString(tc.requestBody)) - // Simulate a client sending its own auth credential, which must be replaced - // by aibridge with the configured provider key. - req.Header.Set("Authorization", "Bearer fake-client-bearer") - w := httptest.NewRecorder() + provider := NewOpenAI(config.OpenAI{Key: "centralized-key"}) - interceptor, err := provider.CreateInterceptor(w, req, testTracer) - require.NoError(t, err) - require.NotNil(t, interceptor) + tests := []struct { + name string + presetHeaders map[string]string + wantAuthorization string + }{ + { + name: "no pre-existing auth injects centralized key", + presetHeaders: map[string]string{}, + wantAuthorization: "Bearer centralized-key", + }, + { + name: "pre-existing Authorization is not overwritten", + presetHeaders: map[string]string{"Authorization": "Bearer user-oauth-token"}, + wantAuthorization: "Bearer user-oauth-token", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - logger := slog.Make() - interceptor.Setup(logger, &testutil.MockRecorder{}, nil) + headers := http.Header{} + for k, v := range tc.presetHeaders { + headers.Set(k, v) + } - processReq := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+tc.route, nil) - err = interceptor.ProcessRequest(w, processReq) - require.NoError(t, err) + provider.InjectAuthHeader(&headers) - // Verify aibridge's configured key was used and the client's auth credential was not forwarded. - assert.Equal(t, "Bearer test-key", receivedHeaders.Get("Authorization"), "upstream must receive configured provider key") - assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream") + assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization")) }) } }