From f95de3ce24dffb46e4e317ef40d236c175a08cdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Mon, 16 Mar 2026 12:58:24 +0000 Subject: [PATCH 1/2] chore: remove SDK struct usage from requests in messages API Replaces MessageNewParamsWrapper that wrapped anthropic.MessageNewParams with MessagesRequestPayload that operates on raw bytes. --- intercept/messages/base.go | 91 +--- intercept/messages/base_test.go | 524 ++++++++--------------- intercept/messages/blocking.go | 129 +++--- intercept/messages/paramswrap.go | 142 ------ intercept/messages/paramswrap_test.go | 303 ------------- intercept/messages/reqpayload.go | 271 ++++++++++++ intercept/messages/reqpayload_test.go | 343 +++++++++++++++ intercept/messages/streaming.go | 56 ++- internal/integrationtest/mockupstream.go | 15 +- provider/anthropic.go | 13 +- 10 files changed, 905 insertions(+), 982 deletions(-) delete mode 100644 intercept/messages/paramswrap.go delete mode 100644 intercept/messages/paramswrap_test.go create mode 100644 intercept/messages/reqpayload.go create mode 100644 intercept/messages/reqpayload_test.go diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 09372ec7..f2f0e708 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -24,7 +24,6 @@ import ( "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" "github.com/coder/quartz" - "github.com/tidwall/sjson" "github.com/google/uuid" "go.opentelemetry.io/otel/attribute" @@ -34,9 +33,8 @@ import ( ) type interceptionBase struct { - id uuid.UUID - req *MessageNewParamsWrapper - payload []byte + id uuid.UUID + reqPayload MessagesRequestPayload cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock @@ -63,22 +61,11 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, } func (i *interceptionBase) CorrelatingToolCallID() *string { - if len(i.req.Messages) == 0 { - return nil - } - content := i.req.Messages[len(i.req.Messages)-1].Content - for idx := len(content) - 1; idx >= 0; idx-- { - block := content[idx] - if block.OfToolResult == nil { - continue - } - return &block.OfToolResult.ToolUseID - } - return nil + return i.reqPayload.correlatingToolCallID() } func (i *interceptionBase) Model() string { - if i.req == nil { + if len(i.reqPayload) == 0 { return "coder-aibridge-unknown" } @@ -90,7 +77,7 @@ func (i *interceptionBase) Model() string { return model } - return string(i.req.Model) + return i.reqPayload.model() } func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { @@ -106,7 +93,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) } func (i *interceptionBase) injectTools() { - if i.req == nil || i.mcpProxy == nil || !i.hasInjectableTools() { + if i.mcpProxy == nil || !i.hasInjectableTools() { return } @@ -131,46 +118,23 @@ func (i *interceptionBase) injectTools() { // Prepend the injected tools in order to maintain any configured cache breakpoints. // The order of injected tools is expected to be stable, and therefore will not cause // any cache invalidation when prepended. - i.req.Tools = append(injectedTools, i.req.Tools...) - - var err error - i.payload, err = sjson.SetBytes(i.payload, "tools", i.req.Tools) + updated, err := i.reqPayload.injectTools(injectedTools) if err != nil { i.logger.Warn(context.Background(), "failed to set inject tools in request payload", slog.Error(err)) + return } + i.reqPayload = updated } func (i *interceptionBase) disableParallelToolCalls() { // Note: Parallel tool calls are disabled to avoid tool_use/tool_result block mismatches. // https://github.com/coder/aibridge/issues/2 - toolChoiceType := i.req.ToolChoice.GetType() - var toolChoiceTypeStr string - if toolChoiceType != nil { - toolChoiceTypeStr = *toolChoiceType - } - - switch toolChoiceTypeStr { - // If no tool_choice was defined, assume auto. - // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use. - case "", string(constant.ValueOf[constant.Auto]()): - // We only set OfAuto if no tool_choice was provided (the default). - // "auto" is the default when a zero value is provided, so we can safely disable parallel checks on it. - if i.req.ToolChoice.OfAuto == nil { - i.req.ToolChoice.OfAuto = &anthropic.ToolChoiceAutoParam{} - } - i.req.ToolChoice.OfAuto.DisableParallelToolUse = anthropic.Bool(true) - case string(constant.ValueOf[constant.Any]()): - i.req.ToolChoice.OfAny.DisableParallelToolUse = anthropic.Bool(true) - case string(constant.ValueOf[constant.Tool]()): - i.req.ToolChoice.OfTool.DisableParallelToolUse = anthropic.Bool(true) - case string(constant.ValueOf[constant.None]()): - // No-op; if tool_choice=none then tools are not used at all. - } - var err error - i.payload, err = sjson.SetBytes(i.payload, "tool_choice", i.req.ToolChoice) + updated, err := i.reqPayload.disableParallelToolCalls() if err != nil { i.logger.Warn(context.Background(), "failed to set tool_choice in request payload", slog.Error(err)) + return } + i.reqPayload = updated } // extractModelThoughts returns any thinking blocks that were returned in the response. @@ -201,7 +165,7 @@ func (i *interceptionBase) extractModelThoughts(msg *anthropic.Message) []*recor // See `ANTHROPIC_SMALL_FAST_MODEL`: https://docs.anthropic.com/en/docs/claude-code/settings#environment-variables // https://docs.claude.com/en/docs/claude-code/costs#background-token-usage func (i *interceptionBase) isSmallFastModel() bool { - return strings.Contains(string(i.req.Model), "haiku") + return strings.Contains(i.reqPayload.model(), "haiku") } func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) { @@ -244,23 +208,12 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio return anthropic.NewMessageService(opts...), nil } -// withBody returns a per-request option that sends the current i.payload as the -// request body. This is called for each API request so that the latest payload (including -// any messages appended during the agentic tool loop) is always sent. +// withBody returns a per-request option that sends the current raw request +// payload as the request body. This is called for each API request so that the +// latest payload (including any messages appended during the agentic tool loop) +// is always sent. func (i *interceptionBase) withBody() option.RequestOption { - return option.WithRequestBody("application/json", i.payload) -} - -// syncPayloadMessages updates the raw payload's "messages" field to match the given messages. -// This must be called before the next API request in the agentic loop so that -// withBody() picks up the updated messages. -func (i *interceptionBase) syncPayloadMessages(messages []anthropic.MessageParam) error { - var err error - i.payload, err = sjson.SetBytes(i.payload, "messages", messages) - if err != nil { - return fmt.Errorf("sync payload messages: %w", err) - } - return nil + return option.WithRequestBody("application/json", []byte(i.reqPayload)) } func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) { @@ -317,13 +270,13 @@ func (i *interceptionBase) augmentRequestForBedrock() { return } - i.req.MessageNewParams.Model = anthropic.Model(i.Model()) - - var err error - i.payload, err = sjson.SetBytes(i.payload, "model", i.Model()) + updated, err := i.reqPayload.withModel(i.Model()) if err != nil { i.logger.Warn(context.Background(), "failed to set model in request payload for Bedrock", slog.Error(err)) + return } + + i.reqPayload = updated } // writeUpstreamError marshals and writes a given error. diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index cca890e0..810e5a52 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "cdr.dev/slog/v3" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge/config" @@ -11,100 +12,63 @@ import ( "github.com/coder/aibridge/utils" mcpgo "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) func TestScanForCorrelatingToolCallID(t *testing.T) { t.Parallel() - tests := []struct { - name string - messages []anthropic.MessageParam - expected *string + testCases := []struct { + name string + requestBody string + expectedToolID *string }{ { - name: "no messages", - messages: nil, - expected: nil, + name: "no messages field", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024}`, + expectedToolID: nil, }, { - name: "last message has no tool_result blocks", - messages: []anthropic.MessageParam{ - anthropic.NewUserMessage(anthropic.NewTextBlock("hello")), - }, - expected: nil, + name: "messages string", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":"test"}`, + expectedToolID: nil, }, { - name: "single tool_result block", - messages: []anthropic.MessageParam{ - anthropic.NewUserMessage( - anthropic.ContentBlockParamUnion{ - OfToolResult: &anthropic.ToolResultBlockParam{ - ToolUseID: "toolu_abc", - Content: []anthropic.ToolResultBlockParamContentUnion{ - {OfText: &anthropic.TextBlockParam{Text: "result"}}, - }, - }, - }, - ), - }, - expected: utils.PtrTo("toolu_abc"), + name: "empty messages array", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[]}`, + expectedToolID: nil, }, { - name: "multiple tool_result blocks returns last", - messages: []anthropic.MessageParam{ - anthropic.NewUserMessage( - anthropic.ContentBlockParamUnion{ - OfToolResult: &anthropic.ToolResultBlockParam{ - ToolUseID: "toolu_first", - Content: []anthropic.ToolResultBlockParamContentUnion{ - {OfText: &anthropic.TextBlockParam{Text: "first"}}, - }, - }, - }, - anthropic.NewTextBlock("some text"), - anthropic.ContentBlockParamUnion{ - OfToolResult: &anthropic.ToolResultBlockParam{ - ToolUseID: "toolu_second", - Content: []anthropic.ToolResultBlockParamContentUnion{ - {OfText: &anthropic.TextBlockParam{Text: "second"}}, - }, - }, - }, - ), - }, - expected: utils.PtrTo("toolu_second"), + name: "last message has no tool result blocks", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`, + expectedToolID: nil, }, { - name: "last message is not a tool result", - messages: []anthropic.MessageParam{ - anthropic.NewUserMessage( - anthropic.ContentBlockParamUnion{ - OfToolResult: &anthropic.ToolResultBlockParam{ - ToolUseID: "toolu_first", - Content: []anthropic.ToolResultBlockParamContentUnion{ - {OfText: &anthropic.TextBlockParam{Text: "first"}}, - }, - }, - }), - anthropic.NewUserMessage(anthropic.NewTextBlock("some text")), - }, - expected: nil, + name: "single tool result block", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_abc","content":"result"}]}]}`, + expectedToolID: utils.PtrTo("toolu_abc"), + }, + { + name: "multiple tool result blocks returns last", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"},{"type":"text","text":"ignored"},{"type":"tool_result","tool_use_id":"toolu_second","content":"second"}]}]}`, + expectedToolID: utils.PtrTo("toolu_second"), + }, + { + name: "last message is not a tool result", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"}]},{"role":"user","content":"some text"}]}`, + expectedToolID: nil, }, } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { t.Parallel() base := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: tc.messages, - }, - }, + reqPayload: mustMessagesPayload(t, testCase.requestBody), } - require.Equal(t, tc.expected, base.CorrelatingToolCallID()) + require.Equal(t, testCase.expectedToolID, base.CorrelatingToolCallID()) }) } } @@ -402,303 +366,171 @@ func TestAccumulateUsage(t *testing.T) { func TestInjectTools_CacheBreakpoints(t *testing.T) { t.Parallel() - t.Run("cache control preserved when no tools to inject", func(t *testing.T) { - t.Parallel() - - // Request has existing tool with cache control, but no tools to inject. - i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Tools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "existing_tool", - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: constant.ValueOf[constant.Ephemeral](), - }, - }, - }, - }, - }, - }, - mcpProxy: &mockServerProxier{tools: nil}, - } - - i.injectTools() - - // Cache control should remain untouched since no tools were injected. - require.Len(t, i.req.Tools, 1) - require.Equal(t, constant.ValueOf[constant.Ephemeral](), i.req.Tools[0].OfTool.CacheControl.Type) - }) - - t.Run("cache control breakpoint is preserved by prepending injected tools", func(t *testing.T) { - t.Parallel() - - // Request has existing tool with cache control. - i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Tools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "existing_tool", - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: constant.ValueOf[constant.Ephemeral](), - }, - }, - }, - }, - }, - }, - mcpProxy: &mockServerProxier{ - tools: []*mcp.Tool{ - {ID: "injected_tool", Name: "injected", Description: "Injected tool"}, - }, - }, - } + testCases := []struct { + name string + requestBody string + injectedTools []*mcp.Tool + expectedToolNames []string + expectedCacheControlTypes []string + }{ + { + name: "cache control preserved when no tools to inject", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[` + + `{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`, - i.injectTools() + injectedTools: nil, + expectedToolNames: []string{"existing_tool"}, + expectedCacheControlTypes: []string{string(constant.ValueOf[constant.Ephemeral]())}, + }, + { + name: "cache control breakpoint is preserved by prepending injected tools", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[` + + `{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`, - require.Len(t, i.req.Tools, 2) - // Injected tools are prepended. - require.Equal(t, "injected_tool", i.req.Tools[0].OfTool.Name) - require.Zero(t, i.req.Tools[0].OfTool.CacheControl) - // Original tool's cache control should be preserved at the end. - require.Equal(t, "existing_tool", i.req.Tools[1].OfTool.Name) - require.Equal(t, constant.ValueOf[constant.Ephemeral](), i.req.Tools[1].OfTool.CacheControl.Type) - }) + injectedTools: []*mcp.Tool{{ID: "injected_tool", Name: "injected", Description: "Injected tool"}}, + expectedToolNames: []string{"injected_tool", "existing_tool"}, + expectedCacheControlTypes: []string{"", string(constant.ValueOf[constant.Ephemeral]())}, + }, + { + name: "cache control breakpoint in non standard location is preserved", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[` + + `{"name":"tool_with_cache_1","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}},` + + `{"name":"tool_with_cache_2","type":"custom","input_schema":{"type":"object","properties":{}}}]}`, + + injectedTools: []*mcp.Tool{{ID: "injected_tool", Name: "injected", Description: "Injected tool"}}, + expectedToolNames: []string{"injected_tool", "tool_with_cache_1", "tool_with_cache_2"}, + expectedCacheControlTypes: []string{"", string(constant.ValueOf[constant.Ephemeral]()), ""}, + }, + { + name: "no cache control added when none originally set", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[` + + `{"name":"existing_tool_no_cache","type":"custom","input_schema":{"type":"object","properties":{}}}]}`, - // The cache breakpoint SHOULD be on the final tool, but may not be; we must preserve that intention. - t.Run("cache control breakpoint in non-standard location is preserved", func(t *testing.T) { - t.Parallel() - - // Request has multiple tools with cache control breakpoints. - i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Tools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "tool_with_cache_1", - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: constant.ValueOf[constant.Ephemeral](), - }, - }, - }, - { - OfTool: &anthropic.ToolParam{ - Name: "tool_with_cache_2", - }, - }, - }, - }, - }, - mcpProxy: &mockServerProxier{ - tools: []*mcp.Tool{ - {ID: "injected_tool", Name: "injected", Description: "Injected tool"}, - }, - }, - } + injectedTools: []*mcp.Tool{{ID: "injected_tool", Name: "injected", Description: "Injected tool"}}, + expectedToolNames: []string{"injected_tool", "existing_tool_no_cache"}, + expectedCacheControlTypes: []string{"", ""}, + }, + } - i.injectTools() - - require.Len(t, i.req.Tools, 3) - // Injected tool is prepended without cache control. - require.Equal(t, "injected_tool", i.req.Tools[0].OfTool.Name) - require.Zero(t, i.req.Tools[0].OfTool.CacheControl) - // Both original tools' cache controls should remain. - require.Equal(t, "tool_with_cache_1", i.req.Tools[1].OfTool.Name) - require.Equal(t, constant.ValueOf[constant.Ephemeral](), i.req.Tools[1].OfTool.CacheControl.Type) - require.Equal(t, "tool_with_cache_2", i.req.Tools[2].OfTool.Name) - require.Zero(t, i.req.Tools[2].OfTool.CacheControl) - }) + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() - t.Run("no cache control added when none originally set", func(t *testing.T) { - t.Parallel() - - // Request has tools but none with cache control. - i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Tools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "existing_tool_no_cache", - }, - }, - }, - }, - }, - mcpProxy: &mockServerProxier{ - tools: []*mcp.Tool{ - {ID: "injected_tool", Name: "injected", Description: "Injected tool"}, - }, - }, - } + base := &interceptionBase{ + reqPayload: mustMessagesPayload(t, testCase.requestBody), + mcpProxy: &mockServerProxier{tools: testCase.injectedTools}, + logger: slog.Make(), + } - i.injectTools() + base.injectTools() - require.Len(t, i.req.Tools, 2) - // Injected tool is prepended without cache control. - require.Equal(t, "injected_tool", i.req.Tools[0].OfTool.Name) - require.Zero(t, i.req.Tools[0].OfTool.CacheControl) - // Original tool remains at the end without cache control. - require.Equal(t, "existing_tool_no_cache", i.req.Tools[1].OfTool.Name) - require.Zero(t, i.req.Tools[1].OfTool.CacheControl) - }) + toolItems := gjson.GetBytes(base.reqPayload, "tools").Array() + require.Len(t, toolItems, len(testCase.expectedToolNames)) + for idx := range toolItems { + require.Equal(t, testCase.expectedToolNames[idx], toolItems[idx].Get("name").String()) + require.Equal(t, testCase.expectedCacheControlTypes[idx], toolItems[idx].Get("cache_control.type").String()) + } + }) + } } func TestInjectTools_ParallelToolCalls(t *testing.T) { t.Parallel() - t.Run("does not modify tool choice when no tools to inject", func(t *testing.T) { - t.Parallel() - - i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{ - Type: constant.ValueOf[constant.Auto](), - }, - }, - }, - }, - mcpProxy: &mockServerProxier{tools: nil}, // No tools to inject. - } - - i.injectTools() - - // Tool choice should remain unchanged - DisableParallelToolUse should not be set. - require.NotNil(t, i.req.ToolChoice.OfAuto) - require.False(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Valid()) - }) - - t.Run("disables parallel tool use for auto tool choice (default)", func(t *testing.T) { - t.Parallel() - - i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - // No tool choice set (default). - }, - }, - mcpProxy: &mockServerProxier{ - tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, - }, - } - - i.injectTools() - - require.NotNil(t, i.req.ToolChoice.OfAuto) - require.True(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Valid()) - require.True(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Value) - }) - - t.Run("disables parallel tool use for explicit auto tool choice", func(t *testing.T) { - t.Parallel() - - i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{ - Type: constant.ValueOf[constant.Auto](), - }, - }, - }, - }, - mcpProxy: &mockServerProxier{ - tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, - }, - } - - i.injectTools() + testCases := []struct { + name string + requestBody string + injectedTools []*mcp.Tool + expectedToolChoiceType string + expectedDisableParallel *bool + expectedToolCount int + }{ + { + name: "does not modify tool choice when no tools to inject", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tool_choice":{"type":"auto"}}`, + injectedTools: nil, + expectedToolChoiceType: string(constant.ValueOf[constant.Auto]()), + expectedDisableParallel: nil, + expectedToolCount: 0, + }, + { + name: "disables parallel tool use for auto tool choice default", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`, + injectedTools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, + expectedToolChoiceType: string(constant.ValueOf[constant.Auto]()), + expectedDisableParallel: utils.PtrTo(true), + expectedToolCount: 1, + }, + { + name: "disables parallel tool use for explicit auto tool choice", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tool_choice":{"type":"auto"}}`, + injectedTools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, + expectedToolChoiceType: string(constant.ValueOf[constant.Auto]()), + expectedDisableParallel: utils.PtrTo(true), + expectedToolCount: 1, + }, + { + name: "disables parallel tool use for any tool choice", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tool_choice":{"type":"any"}}`, + injectedTools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, + expectedToolChoiceType: string(constant.ValueOf[constant.Any]()), + expectedDisableParallel: utils.PtrTo(true), + expectedToolCount: 1, + }, + { + name: "disables parallel tool use for tool choice type", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tool_choice":{"type":"tool","name":"specific_tool"}}`, + injectedTools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, + expectedToolChoiceType: string(constant.ValueOf[constant.Tool]()), + expectedDisableParallel: utils.PtrTo(true), + expectedToolCount: 1, + }, + { + name: "no op for none tool choice type", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tool_choice":{"type":"none"}}`, + injectedTools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, + expectedToolChoiceType: string(constant.ValueOf[constant.None]()), + expectedDisableParallel: nil, + expectedToolCount: 1, + }, + } - require.NotNil(t, i.req.ToolChoice.OfAuto) - require.True(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Valid()) - require.True(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Value) - }) + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() - t.Run("disables parallel tool use for any tool choice", func(t *testing.T) { - t.Parallel() - - i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfAny: &anthropic.ToolChoiceAnyParam{ - Type: constant.ValueOf[constant.Any](), - }, - }, - }, - }, - mcpProxy: &mockServerProxier{ - tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, - }, - } + base := &interceptionBase{ + reqPayload: mustMessagesPayload(t, testCase.requestBody), + mcpProxy: &mockServerProxier{tools: testCase.injectedTools}, + logger: slog.Make(), + } - i.injectTools() + base.injectTools() - require.NotNil(t, i.req.ToolChoice.OfAny) - require.True(t, i.req.ToolChoice.OfAny.DisableParallelToolUse.Valid()) - require.True(t, i.req.ToolChoice.OfAny.DisableParallelToolUse.Value) - }) + require.Len(t, gjson.GetBytes(base.reqPayload, "tools").Array(), testCase.expectedToolCount) - t.Run("disables parallel tool use for tool choice type", func(t *testing.T) { - t.Parallel() - - i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfTool: &anthropic.ToolChoiceToolParam{ - Type: constant.ValueOf[constant.Tool](), - Name: "specific_tool", - }, - }, - }, - }, - mcpProxy: &mockServerProxier{ - tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, - }, - } + toolChoice := gjson.GetBytes(base.reqPayload, "tool_choice") + require.Equal(t, testCase.expectedToolChoiceType, toolChoice.Get("type").String()) - i.injectTools() + disableParallelResult := toolChoice.Get("disable_parallel_tool_use") + if testCase.expectedDisableParallel == nil { + require.False(t, disableParallelResult.Exists()) + return + } - require.NotNil(t, i.req.ToolChoice.OfTool) - require.True(t, i.req.ToolChoice.OfTool.DisableParallelToolUse.Valid()) - require.True(t, i.req.ToolChoice.OfTool.DisableParallelToolUse.Value) - }) + require.True(t, disableParallelResult.Exists()) + require.Equal(t, *testCase.expectedDisableParallel, disableParallelResult.Bool()) + }) + } +} - t.Run("no-op for none tool choice type", func(t *testing.T) { - t.Parallel() - - i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfNone: &anthropic.ToolChoiceNoneParam{ - Type: constant.ValueOf[constant.None](), - }, - }, - }, - }, - mcpProxy: &mockServerProxier{ - tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, - }, - } +func mustMessagesPayload(t *testing.T, requestBody string) MessagesRequestPayload { + t.Helper() - i.injectTools() + payload, err := NewMessagesRequestPayload([]byte(requestBody)) + require.NoError(t, err) - // Tools are still injected. - require.Len(t, i.req.Tools, 1) - // But no parallel tool use modification for "none" type. - require.Nil(t, i.req.ToolChoice.OfAuto) - require.Nil(t, i.req.ToolChoice.OfAny) - require.Nil(t, i.req.ToolChoice.OfTool) - require.NotNil(t, i.req.ToolChoice.OfNone) - }) + return payload } // mockServerProxier is a test implementation of mcp.ServerProxier. diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 51a1a98d..a051af45 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -30,8 +30,7 @@ type BlockingInterception struct { func NewBlockingInterceptor( id uuid.UUID, - req *MessageNewParamsWrapper, - payload []byte, + reqPayload MessagesRequestPayload, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, @@ -40,8 +39,7 @@ func NewBlockingInterceptor( ) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ id: id, - req: req, - payload: payload, + reqPayload: reqPayload, cfg: cfg, bedrockCfg: bedrockCfg, clientHeaders: clientHeaders, @@ -63,8 +61,8 @@ func (s *BlockingInterception) Streaming() bool { } func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { - if i.req == nil { - return fmt.Errorf("developer error: req is nil") + if len(i.reqPayload) == 0 { + return fmt.Errorf("developer error: request payload is empty") } ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) @@ -72,46 +70,39 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req i.injectTools() - var ( - prompt *string - err error - ) - // Track user prompt if not a small/fast model + var prompt *string if !i.isSmallFastModel() { - prompt, err = i.req.lastUserPrompt() - if err != nil { - i.logger.Warn(ctx, "failed to retrieve last user prompt", slog.Error(err)) + promptText, promptFound, promptErr := i.reqPayload.lastUserPrompt() + if promptErr != nil { + i.logger.Warn(ctx, "failed to retrieve last user prompt", slog.Error(promptErr)) + } else if promptFound { + prompt = &promptText } } - opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 600)} - // TODO(ssncferreira): inject actor headers directly in the client-header // middleware instead of using SDK options. + requestOptions := []option.RequestOption{option.WithRequestTimeout(time.Second * 600)} if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { - opts = append(opts, intercept.ActorHeadersAsAnthropicOpts(actor)...) + requestOptions = append(requestOptions, intercept.ActorHeadersAsAnthropicOpts(actor)...) } - svc, err := i.newMessagesService(ctx, opts...) + svc, err := i.newMessagesService(ctx, requestOptions...) if err != nil { err = fmt.Errorf("create anthropic client: %w", err) http.Error(w, err.Error(), http.StatusInternalServerError) return err } - messages := i.req.MessageNewParams - logger := i.logger.With(slog.F("model", i.req.Model)) + logger := i.logger.With(slog.F("model", i.Model())) var resp *anthropic.Message - // Accumulate usage across the entire streaming interaction (including tool reinvocations). var cumulativeUsage anthropic.Usage for { - // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) - resp, err = i.newMessage(ctx, svc, messages) + resp, err = i.newMessage(ctx, svc) if err != nil { if eventstream.IsConnError(err) { - // Can't write a response, just error out. return fmt.Errorf("upstream connection closed: %w", err) } @@ -160,8 +151,8 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req // Handle tool calls. var pendingToolCalls []anthropic.ToolUseBlock - for _, c := range resp.Content { - toolUse := c.AsToolUse() + for _, contentBlock := range resp.Content { + toolUse := contentBlock.AsToolUse() if toolUse.ID == "" { continue } @@ -171,7 +162,6 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req continue } - // If tool is not injected, track it since the client will be handling it. _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: resp.ID, @@ -182,88 +172,78 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req }) } - // If no injected tool calls, we're done. if len(pendingToolCalls) == 0 { break } - // Append the assistant's message (which contains the tool_use block) - // to the messages for the next API call. - messages.Messages = append(messages.Messages, resp.ToParam()) + var loopMessages []anthropic.MessageParam + loopMessages = append(loopMessages, resp.ToParam()) - // Process each pending tool call. - for _, tc := range pendingToolCalls { + for _, toolCall := range pendingToolCalls { if i.mcpProxy == nil { continue } - tool := i.mcpProxy.GetTool(tc.Name) + tool := i.mcpProxy.GetTool(toolCall.Name) if tool == nil { - logger.Warn(ctx, "tool not found in manager", slog.F("tool", tc.Name)) - // Continue to next tool call, but still append an error tool_result - messages.Messages = append(messages.Messages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(tc.ID, fmt.Sprintf("Error: tool %s not found", tc.Name), true)), + logger.Warn(ctx, "tool not found in manager", slog.F("tool", toolCall.Name)) + loopMessages = append(loopMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(toolCall.ID, fmt.Sprintf("Error: tool %s not found", toolCall.Name), true)), ) continue } - res, err := tool.Call(ctx, tc.Input, i.tracer) - + toolResultResponse, toolCallErr := tool.Call(ctx, toolCall.Input, i.tracer) _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: resp.ID, - ToolCallID: tc.ID, + ToolCallID: toolCall.ID, ServerURL: &tool.ServerURL, Tool: tool.Name, - Args: tc.Input, + Args: toolCall.Input, Injected: true, - InvocationError: err, + InvocationError: toolCallErr, }) - if err != nil { - // Always provide a tool_result even if the tool call failed - messages.Messages = append(messages.Messages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(tc.ID, fmt.Sprintf("Error: calling tool: %v", err), true)), + if toolCallErr != nil { + loopMessages = append(loopMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(toolCall.ID, fmt.Sprintf("Error: calling tool: %v", toolCallErr), true)), ) continue } - // Process tool result toolResult := anthropic.ContentBlockParamUnion{ OfToolResult: &anthropic.ToolResultBlockParam{ - ToolUseID: tc.ID, + ToolUseID: toolCall.ID, IsError: anthropic.Bool(false), }, } var hasValidResult bool - for _, content := range res.Content { - switch cb := content.(type) { + for _, toolContent := range toolResultResponse.Content { + switch contentBlock := toolContent.(type) { case mcplib.TextContent: toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.ToolResultBlockParamContentUnion{ OfText: &anthropic.TextBlockParam{ - Text: cb.Text, + Text: contentBlock.Text, }, }) hasValidResult = true - // TODO: is there a more correct way of handling these non-text content responses? case mcplib.EmbeddedResource: - switch resource := cb.Resource.(type) { + switch resource := contentBlock.Resource.(type) { case mcplib.TextResourceContents: - val := fmt.Sprintf("Binary resource (MIME: %s, URI: %s): %s", - resource.MIMEType, resource.URI, resource.Text) + value := fmt.Sprintf("Binary resource (MIME: %s, URI: %s): %s", resource.MIMEType, resource.URI, resource.Text) toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.ToolResultBlockParamContentUnion{ OfText: &anthropic.TextBlockParam{ - Text: val, + Text: value, }, }) hasValidResult = true case mcplib.BlobResourceContents: - val := fmt.Sprintf("Binary resource (MIME: %s, URI: %s): %s", - resource.MIMEType, resource.URI, resource.Blob) + value := fmt.Sprintf("Binary resource (MIME: %s, URI: %s): %s", resource.MIMEType, resource.URI, resource.Blob) toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.ToolResultBlockParamContentUnion{ OfText: &anthropic.TextBlockParam{ - Text: val, + Text: value, }, }) hasValidResult = true @@ -278,7 +258,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req hasValidResult = true } default: - i.logger.Warn(ctx, "not handling non-text tool result", slog.F("type", fmt.Sprintf("%T", cb))) + i.logger.Warn(ctx, "not handling non-text tool result", slog.F("type", fmt.Sprintf("%T", contentBlock))) toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.ToolResultBlockParamContentUnion{ OfText: &anthropic.TextBlockParam{ Text: "Error: unsupported tool result type", @@ -289,9 +269,8 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req } } - // If no content was processed, still add a tool_result if !hasValidResult { - i.logger.Warn(ctx, "no tool result added", slog.F("content_len", len(res.Content)), slog.F("is_error", res.IsError)) + i.logger.Warn(ctx, "no tool result added", slog.F("content_len", len(toolResultResponse.Content)), slog.F("is_error", toolResultResponse.IsError)) toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.ToolResultBlockParamContentUnion{ OfText: &anthropic.TextBlockParam{ Text: "Error: no valid tool result content", @@ -301,44 +280,42 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req } if len(toolResult.OfToolResult.Content) > 0 { - messages.Messages = append(messages.Messages, anthropic.NewUserMessage(toolResult)) + loopMessages = append(loopMessages, anthropic.NewUserMessage(toolResult)) } } - // Sync the raw payload with updated messages so that withBody() - // sends the updated payload on the next iteration. - if err := i.syncPayloadMessages(messages.Messages); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return fmt.Errorf("sync payload for agentic loop: %w", err) + updatedPayload, rewriteErr := i.reqPayload.appendedMessages(loopMessages) + if rewriteErr != nil { + http.Error(w, rewriteErr.Error(), http.StatusInternalServerError) + return fmt.Errorf("rewrite payload for agentic loop: %w", rewriteErr) } + i.reqPayload = updatedPayload } if resp == nil { return nil } - // Overwrite response identifier since proxy obscures injected tool call invocations. - sj, err := sjson.Set(resp.RawJSON(), "id", i.ID().String()) + responseJSON, err := sjson.Set(resp.RawJSON(), "id", i.ID().String()) if err != nil { return fmt.Errorf("marshal response id failed: %w", err) } - // Overwrite the response's usage with the cumulative usage across any inner loops which invokes injected MCP tools. - sj, err = sjson.Set(sj, "usage", cumulativeUsage) + responseJSON, err = sjson.Set(responseJSON, "usage", cumulativeUsage) if err != nil { return fmt.Errorf("marshal response usage failed: %w", err) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(sj)) + _, _ = w.Write([]byte(responseJSON)) return nil } -func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService, msgParams anthropic.MessageNewParams) (_ *anthropic.Message, outErr error) { +func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService) (_ *anthropic.Message, outErr error) { ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) - return svc.New(ctx, msgParams, i.withBody()) + return svc.New(ctx, anthropic.MessageNewParams{}, i.withBody()) } diff --git a/intercept/messages/paramswrap.go b/intercept/messages/paramswrap.go deleted file mode 100644 index bd5175aa..00000000 --- a/intercept/messages/paramswrap.go +++ /dev/null @@ -1,142 +0,0 @@ -package messages - -import ( - "encoding/json" - "errors" - - "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/packages/param" -) - -// MessageNewParamsWrapper exists because the "stream" param is not included in anthropic.MessageNewParams. -type MessageNewParamsWrapper struct { - anthropic.MessageNewParams `json:""` - Stream bool `json:"stream,omitempty"` -} - -func (b MessageNewParamsWrapper) MarshalJSON() ([]byte, error) { - type shadow MessageNewParamsWrapper - return param.MarshalWithExtras(b, (*shadow)(&b), map[string]any{ - "stream": b.Stream, - }) -} - -func (b *MessageNewParamsWrapper) UnmarshalJSON(raw []byte) error { - // Parse JSON once and extract both stream field and do content conversion - // to avoid double-parsing the same payload. - var modifiedJSON map[string]any - if err := json.Unmarshal(raw, &modifiedJSON); err != nil { - return err - } - - // Extract stream field from already-parsed map - if stream, ok := modifiedJSON["stream"].(bool); ok { - b.Stream = stream - } - - // Convert string content to array format if needed - if _, hasMessages := modifiedJSON["messages"]; hasMessages { - convertStringContentRecursive(modifiedJSON) - } - - // Marshal back for SDK parsing - convertedRaw, err := json.Marshal(modifiedJSON) - if err != nil { - return err - } - - return b.MessageNewParams.UnmarshalJSON(convertedRaw) -} - -func (b *MessageNewParamsWrapper) lastUserPrompt() (*string, error) { - if b == nil { - return nil, errors.New("nil struct") - } - - if len(b.Messages) == 0 { - return nil, errors.New("no messages") - } - - // We only care if the last message was issued by a user. - msg := b.Messages[len(b.Messages)-1] - if msg.Role != anthropic.MessageParamRoleUser { - return nil, nil - } - - if len(msg.Content) == 0 { - return nil, nil - } - - // Walk backwards on "user"-initiated message content. Clients often inject - // content ahead of the actual prompt to provide context to the model, - // so the last item in the slice is most likely the user's prompt. - for i := len(msg.Content) - 1; i >= 0; i-- { - // Only text content is supported currently. - if textContent := msg.Content[i].GetText(); textContent != nil { - return textContent, nil - } - } - - return nil, nil -} - -// convertStringContentRecursive recursively scans JSON data and converts string "content" fields -// to proper text block arrays where needed for Anthropic SDK compatibility. -// Returns true if any modifications were made. -func convertStringContentRecursive(data any) bool { - modified := false - switch v := data.(type) { - case map[string]any: - // Check if this object has a "content" field with string value - if content, hasContent := v["content"]; hasContent { - if contentStr, isString := content.(string); isString { - // Check if this needs conversion based on context - if shouldConvertContentField(v) { - v["content"] = []map[string]any{ - { - "type": "text", - "text": contentStr, - }, - } - modified = true - } - } - } - - // Recursively process all values in the map - for _, value := range v { - if convertStringContentRecursive(value) { - modified = true - } - } - - case []any: - // Recursively process all items in the array - for _, item := range v { - if convertStringContentRecursive(item) { - modified = true - } - } - } - return modified -} - -// shouldConvertContentField determines if a "content" string field should be converted to text block array -func shouldConvertContentField(obj map[string]any) bool { - // Check if this is a message-level content (has "role" field) - if _, hasRole := obj["role"]; hasRole { - return true - } - - // Check if this is a tool_result block (but not mcp_tool_result which supports strings) - if objType, hasType := obj["type"].(string); hasType { - switch objType { - case "tool_result": - return true // Regular tool_result needs array format - case "mcp_tool_result": - return false // MCP tool_result supports strings - } - } - - return false -} diff --git a/intercept/messages/paramswrap_test.go b/intercept/messages/paramswrap_test.go deleted file mode 100644 index 7f8793d7..00000000 --- a/intercept/messages/paramswrap_test.go +++ /dev/null @@ -1,303 +0,0 @@ -package messages - -import ( - "testing" - - "github.com/anthropics/anthropic-sdk-go" - "github.com/stretchr/testify/require" -) - -func TestMessageNewParamsWrapperUnmarshalJSON(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input string - expectedStream bool - checkContent func(t *testing.T, w *MessageNewParamsWrapper) - }{ - { - name: "message with string content converts to array", - input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"Hello world"}]}`, - expectedStream: false, - checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { - require.Len(t, w.Messages, 1) - require.Equal(t, anthropic.MessageParamRoleUser, w.Messages[0].Role) - text := w.Messages[0].Content[0].GetText() - require.NotNil(t, text) - require.Equal(t, "Hello world", *text) - }, - }, - { - name: "stream field extracted", - input: `{"model":"claude-3","max_tokens":1000,"stream":true,"messages":[{"role":"user","content":"Hi"}]}`, - expectedStream: true, - checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { - require.Len(t, w.Messages, 1) - }, - }, - { - name: "stream false", - input: `{"model":"claude-3","max_tokens":1000,"stream":false,"messages":[{"role":"user","content":"Hi"}]}`, - expectedStream: false, - checkContent: nil, - }, - { - name: "array content unchanged", - input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`, - expectedStream: false, - checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { - require.Len(t, w.Messages, 1) - text := w.Messages[0].Content[0].GetText() - require.NotNil(t, text) - require.Equal(t, "Hello", *text) - }, - }, - { - name: "multiple messages with mixed content", - input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"First"},{"role":"assistant","content":[{"type":"text","text":"Response"}]},{"role":"user","content":"Second"}]}`, - expectedStream: false, - checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { - require.Len(t, w.Messages, 3) - text0 := w.Messages[0].Content[0].GetText() - require.NotNil(t, text0) - require.Equal(t, "First", *text0) - text2 := w.Messages[2].Content[0].GetText() - require.NotNil(t, text2) - require.Equal(t, "Second", *text2) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var wrapper MessageNewParamsWrapper - err := wrapper.UnmarshalJSON([]byte(tt.input)) - require.NoError(t, err) - require.Equal(t, tt.expectedStream, wrapper.Stream) - if tt.checkContent != nil { - tt.checkContent(t, &wrapper) - } - }) - } -} - -func TestShouldConvertContentField(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - obj map[string]any - expected bool - }{ - { - name: "message with role", - obj: map[string]any{ - "role": "user", - "content": "test", - }, - expected: true, - }, - { - name: "tool_result type", - obj: map[string]any{ - "type": "tool_result", - "content": "result", - }, - expected: true, - }, - { - name: "mcp_tool_result type", - obj: map[string]any{ - "type": "mcp_tool_result", - "content": "result", - }, - expected: false, - }, - { - name: "other type", - obj: map[string]any{ - "type": "text", - "content": "text", - }, - expected: false, - }, - { - name: "no role or type", - obj: map[string]any{ - "content": "test", - }, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := shouldConvertContentField(tt.obj) - require.Equal(t, tt.expected, result) - }) - } -} - -func TestAnthropicLastUserPrompt(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - wrapper *MessageNewParamsWrapper - expected string - expectError bool - errorMsg string - }{ - { - name: "nil struct", - expectError: true, - errorMsg: "nil struct", - }, - { - name: "no messages", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{}, - }, - }, - expectError: true, - errorMsg: "no messages", - }, - { - name: "last message not from user", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("user message"), - }, - }, - { - Role: anthropic.MessageParamRoleAssistant, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("assistant message"), - }, - }, - }, - }, - }, - }, - { - name: "last user message with empty content", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{}, - }, - }, - }, - }, - }, - { - name: "last user message with single text content", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("Hello, world!"), - }, - }, - }, - }, - }, - expected: "Hello, world!", - }, - { - name: "last user message with multiple content blocks - text at end", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewImageBlockBase64("image/png", "base64data"), - anthropic.NewTextBlock("First text"), - anthropic.NewImageBlockBase64("image/jpeg", "moredata"), - anthropic.NewTextBlock("Last text"), - }, - }, - }, - }, - }, - expected: "Last text", - }, - { - name: "last user message with only non-text content", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewImageBlockBase64("image/png", "base64data"), - anthropic.NewImageBlockBase64("image/jpeg", "moredata"), - }, - }, - }, - }, - }, - }, - { - name: "multiple messages with last being user", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("First user message"), - }, - }, - { - Role: anthropic.MessageParamRoleAssistant, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("Assistant response"), - }, - }, - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("Second user message"), - }, - }, - }, - }, - }, - expected: "Second user message", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := tt.wrapper.lastUserPrompt() - - if tt.expectError { - require.Error(t, err) - require.Contains(t, err.Error(), tt.errorMsg) - require.Nil(t, result) - } else { - require.NoError(t, err) - // Check pointer equality - both nil or both non-nil - if tt.expected == "" { - require.Nil(t, result) - } else { - require.NotNil(t, result) - // The result should point to the same string from the content block - require.Equal(t, tt.expected, *result) - } - } - }) - } -} diff --git a/intercept/messages/reqpayload.go b/intercept/messages/reqpayload.go new file mode 100644 index 00000000..cefddd87 --- /dev/null +++ b/intercept/messages/reqpayload.go @@ -0,0 +1,271 @@ +package messages + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + // Absolute JSON paths from the request root. + messagesReqPathMessages = "messages" + messagesReqPathModel = "model" + messagesReqPathStream = "stream" + messagesReqPathToolChoice = "tool_choice" + messagesReqPathToolChoiceDisableParallel = "tool_choice.disable_parallel_tool_use" + messagesReqPathToolChoiceType = "tool_choice.type" + messagesReqPathTools = "tools" + + // Relative field names used within sub-objects. + messagesReqFieldContent = "content" + messagesReqFieldRole = "role" + messagesReqFieldText = "text" + messagesReqFieldToolUseID = "tool_use_id" + messagesReqFieldType = "type" +) + +var ( + constAny = string(constant.ValueOf[constant.Any]()) + constAuto = string(constant.ValueOf[constant.Auto]()) + constNone = string(constant.ValueOf[constant.None]()) + constText = string(constant.ValueOf[constant.Text]()) + constTool = string(constant.ValueOf[constant.Tool]()) + constToolResult = string(constant.ValueOf[constant.ToolResult]()) + constUser = string(anthropic.MessageParamRoleUser) +) + +// MessagesRequestPayload is raw JSON bytes of an Anthropic Messages API request. +// Methods provide package-specific reads and rewrites while preserving the +// original body for upstream pass-through. +type MessagesRequestPayload []byte + +func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) { + if len(bytes.TrimSpace(raw)) == 0 { + return nil, fmt.Errorf("messages empty request body") + } + if !json.Valid(raw) { + return nil, fmt.Errorf("messages invalid JSON request body") + } + + return MessagesRequestPayload(raw), nil +} + +func (p MessagesRequestPayload) Stream() bool { + return gjson.GetBytes(p, messagesReqPathStream).Bool() +} + +func (p MessagesRequestPayload) model() string { + return gjson.GetBytes(p, messagesReqPathModel).String() +} + +func (p MessagesRequestPayload) correlatingToolCallID() *string { + messages := gjson.GetBytes(p, messagesReqPathMessages) + if !messages.IsArray() { + return nil + } + + messageItems := messages.Array() + if len(messageItems) == 0 { + return nil + } + + content := messageItems[len(messageItems)-1].Get(messagesReqFieldContent) + if !content.IsArray() { + return nil + } + + contentItems := content.Array() + for idx := len(contentItems) - 1; idx >= 0; idx-- { + contentItem := contentItems[idx] + if contentItem.Get(messagesReqFieldType).String() != constToolResult { + continue + } + + toolUseID := contentItem.Get(messagesReqFieldToolUseID).String() + if toolUseID == "" { + continue + } + + return &toolUseID + } + + return nil +} + +// lastUserPrompt returns the prompt text from the last user message. If no prompt +// is found, it returns empty string, false, nil. Unexpected shapes are treated as +// unsupported and do not fail the request path. +func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) { + messages := gjson.GetBytes(p, messagesReqPathMessages) + if !messages.Exists() || messages.Type == gjson.Null { + return "", false, nil + } + if !messages.IsArray() { + return "", false, fmt.Errorf("unexpected messages type: %s", messages.Type) + } + + messageItems := messages.Array() + if len(messageItems) == 0 { + return "", false, nil + } + + lastMessage := messageItems[len(messageItems)-1] + if lastMessage.Get(messagesReqFieldRole).String() != constUser { + return "", false, nil + } + + content := lastMessage.Get(messagesReqFieldContent) + if !content.Exists() || content.Type == gjson.Null { + return "", false, nil + } + if content.Type == gjson.String { + return content.String(), true, nil + } + if !content.IsArray() { + return "", false, fmt.Errorf("unexpected message content type: %s", content.Type) + } + + contentItems := content.Array() + for idx := len(contentItems) - 1; idx >= 0; idx-- { + contentItem := contentItems[idx] + if contentItem.Get(messagesReqFieldType).String() != constText { + continue + } + + text := contentItem.Get(messagesReqFieldText) + if text.Type != gjson.String { + continue + } + + return text.String(), true, nil + } + + return "", false, nil +} + +func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam) (MessagesRequestPayload, error) { + if len(injected) == 0 { + return p, nil + } + + existing, err := p.tools() + if err != nil { + return p, err + } + + allTools := make([]any, 0, len(injected)+len(existing)) + for _, tool := range injected { + allTools = append(allTools, tool) + } + for _, tool := range existing { + allTools = append(allTools, tool) + } + + return p.set(messagesReqPathTools, allTools) +} + +func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPayload, error) { + toolChoice := gjson.GetBytes(p, messagesReqPathToolChoice) + + // If no tool_choice was defined, assume auto. + // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use. + if !toolChoice.Exists() || toolChoice.Type == gjson.Null { + updated, err := p.set(messagesReqPathToolChoiceType, constAuto) + if err != nil { + return p, err + } + return updated.set(messagesReqPathToolChoiceDisableParallel, true) + } + if !toolChoice.IsObject() { + return p, fmt.Errorf("unsupported tool_choice type: %s", toolChoice.Type) + } + + toolChoiceType := gjson.GetBytes(p, messagesReqPathToolChoiceType) + if toolChoiceType.Exists() && toolChoiceType.Type != gjson.String { + return p, fmt.Errorf("unsupported tool_choice.type type: %s", toolChoiceType.Type) + } + + switch toolChoiceType.String() { + case "": + updated, err := p.set(messagesReqPathToolChoiceType, constAuto) + if err != nil { + return p, err + } + return updated.set(messagesReqPathToolChoiceDisableParallel, true) + case constAuto, constAny, constTool: + return p.set(messagesReqPathToolChoiceDisableParallel, true) + case constNone: + return p, nil + default: + return p, fmt.Errorf("unsupported tool_choice.type value: %q", toolChoiceType.String()) + } +} + +func (p MessagesRequestPayload) appendedMessages(messages []anthropic.MessageParam) (MessagesRequestPayload, error) { + if len(messages) == 0 { + return p, nil + } + + existing, err := p.messages() + if err != nil { + return p, err + } + + allMessages := make([]any, 0, len(existing)+len(messages)) + allMessages = append(allMessages, existing...) + for _, message := range messages { + allMessages = append(allMessages, message) + } + + return p.set(messagesReqPathMessages, allMessages) +} + +func (p MessagesRequestPayload) withModel(model string) (MessagesRequestPayload, error) { + return p.set(messagesReqPathModel, model) +} + +func (p MessagesRequestPayload) messages() ([]any, error) { + messages := gjson.GetBytes(p, messagesReqPathMessages) + if !messages.Exists() || messages.Type == gjson.Null { + return []any{}, nil + } + if !messages.IsArray() { + return nil, fmt.Errorf("unsupported messages type: %s", messages.Type) + } + + messageItems := messages.Array() + existing := make([]any, 0, len(messageItems)) + for _, item := range messageItems { + existing = append(existing, json.RawMessage(item.Raw)) + } + + return existing, nil +} + +func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) { + tools := gjson.GetBytes(p, messagesReqPathTools) + if !tools.Exists() || tools.Type == gjson.Null { + return nil, nil + } + if !tools.IsArray() { + return nil, fmt.Errorf("unsupported tools type: %s", tools.Type) + } + + toolItems := tools.Array() + existing := make([]json.RawMessage, 0, len(toolItems)) + for _, item := range toolItems { + existing = append(existing, json.RawMessage(item.Raw)) + } + + return existing, nil +} + +func (p MessagesRequestPayload) set(path string, value any) (MessagesRequestPayload, error) { + out, err := sjson.SetBytes(p, path, value) + return MessagesRequestPayload(out), err +} diff --git a/intercept/messages/reqpayload_test.go b/intercept/messages/reqpayload_test.go new file mode 100644 index 00000000..56cb7e96 --- /dev/null +++ b/intercept/messages/reqpayload_test.go @@ -0,0 +1,343 @@ +package messages + +import ( + "testing" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/coder/aibridge/utils" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestNewMessagesRequestPayload(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody []byte + + expectError bool + }{ + { + name: "empty body", + requestBody: []byte(" \n\t "), + expectError: true, + }, + { + name: "invalid json", + requestBody: []byte(`{"model":`), + expectError: true, + }, + { + name: "valid json", + requestBody: []byte(`{"model":"claude-opus-4-5","max_tokens":1024}`), + expectError: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload, err := NewMessagesRequestPayload(testCase.requestBody) + if testCase.expectError { + require.Error(t, err) + require.Nil(t, payload) + return + } + + require.NoError(t, err) + require.Equal(t, MessagesRequestPayload(testCase.requestBody), payload) + }) + } +} + +func TestMessagesRequestPayloadStream(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedStream bool + }{ + { + name: "stream true", + requestBody: `{"stream":true}`, + expectedStream: true, + }, + { + name: "stream false", + requestBody: `{"stream":false}`, + expectedStream: false, + }, + { + name: "stream missing", + requestBody: `{"model":"claude-opus-4-5"}`, + expectedStream: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + require.Equal(t, testCase.expectedStream, payload.Stream()) + }) + } +} + +func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedPrompt string + + expectedFound bool + + expectError bool + }{ + { + name: "last user message string content", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`, + expectedPrompt: "hello", + expectedFound: true, + expectError: false, + }, + { + name: "last user message typed content returns last text block", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}},{"type":"text","text":"first"},{"type":"text","text":"last"}]}]}`, + expectedPrompt: "last", + expectedFound: true, + expectError: false, + }, + { + name: "last message not from user", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"assistant","content":"hello"}]}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "no messages key", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "empty messages array", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[]}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "last user message with empty content array", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[]}]}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "last user message with only non text content", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}},{"type":"image","source":{"type":"base64","media_type":"image/jpeg","data":"def"}}]}]}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "multiple messages with last being user", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"first"},{"role":"assistant","content":[{"type":"text","text":"response"}]},{"role":"user","content":"second"}]}`, + expectedPrompt: "second", + expectedFound: true, + expectError: false, + }, + { + name: "messages wrong type returns error", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":{}}`, + expectedPrompt: "", + expectedFound: false, + expectError: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + prompt, found, err := payload.lastUserPrompt() + if testCase.expectError { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, testCase.expectedFound, found) + require.Equal(t, testCase.expectedPrompt, prompt) + }) + } +} + +func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedToolUseID *string + }{ + { + name: "no tool result block", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`, + expectedToolUseID: nil, + }, + { + name: "returns last tool result from final message", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"},{"type":"tool_result","tool_use_id":"toolu_second","content":"second"}]}]}`, + expectedToolUseID: utils.PtrTo("toolu_second"), + }, + { + name: "ignores earlier message tool result", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"}]},{"role":"assistant","content":"done"}]}`, + expectedToolUseID: nil, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + require.Equal(t, testCase.expectedToolUseID, payload.correlatingToolCallID()) + }) + } +} + +func TestMessagesRequestPayloadInjectTools(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`) + + updatedPayload, err := payload.injectTools([]anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "injected_tool", + Type: anthropic.ToolTypeCustom, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: map[string]interface{}{}, + }, + }, + }, + }) + require.NoError(t, err) + + toolItems := gjson.GetBytes(updatedPayload, "tools").Array() + require.Len(t, toolItems, 2) + require.Equal(t, "injected_tool", toolItems[0].Get("name").String()) + require.Equal(t, "existing_tool", toolItems[1].Get("name").String()) + require.Equal(t, "ephemeral", toolItems[1].Get("cache_control.type").String()) +} + +func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedType string + + expectedDisableParallel *bool + }{ + { + name: "defaults to auto when missing", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024}`, + expectedType: string(constant.ValueOf[constant.Auto]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "auto gets disabled", + requestBody: `{"tool_choice":{"type":"auto"}}`, + expectedType: string(constant.ValueOf[constant.Auto]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "any gets disabled", + requestBody: `{"tool_choice":{"type":"any"}}`, + expectedType: string(constant.ValueOf[constant.Any]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "tool gets disabled", + requestBody: `{"tool_choice":{"type":"tool","name":"abc"}}`, + expectedType: string(constant.ValueOf[constant.Tool]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "none remains unchanged", + requestBody: `{"tool_choice":{"type":"none"}}`, + expectedType: string(constant.ValueOf[constant.None]()), + expectedDisableParallel: nil, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + updatedPayload, err := payload.disableParallelToolCalls() + require.NoError(t, err) + + toolChoice := gjson.GetBytes(updatedPayload, "tool_choice") + require.Equal(t, testCase.expectedType, toolChoice.Get("type").String()) + + disableParallelResult := toolChoice.Get("disable_parallel_tool_use") + if testCase.expectedDisableParallel == nil { + require.False(t, disableParallelResult.Exists()) + return + } + + require.True(t, disableParallelResult.Exists()) + require.Equal(t, *testCase.expectedDisableParallel, disableParallelResult.Bool()) + }) + } +} + +func TestMessagesRequestPayloadAppendedMessages(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`) + + updatedPayload, err := payload.appendedMessages([]anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleAssistant, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock("assistant response"), + }, + }, + anthropic.NewUserMessage(anthropic.NewToolResultBlock("toolu_123", "tool output", false)), + }) + require.NoError(t, err) + + messageItems := gjson.GetBytes(updatedPayload, "messages").Array() + require.Len(t, messageItems, 3) + require.Equal(t, "hello", messageItems[0].Get("content").String()) + require.Equal(t, "assistant", messageItems[1].Get("role").String()) + require.Equal(t, "assistant response", messageItems[1].Get("content.0.text").String()) + require.Equal(t, "tool_result", messageItems[2].Get("content.0.type").String()) + require.Equal(t, "toolu_123", messageItems[2].Get("content.0.tool_use_id").String()) +} diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index b2900a75..482e2f5a 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -36,8 +36,7 @@ type StreamingInterception struct { func NewStreamingInterceptor( id uuid.UUID, - req *MessageNewParamsWrapper, - payload []byte, + reqPayload MessagesRequestPayload, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, @@ -46,8 +45,7 @@ func NewStreamingInterceptor( ) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ id: id, - req: req, - payload: payload, + reqPayload: reqPayload, cfg: cfg, bedrockCfg: bedrockCfg, clientHeaders: clientHeaders, @@ -88,8 +86,8 @@ func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.Key // results relayed to the SERVER. The response from the server will be handled synchronously, and this loop // can continue until all injected tool invocations are completed and the response is relayed to the client. func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { - if i.req == nil { - return fmt.Errorf("developer error: req is nil") + if len(i.reqPayload) == 0 { + return fmt.Errorf("developer error: request payload is empty") } ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) @@ -100,7 +98,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. - logger := i.logger.With(slog.F("model", i.req.Model)) + logger := i.logger.With(slog.F("model", i.Model())) var ( prompt *string @@ -109,9 +107,11 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re // Claude Code uses a "small/fast model" for certain tasks. if !i.isSmallFastModel() { - prompt, err = i.req.lastUserPrompt() - if err != nil { - logger.Warn(ctx, "failed to determine last user prompt", slog.Error(err)) + promptText, promptFound, promptErr := i.reqPayload.lastUserPrompt() + if promptErr != nil { + logger.Warn(ctx, "failed to determine last user prompt", slog.Error(promptErr)) + } else if promptFound { + prompt = &promptText } // Only inject tools into "actual" request. @@ -142,8 +142,6 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() - messages := i.req.MessageNewParams - // Accumulate usage across the entire streaming interaction (including tool reinvocations). var cumulativeUsage anthropic.Usage @@ -159,7 +157,7 @@ newStream: break } - stream := i.newStream(streamCtx, svc, messages) + stream := i.newStream(streamCtx, svc) var message anthropic.Message var lastToolName string @@ -278,7 +276,8 @@ newStream: // Process injected tools. if len(pendingToolCalls) > 0 { // Append the whole message from this stream as context since we'll be sending a new request with the tool results. - messages.Messages = append(messages.Messages, message.ToParam()) + var loopMessages []anthropic.MessageParam + loopMessages = append(loopMessages, message.ToParam()) for name, id := range pendingToolCalls { if i.mcpProxy == nil { @@ -296,11 +295,9 @@ newStream: continue } - var ( - input json.RawMessage - foundTool bool - foundTools int - ) + var input json.RawMessage + var foundTool bool + var foundTools int for _, block := range message.Content { switch variant := block.AsAny().(type) { case anthropic.ToolUseBlock: @@ -331,14 +328,12 @@ newStream: }) if err != nil { - // Always provide a tool_result even if the tool call failed - messages.Messages = append(messages.Messages, + loopMessages = append(loopMessages, anthropic.NewUserMessage(anthropic.NewToolResultBlock(id, fmt.Sprintf("Error calling tool: %v", err), true)), ) continue } - // Process tool result toolResult := anthropic.ContentBlockParamUnion{ OfToolResult: &anthropic.ToolResultBlockParam{ ToolUseID: id, @@ -398,7 +393,6 @@ newStream: } } - // If no content was processed, still add a tool_result if !hasValidResult { logger.Warn(ctx, "no tool result added", slog.F("content_len", len(res.Content)), slog.F("is_error", res.IsError)) toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.ToolResultBlockParamContentUnion{ @@ -410,16 +404,16 @@ newStream: } if len(toolResult.OfToolResult.Content) > 0 { - messages.Messages = append(messages.Messages, anthropic.NewUserMessage(toolResult)) + loopMessages = append(loopMessages, anthropic.NewUserMessage(toolResult)) } } - // Sync the raw payload with updated messages so that withBody() - // sends the updated payload on the next iteration. - if syncErr := i.syncPayloadMessages(messages.Messages); syncErr != nil { - lastErr = fmt.Errorf("sync payload for agentic loop: %w", syncErr) + updatedPayload, rewriteErr := i.reqPayload.appendedMessages(loopMessages) + if rewriteErr != nil { + lastErr = fmt.Errorf("rewrite payload for agentic loop: %w", rewriteErr) break } + i.reqPayload = updatedPayload // Causes a new stream to be run with updated messages. isFirst = false @@ -579,10 +573,10 @@ func (s *StreamingInterception) encodeForStream(payload []byte, typ string) []by return buf.Bytes() } -// newStream traces svc.NewStreaming(streamCtx, messages) -func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService, messages anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEventUnion] { +// newStream traces svc.NewStreaming(streamCtx, anthropic.MessageNewParams{}). +func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] { _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() - return svc.NewStreaming(ctx, messages, s.withBody()) + return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, s.withBody()) } diff --git a/internal/integrationtest/mockupstream.go b/internal/integrationtest/mockupstream.go index a658b054..40b00663 100644 --- a/internal/integrationtest/mockupstream.go +++ b/internal/integrationtest/mockupstream.go @@ -16,7 +16,6 @@ import ( "sync/atomic" "testing" - "github.com/anthropics/anthropic-sdk-go" "github.com/coder/aibridge/fixtures" "github.com/openai/openai-go/v3" "github.com/stretchr/testify/require" @@ -296,14 +295,14 @@ func validateOpenAIResponses(t *testing.T, body []byte, msgAndArgs ...any) { } // validateAnthropicMessages validates that an Anthropic messages request -// has all required fields. -// See https://github.com/anthropics/anthropic-sdk-go. +// has the required top-level fields while remaining tolerant of raw payload +// shapes that do not round-trip through anthropic.MessageNewParams. func validateAnthropicMessages(t *testing.T, body []byte, msgAndArgs ...any) { t.Helper() - var req anthropic.MessageNewParams - require.NoError(t, json.Unmarshal(body, &req), msgAndArgs...) - require.NotEmpty(t, req.Model, "model is required", msgAndArgs) - require.NotEmpty(t, req.Messages, "messages is required", msgAndArgs) - require.NotZero(t, req.MaxTokens, "max_tokens is required", msgAndArgs) + var requestBody map[string]any + require.NoError(t, json.Unmarshal(body, &requestBody), msgAndArgs...) + require.NotEmpty(t, requestBody["model"], "model is required", msgAndArgs) + require.NotEmpty(t, requestBody["messages"], "messages is required", msgAndArgs) + require.NotZero(t, requestBody["max_tokens"], "max_tokens is required", msgAndArgs) } diff --git a/provider/anthropic.go b/provider/anthropic.go index 4a79cd42..f4380648 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -1,8 +1,6 @@ package provider import ( - "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -102,8 +100,9 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr if err != nil { return nil, fmt.Errorf("read body: %w", err) } - var req messages.MessageNewParamsWrapper - if err := json.NewDecoder(bytes.NewReader(payload)).Decode(&req); err != nil { + + reqPayload, err := messages.NewMessagesRequestPayload(payload) + if err != nil { return nil, fmt.Errorf("unmarshal request body: %w", err) } @@ -111,10 +110,10 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr cfg.ExtraHeaders = extractAnthropicHeaders(r) var interceptor intercept.Interceptor - if req.Stream { - interceptor = messages.NewStreamingInterceptor(id, &req, payload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) + if reqPayload.Stream() { + interceptor = messages.NewStreamingInterceptor(id, reqPayload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, &req, payload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) + interceptor = messages.NewBlockingInterceptor(id, reqPayload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil From d6094f860f3ce01214edbbc6a6cbf40edefb7e65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Tue, 17 Mar 2026 18:14:04 +0000 Subject: [PATCH 2/2] revert table tests + variable name changes --- intercept/messages/base_test.go | 398 +++++++++++++---------- intercept/messages/blocking.go | 74 +++-- intercept/messages/streaming.go | 41 ++- internal/integrationtest/mockupstream.go | 15 +- 4 files changed, 309 insertions(+), 219 deletions(-) diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index 810e5a52..3f25b6eb 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -18,57 +18,57 @@ import ( func TestScanForCorrelatingToolCallID(t *testing.T) { t.Parallel() - testCases := []struct { - name string - requestBody string - expectedToolID *string + tests := []struct { + name string + requestBody string + expected *string }{ { - name: "no messages field", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024}`, - expectedToolID: nil, + name: "no messages field", + requestBody: `{}`, + expected: nil, }, { - name: "messages string", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":"test"}`, - expectedToolID: nil, + name: "messages string", + requestBody: `{"messages":"test"}`, + expected: nil, }, { - name: "empty messages array", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[]}`, - expectedToolID: nil, + name: "empty messages array", + requestBody: `{"messages":[]}`, + expected: nil, }, { - name: "last message has no tool result blocks", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`, - expectedToolID: nil, + name: "last message has no tool result blocks", + requestBody: `{"messages":[{"role":"user","content":"hello"}]}`, + expected: nil, }, { - name: "single tool result block", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_abc","content":"result"}]}]}`, - expectedToolID: utils.PtrTo("toolu_abc"), + name: "single tool result block", + requestBody: `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_abc","content":"result"}]}]}`, + expected: utils.PtrTo("toolu_abc"), }, { - name: "multiple tool result blocks returns last", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"},{"type":"text","text":"ignored"},{"type":"tool_result","tool_use_id":"toolu_second","content":"second"}]}]}`, - expectedToolID: utils.PtrTo("toolu_second"), + name: "multiple tool result blocks returns last", + requestBody: `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"},{"type":"text","text":"ignored"},{"type":"tool_result","tool_use_id":"toolu_second","content":"second"}]}]}`, + expected: utils.PtrTo("toolu_second"), }, { - name: "last message is not a tool result", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"}]},{"role":"user","content":"some text"}]}`, - expectedToolID: nil, + name: "last message is not a tool result", + requestBody: `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"}]},{"role":"user","content":"some text"}]}`, + expected: nil, }, } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { t.Parallel() base := &interceptionBase{ - reqPayload: mustMessagesPayload(t, testCase.requestBody), + reqPayload: mustMessagesPayload(t, tc.requestBody), } - require.Equal(t, testCase.expectedToolID, base.CorrelatingToolCallID()) + require.Equal(t, tc.expected, base.CorrelatingToolCallID()) }) } } @@ -366,162 +366,228 @@ func TestAccumulateUsage(t *testing.T) { func TestInjectTools_CacheBreakpoints(t *testing.T) { t.Parallel() - testCases := []struct { - name string - requestBody string - injectedTools []*mcp.Tool - expectedToolNames []string - expectedCacheControlTypes []string - }{ - { - name: "cache control preserved when no tools to inject", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[` + - `{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`, + t.Run("cache control preserved when no tools to inject", func(t *testing.T) { + t.Parallel() - injectedTools: nil, - expectedToolNames: []string{"existing_tool"}, - expectedCacheControlTypes: []string{string(constant.ValueOf[constant.Ephemeral]())}, - }, - { - name: "cache control breakpoint is preserved by prepending injected tools", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[` + - `{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`, + // Request has existing tool with cache control, but no tools to inject. + i := &interceptionBase{ + reqPayload: mustMessagesPayload(t, `{"tools":[`+ + `{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`), + mcpProxy: &mockServerProxier{tools: nil}, + logger: slog.Make(), + } - injectedTools: []*mcp.Tool{{ID: "injected_tool", Name: "injected", Description: "Injected tool"}}, - expectedToolNames: []string{"injected_tool", "existing_tool"}, - expectedCacheControlTypes: []string{"", string(constant.ValueOf[constant.Ephemeral]())}, - }, - { - name: "cache control breakpoint in non standard location is preserved", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[` + - `{"name":"tool_with_cache_1","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}},` + - `{"name":"tool_with_cache_2","type":"custom","input_schema":{"type":"object","properties":{}}}]}`, - - injectedTools: []*mcp.Tool{{ID: "injected_tool", Name: "injected", Description: "Injected tool"}}, - expectedToolNames: []string{"injected_tool", "tool_with_cache_1", "tool_with_cache_2"}, - expectedCacheControlTypes: []string{"", string(constant.ValueOf[constant.Ephemeral]()), ""}, - }, - { - name: "no cache control added when none originally set", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[` + - `{"name":"existing_tool_no_cache","type":"custom","input_schema":{"type":"object","properties":{}}}]}`, + i.injectTools() - injectedTools: []*mcp.Tool{{ID: "injected_tool", Name: "injected", Description: "Injected tool"}}, - expectedToolNames: []string{"injected_tool", "existing_tool_no_cache"}, - expectedCacheControlTypes: []string{"", ""}, - }, - } + // Cache control should remain untouched since no tools were injected. + toolItems := gjson.GetBytes(i.reqPayload, "tools").Array() + require.Len(t, toolItems, 1) + require.Equal(t, "existing_tool", toolItems[0].Get("name").String()) + require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[0].Get("cache_control.type").String()) + }) - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - t.Parallel() + t.Run("cache control breakpoint is preserved by prepending injected tools", func(t *testing.T) { + t.Parallel() + + // Request has existing tool with cache control. + i := &interceptionBase{ + reqPayload: mustMessagesPayload(t, `{"tools":[`+ + `{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`), + mcpProxy: &mockServerProxier{ + tools: []*mcp.Tool{ + {ID: "injected_tool", Name: "injected", Description: "Injected tool"}, + }, + }, + logger: slog.Make(), + } - base := &interceptionBase{ - reqPayload: mustMessagesPayload(t, testCase.requestBody), - mcpProxy: &mockServerProxier{tools: testCase.injectedTools}, - logger: slog.Make(), - } + i.injectTools() - base.injectTools() + toolItems := gjson.GetBytes(i.reqPayload, "tools").Array() + require.Len(t, toolItems, 2) + // Injected tools are prepended. + require.Equal(t, "injected_tool", toolItems[0].Get("name").String()) + require.Empty(t, toolItems[0].Get("cache_control.type").String()) + // Original tool's cache control should be preserved at the end. + require.Equal(t, "existing_tool", toolItems[1].Get("name").String()) + require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[1].Get("cache_control.type").String()) + }) - toolItems := gjson.GetBytes(base.reqPayload, "tools").Array() - require.Len(t, toolItems, len(testCase.expectedToolNames)) - for idx := range toolItems { - require.Equal(t, testCase.expectedToolNames[idx], toolItems[idx].Get("name").String()) - require.Equal(t, testCase.expectedCacheControlTypes[idx], toolItems[idx].Get("cache_control.type").String()) - } - }) - } + // The cache breakpoint SHOULD be on the final tool, but may not be; we must preserve that intention. + t.Run("cache control breakpoint in non-standard location is preserved", func(t *testing.T) { + t.Parallel() + + // Request has multiple tools with cache control breakpoints. + i := &interceptionBase{ + reqPayload: mustMessagesPayload(t, `{"tools":[`+ + `{"name":"tool_with_cache_1","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}},`+ + `{"name":"tool_with_cache_2","type":"custom","input_schema":{"type":"object","properties":{}}}]}`), + mcpProxy: &mockServerProxier{ + tools: []*mcp.Tool{ + {ID: "injected_tool", Name: "injected", Description: "Injected tool"}, + }, + }, + logger: slog.Make(), + } + + i.injectTools() + + toolItems := gjson.GetBytes(i.reqPayload, "tools").Array() + require.Len(t, toolItems, 3) + // Injected tool is prepended without cache control. + require.Equal(t, "injected_tool", toolItems[0].Get("name").String()) + require.Empty(t, toolItems[0].Get("cache_control.type").String()) + // Both original tools' cache controls should remain. + require.Equal(t, "tool_with_cache_1", toolItems[1].Get("name").String()) + require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[1].Get("cache_control.type").String()) + require.Equal(t, "tool_with_cache_2", toolItems[2].Get("name").String()) + require.Empty(t, toolItems[2].Get("cache_control.type").String()) + }) + + t.Run("no cache control added when none originally set", func(t *testing.T) { + t.Parallel() + + // Request has tools but none with cache control. + i := &interceptionBase{ + reqPayload: mustMessagesPayload(t, `{"tools":[`+ + `{"name":"existing_tool_no_cache","type":"custom","input_schema":{"type":"object","properties":{}}}]}`), + mcpProxy: &mockServerProxier{ + tools: []*mcp.Tool{ + {ID: "injected_tool", Name: "injected", Description: "Injected tool"}, + }, + }, + logger: slog.Make(), + } + + i.injectTools() + + toolItems := gjson.GetBytes(i.reqPayload, "tools").Array() + require.Len(t, toolItems, 2) + // Injected tool is prepended without cache control. + require.Equal(t, "injected_tool", toolItems[0].Get("name").String()) + require.Empty(t, toolItems[0].Get("cache_control.type").String()) + // Original tool remains at the end without cache control. + require.Equal(t, "existing_tool_no_cache", toolItems[1].Get("name").String()) + require.Empty(t, toolItems[1].Get("cache_control.type").String()) + }) } func TestInjectTools_ParallelToolCalls(t *testing.T) { t.Parallel() - testCases := []struct { - name string - requestBody string - injectedTools []*mcp.Tool - expectedToolChoiceType string - expectedDisableParallel *bool - expectedToolCount int - }{ - { - name: "does not modify tool choice when no tools to inject", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tool_choice":{"type":"auto"}}`, - injectedTools: nil, - expectedToolChoiceType: string(constant.ValueOf[constant.Auto]()), - expectedDisableParallel: nil, - expectedToolCount: 0, - }, - { - name: "disables parallel tool use for auto tool choice default", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`, - injectedTools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, - expectedToolChoiceType: string(constant.ValueOf[constant.Auto]()), - expectedDisableParallel: utils.PtrTo(true), - expectedToolCount: 1, - }, - { - name: "disables parallel tool use for explicit auto tool choice", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tool_choice":{"type":"auto"}}`, - injectedTools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, - expectedToolChoiceType: string(constant.ValueOf[constant.Auto]()), - expectedDisableParallel: utils.PtrTo(true), - expectedToolCount: 1, - }, - { - name: "disables parallel tool use for any tool choice", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tool_choice":{"type":"any"}}`, - injectedTools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, - expectedToolChoiceType: string(constant.ValueOf[constant.Any]()), - expectedDisableParallel: utils.PtrTo(true), - expectedToolCount: 1, - }, - { - name: "disables parallel tool use for tool choice type", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tool_choice":{"type":"tool","name":"specific_tool"}}`, - injectedTools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, - expectedToolChoiceType: string(constant.ValueOf[constant.Tool]()), - expectedDisableParallel: utils.PtrTo(true), - expectedToolCount: 1, - }, - { - name: "no op for none tool choice type", - requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tool_choice":{"type":"none"}}`, - injectedTools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, - expectedToolChoiceType: string(constant.ValueOf[constant.None]()), - expectedDisableParallel: nil, - expectedToolCount: 1, - }, - } + t.Run("does not modify tool choice when no tools to inject", func(t *testing.T) { + t.Parallel() - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - t.Parallel() + i := &interceptionBase{ + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"auto"}}`), + mcpProxy: &mockServerProxier{tools: nil}, // No tools to inject. + logger: slog.Make(), + } - base := &interceptionBase{ - reqPayload: mustMessagesPayload(t, testCase.requestBody), - mcpProxy: &mockServerProxier{tools: testCase.injectedTools}, - logger: slog.Make(), - } + i.injectTools() + + // Tool choice should remain unchanged - DisableParallelToolUse should not be set. + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String()) + require.False(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + }) + + t.Run("disables parallel tool use for empty tool choice (default)", func(t *testing.T) { + t.Parallel() - base.injectTools() + i := &interceptionBase{ + reqPayload: mustMessagesPayload(t, `{}`), + mcpProxy: &mockServerProxier{ + tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, + }, + logger: slog.Make(), + } - require.Len(t, gjson.GetBytes(base.reqPayload, "tools").Array(), testCase.expectedToolCount) + i.injectTools() - toolChoice := gjson.GetBytes(base.reqPayload, "tool_choice") - require.Equal(t, testCase.expectedToolChoiceType, toolChoice.Get("type").String()) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool()) + }) - disableParallelResult := toolChoice.Get("disable_parallel_tool_use") - if testCase.expectedDisableParallel == nil { - require.False(t, disableParallelResult.Exists()) - return - } + t.Run("disables parallel tool use for explicit auto tool choice", func(t *testing.T) { + t.Parallel() - require.True(t, disableParallelResult.Exists()) - require.Equal(t, *testCase.expectedDisableParallel, disableParallelResult.Bool()) - }) - } + i := &interceptionBase{ + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"auto"}}`), + mcpProxy: &mockServerProxier{ + tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, + }, + logger: slog.Make(), + } + + i.injectTools() + + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool()) + }) + + t.Run("disables parallel tool use for any tool choice", func(t *testing.T) { + t.Parallel() + + i := &interceptionBase{ + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"any"}}`), + mcpProxy: &mockServerProxier{ + tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, + }, + logger: slog.Make(), + } + + i.injectTools() + + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Any]()), toolChoice.Get("type").String()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool()) + }) + + t.Run("disables parallel tool use for tool choice type", func(t *testing.T) { + t.Parallel() + + i := &interceptionBase{ + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"tool","name":"specific_tool"}}`), + mcpProxy: &mockServerProxier{ + tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, + }, + logger: slog.Make(), + } + + i.injectTools() + + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Tool]()), toolChoice.Get("type").String()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool()) + }) + + t.Run("no-op for none tool choice type", func(t *testing.T) { + t.Parallel() + + i := &interceptionBase{ + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"none"}}`), + mcpProxy: &mockServerProxier{ + tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, + }, + logger: slog.Make(), + } + + i.injectTools() + + // Tools are still injected. + require.Len(t, gjson.GetBytes(i.reqPayload, "tools").Array(), 1) + // But no parallel tool use modification for "none" type. + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.None]()), toolChoice.Get("type").String()) + require.False(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + }) } func mustMessagesPayload(t *testing.T, requestBody string) MessagesRequestPayload { diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index a051af45..7ed267cd 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -82,12 +82,12 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req // TODO(ssncferreira): inject actor headers directly in the client-header // middleware instead of using SDK options. - requestOptions := []option.RequestOption{option.WithRequestTimeout(time.Second * 600)} + opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 600)} if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { - requestOptions = append(requestOptions, intercept.ActorHeadersAsAnthropicOpts(actor)...) + opts = append(opts, intercept.ActorHeadersAsAnthropicOpts(actor)...) } - svc, err := i.newMessagesService(ctx, requestOptions...) + svc, err := i.newMessagesService(ctx, opts...) if err != nil { err = fmt.Errorf("create anthropic client: %w", err) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -97,12 +97,15 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req logger := i.logger.With(slog.F("model", i.Model())) var resp *anthropic.Message + // Accumulate usage across the entire streaming interaction (including tool reinvocations). var cumulativeUsage anthropic.Usage for { + // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) resp, err = i.newMessage(ctx, svc) if err != nil { if eventstream.IsConnError(err) { + // Can't write a response, just error out. return fmt.Errorf("upstream connection closed: %w", err) } @@ -151,8 +154,8 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req // Handle tool calls. var pendingToolCalls []anthropic.ToolUseBlock - for _, contentBlock := range resp.Content { - toolUse := contentBlock.AsToolUse() + for _, c := range resp.Content { + toolUse := c.AsToolUse() if toolUse.ID == "" { continue } @@ -162,6 +165,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req continue } + // If tool is not injected, track it since the client will be handling it. _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: resp.ID, @@ -172,6 +176,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req }) } + // If no injected tool calls, we're done. if len(pendingToolCalls) == 0 { break } @@ -179,71 +184,79 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req var loopMessages []anthropic.MessageParam loopMessages = append(loopMessages, resp.ToParam()) - for _, toolCall := range pendingToolCalls { + // Process each pending tool call. + for _, tc := range pendingToolCalls { if i.mcpProxy == nil { continue } - tool := i.mcpProxy.GetTool(toolCall.Name) + tool := i.mcpProxy.GetTool(tc.Name) if tool == nil { - logger.Warn(ctx, "tool not found in manager", slog.F("tool", toolCall.Name)) + logger.Warn(ctx, "tool not found in manager", slog.F("tool", tc.Name)) + // Continue to next tool call, but still append an error tool_result loopMessages = append(loopMessages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(toolCall.ID, fmt.Sprintf("Error: tool %s not found", toolCall.Name), true)), + anthropic.NewUserMessage(anthropic.NewToolResultBlock(tc.ID, fmt.Sprintf("Error: tool %s not found", tc.Name), true)), ) continue } - toolResultResponse, toolCallErr := tool.Call(ctx, toolCall.Input, i.tracer) + res, err := tool.Call(ctx, tc.Input, i.tracer) + _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: resp.ID, - ToolCallID: toolCall.ID, + ToolCallID: tc.ID, ServerURL: &tool.ServerURL, Tool: tool.Name, - Args: toolCall.Input, + Args: tc.Input, Injected: true, - InvocationError: toolCallErr, + InvocationError: err, }) - if toolCallErr != nil { + if err != nil { + // Always provide a tool_result even if the tool call failed loopMessages = append(loopMessages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(toolCall.ID, fmt.Sprintf("Error: calling tool: %v", toolCallErr), true)), + anthropic.NewUserMessage(anthropic.NewToolResultBlock(tc.ID, fmt.Sprintf("Error: calling tool: %v", err), true)), ) continue } + // Process tool result toolResult := anthropic.ContentBlockParamUnion{ OfToolResult: &anthropic.ToolResultBlockParam{ - ToolUseID: toolCall.ID, + ToolUseID: tc.ID, IsError: anthropic.Bool(false), }, } var hasValidResult bool - for _, toolContent := range toolResultResponse.Content { - switch contentBlock := toolContent.(type) { + for _, content := range res.Content { + switch cb := content.(type) { case mcplib.TextContent: toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.ToolResultBlockParamContentUnion{ OfText: &anthropic.TextBlockParam{ - Text: contentBlock.Text, + Text: cb.Text, }, }) hasValidResult = true + // TODO: is there a more correct way of handling these non-text content responses? case mcplib.EmbeddedResource: - switch resource := contentBlock.Resource.(type) { + switch resource := cb.Resource.(type) { case mcplib.TextResourceContents: - value := fmt.Sprintf("Binary resource (MIME: %s, URI: %s): %s", resource.MIMEType, resource.URI, resource.Text) + val := fmt.Sprintf("Binary resource (MIME: %s, URI: %s): %s", + resource.MIMEType, resource.URI, resource.Text) toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.ToolResultBlockParamContentUnion{ OfText: &anthropic.TextBlockParam{ - Text: value, + Text: val, }, }) hasValidResult = true case mcplib.BlobResourceContents: - value := fmt.Sprintf("Binary resource (MIME: %s, URI: %s): %s", resource.MIMEType, resource.URI, resource.Blob) + val := fmt.Sprintf("Binary resource (MIME: %s, URI: %s): %s", + resource.MIMEType, resource.URI, resource.Blob) toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.ToolResultBlockParamContentUnion{ OfText: &anthropic.TextBlockParam{ - Text: value, + Text: val, }, }) hasValidResult = true @@ -258,7 +271,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req hasValidResult = true } default: - i.logger.Warn(ctx, "not handling non-text tool result", slog.F("type", fmt.Sprintf("%T", contentBlock))) + i.logger.Warn(ctx, "not handling non-text tool result", slog.F("type", fmt.Sprintf("%T", cb))) toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.ToolResultBlockParamContentUnion{ OfText: &anthropic.TextBlockParam{ Text: "Error: unsupported tool result type", @@ -269,8 +282,9 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req } } + // If no content was processed, still add a tool_result if !hasValidResult { - i.logger.Warn(ctx, "no tool result added", slog.F("content_len", len(toolResultResponse.Content)), slog.F("is_error", toolResultResponse.IsError)) + i.logger.Warn(ctx, "no tool result added", slog.F("content_len", len(res.Content)), slog.F("is_error", res.IsError)) toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.ToolResultBlockParamContentUnion{ OfText: &anthropic.TextBlockParam{ Text: "Error: no valid tool result content", @@ -296,19 +310,21 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req return nil } - responseJSON, err := sjson.Set(resp.RawJSON(), "id", i.ID().String()) + // Overwrite response identifier since proxy obscures injected tool call invocations. + sj, err := sjson.Set(resp.RawJSON(), "id", i.ID().String()) if err != nil { return fmt.Errorf("marshal response id failed: %w", err) } - responseJSON, err = sjson.Set(responseJSON, "usage", cumulativeUsage) + // Overwrite the response's usage with the cumulative usage across any inner loops which invokes injected MCP tools. + sj, err = sjson.Set(sj, "usage", cumulativeUsage) if err != nil { return fmt.Errorf("marshal response usage failed: %w", err) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(responseJSON)) + _, _ = w.Write([]byte(sj)) return nil } diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 482e2f5a..d317c55b 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -101,17 +101,16 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re logger := i.logger.With(slog.F("model", i.Model())) var ( - prompt *string - err error + prompt string + promptFound bool + err error ) // Claude Code uses a "small/fast model" for certain tasks. if !i.isSmallFastModel() { - promptText, promptFound, promptErr := i.reqPayload.lastUserPrompt() - if promptErr != nil { - logger.Warn(ctx, "failed to determine last user prompt", slog.Error(promptErr)) - } else if promptFound { - prompt = &promptText + prompt, promptFound, err = i.reqPayload.lastUserPrompt() + if err != nil { + logger.Warn(ctx, "failed to determine last user prompt", slog.Error(err)) } // Only inject tools into "actual" request. @@ -295,9 +294,11 @@ newStream: continue } - var input json.RawMessage - var foundTool bool - var foundTools int + var ( + input json.RawMessage + foundTool bool + foundTools int + ) for _, block := range message.Content { switch variant := block.AsAny().(type) { case anthropic.ToolUseBlock: @@ -328,12 +329,14 @@ newStream: }) if err != nil { + // Always provide a tool_result even if the tool call failed loopMessages = append(loopMessages, anthropic.NewUserMessage(anthropic.NewToolResultBlock(id, fmt.Sprintf("Error calling tool: %v", err), true)), ) continue } + // Process tool result toolResult := anthropic.ContentBlockParamUnion{ OfToolResult: &anthropic.ToolResultBlockParam{ ToolUseID: id, @@ -393,6 +396,7 @@ newStream: } } + // If no content was processed, still add a tool_result if !hasValidResult { logger.Warn(ctx, "no tool result added", slog.F("content_len", len(res.Content)), slog.F("is_error", res.IsError)) toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.ToolResultBlockParamContentUnion{ @@ -408,9 +412,11 @@ newStream: } } - updatedPayload, rewriteErr := i.reqPayload.appendedMessages(loopMessages) - if rewriteErr != nil { - lastErr = fmt.Errorf("rewrite payload for agentic loop: %w", rewriteErr) + // Sync the raw payload with updated messages so that withBody() + // sends the updated payload on the next iteration. + updatedPayload, syncErr := i.reqPayload.appendedMessages(loopMessages) + if syncErr != nil { + lastErr = fmt.Errorf("sync payload for agentic loop: %w", syncErr) break } i.reqPayload = updatedPayload @@ -459,13 +465,14 @@ newStream: } } - if prompt != nil { + if promptFound { _ = i.recorder.RecordPromptUsage(ctx, &recorder.PromptUsageRecord{ InterceptionID: i.ID().String(), MsgID: message.ID, - Prompt: *prompt, + Prompt: prompt, }) - prompt = nil + prompt = "" + promptFound = false } if events.IsStreaming() { @@ -573,7 +580,7 @@ func (s *StreamingInterception) encodeForStream(payload []byte, typ string) []by return buf.Bytes() } -// newStream traces svc.NewStreaming(streamCtx, anthropic.MessageNewParams{}). +// newStream traces svc.NewStreaming() call. func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] { _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() diff --git a/internal/integrationtest/mockupstream.go b/internal/integrationtest/mockupstream.go index 40b00663..a658b054 100644 --- a/internal/integrationtest/mockupstream.go +++ b/internal/integrationtest/mockupstream.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "testing" + "github.com/anthropics/anthropic-sdk-go" "github.com/coder/aibridge/fixtures" "github.com/openai/openai-go/v3" "github.com/stretchr/testify/require" @@ -295,14 +296,14 @@ func validateOpenAIResponses(t *testing.T, body []byte, msgAndArgs ...any) { } // validateAnthropicMessages validates that an Anthropic messages request -// has the required top-level fields while remaining tolerant of raw payload -// shapes that do not round-trip through anthropic.MessageNewParams. +// has all required fields. +// See https://github.com/anthropics/anthropic-sdk-go. func validateAnthropicMessages(t *testing.T, body []byte, msgAndArgs ...any) { t.Helper() - var requestBody map[string]any - require.NoError(t, json.Unmarshal(body, &requestBody), msgAndArgs...) - require.NotEmpty(t, requestBody["model"], "model is required", msgAndArgs) - require.NotEmpty(t, requestBody["messages"], "messages is required", msgAndArgs) - require.NotZero(t, requestBody["max_tokens"], "max_tokens is required", msgAndArgs) + var req anthropic.MessageNewParams + require.NoError(t, json.Unmarshal(body, &req), msgAndArgs...) + require.NotEmpty(t, req.Model, "model is required", msgAndArgs) + require.NotEmpty(t, req.Messages, "messages is required", msgAndArgs) + require.NotZero(t, req.MaxTokens, "max_tokens is required", msgAndArgs) }