diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 09372ec7..f2f0e708 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -24,7 +24,6 @@ import ( "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" "github.com/coder/quartz" - "github.com/tidwall/sjson" "github.com/google/uuid" "go.opentelemetry.io/otel/attribute" @@ -34,9 +33,8 @@ import ( ) type interceptionBase struct { - id uuid.UUID - req *MessageNewParamsWrapper - payload []byte + id uuid.UUID + reqPayload MessagesRequestPayload cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock @@ -63,22 +61,11 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, } func (i *interceptionBase) CorrelatingToolCallID() *string { - if len(i.req.Messages) == 0 { - return nil - } - content := i.req.Messages[len(i.req.Messages)-1].Content - for idx := len(content) - 1; idx >= 0; idx-- { - block := content[idx] - if block.OfToolResult == nil { - continue - } - return &block.OfToolResult.ToolUseID - } - return nil + return i.reqPayload.correlatingToolCallID() } func (i *interceptionBase) Model() string { - if i.req == nil { + if len(i.reqPayload) == 0 { return "coder-aibridge-unknown" } @@ -90,7 +77,7 @@ func (i *interceptionBase) Model() string { return model } - return string(i.req.Model) + return i.reqPayload.model() } func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { @@ -106,7 +93,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) } func (i *interceptionBase) injectTools() { - if i.req == nil || i.mcpProxy == nil || !i.hasInjectableTools() { + if i.mcpProxy == nil || !i.hasInjectableTools() { return } @@ -131,46 +118,23 @@ func (i *interceptionBase) injectTools() { // Prepend the injected tools in order to maintain any configured cache breakpoints. // The order of injected tools is expected to be stable, and therefore will not cause // any cache invalidation when prepended. - i.req.Tools = append(injectedTools, i.req.Tools...) - - var err error - i.payload, err = sjson.SetBytes(i.payload, "tools", i.req.Tools) + updated, err := i.reqPayload.injectTools(injectedTools) if err != nil { i.logger.Warn(context.Background(), "failed to set inject tools in request payload", slog.Error(err)) + return } + i.reqPayload = updated } func (i *interceptionBase) disableParallelToolCalls() { // Note: Parallel tool calls are disabled to avoid tool_use/tool_result block mismatches. // https://github.com/coder/aibridge/issues/2 - toolChoiceType := i.req.ToolChoice.GetType() - var toolChoiceTypeStr string - if toolChoiceType != nil { - toolChoiceTypeStr = *toolChoiceType - } - - switch toolChoiceTypeStr { - // If no tool_choice was defined, assume auto. - // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use. - case "", string(constant.ValueOf[constant.Auto]()): - // We only set OfAuto if no tool_choice was provided (the default). - // "auto" is the default when a zero value is provided, so we can safely disable parallel checks on it. - if i.req.ToolChoice.OfAuto == nil { - i.req.ToolChoice.OfAuto = &anthropic.ToolChoiceAutoParam{} - } - i.req.ToolChoice.OfAuto.DisableParallelToolUse = anthropic.Bool(true) - case string(constant.ValueOf[constant.Any]()): - i.req.ToolChoice.OfAny.DisableParallelToolUse = anthropic.Bool(true) - case string(constant.ValueOf[constant.Tool]()): - i.req.ToolChoice.OfTool.DisableParallelToolUse = anthropic.Bool(true) - case string(constant.ValueOf[constant.None]()): - // No-op; if tool_choice=none then tools are not used at all. - } - var err error - i.payload, err = sjson.SetBytes(i.payload, "tool_choice", i.req.ToolChoice) + updated, err := i.reqPayload.disableParallelToolCalls() if err != nil { i.logger.Warn(context.Background(), "failed to set tool_choice in request payload", slog.Error(err)) + return } + i.reqPayload = updated } // extractModelThoughts returns any thinking blocks that were returned in the response. @@ -201,7 +165,7 @@ func (i *interceptionBase) extractModelThoughts(msg *anthropic.Message) []*recor // See `ANTHROPIC_SMALL_FAST_MODEL`: https://docs.anthropic.com/en/docs/claude-code/settings#environment-variables // https://docs.claude.com/en/docs/claude-code/costs#background-token-usage func (i *interceptionBase) isSmallFastModel() bool { - return strings.Contains(string(i.req.Model), "haiku") + return strings.Contains(i.reqPayload.model(), "haiku") } func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) { @@ -244,23 +208,12 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio return anthropic.NewMessageService(opts...), nil } -// withBody returns a per-request option that sends the current i.payload as the -// request body. This is called for each API request so that the latest payload (including -// any messages appended during the agentic tool loop) is always sent. +// withBody returns a per-request option that sends the current raw request +// payload as the request body. This is called for each API request so that the +// latest payload (including any messages appended during the agentic tool loop) +// is always sent. func (i *interceptionBase) withBody() option.RequestOption { - return option.WithRequestBody("application/json", i.payload) -} - -// syncPayloadMessages updates the raw payload's "messages" field to match the given messages. -// This must be called before the next API request in the agentic loop so that -// withBody() picks up the updated messages. -func (i *interceptionBase) syncPayloadMessages(messages []anthropic.MessageParam) error { - var err error - i.payload, err = sjson.SetBytes(i.payload, "messages", messages) - if err != nil { - return fmt.Errorf("sync payload messages: %w", err) - } - return nil + return option.WithRequestBody("application/json", []byte(i.reqPayload)) } func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) { @@ -317,13 +270,13 @@ func (i *interceptionBase) augmentRequestForBedrock() { return } - i.req.MessageNewParams.Model = anthropic.Model(i.Model()) - - var err error - i.payload, err = sjson.SetBytes(i.payload, "model", i.Model()) + updated, err := i.reqPayload.withModel(i.Model()) if err != nil { i.logger.Warn(context.Background(), "failed to set model in request payload for Bedrock", slog.Error(err)) + return } + + i.reqPayload = updated } // writeUpstreamError marshals and writes a given error. diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index cca890e0..3f25b6eb 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "cdr.dev/slog/v3" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge/config" @@ -11,84 +12,51 @@ import ( "github.com/coder/aibridge/utils" mcpgo "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) func TestScanForCorrelatingToolCallID(t *testing.T) { t.Parallel() tests := []struct { - name string - messages []anthropic.MessageParam - expected *string + name string + requestBody string + expected *string }{ { - name: "no messages", - messages: nil, - expected: nil, + name: "no messages field", + requestBody: `{}`, + expected: nil, }, { - name: "last message has no tool_result blocks", - messages: []anthropic.MessageParam{ - anthropic.NewUserMessage(anthropic.NewTextBlock("hello")), - }, - expected: nil, + name: "messages string", + requestBody: `{"messages":"test"}`, + expected: nil, }, { - name: "single tool_result block", - messages: []anthropic.MessageParam{ - anthropic.NewUserMessage( - anthropic.ContentBlockParamUnion{ - OfToolResult: &anthropic.ToolResultBlockParam{ - ToolUseID: "toolu_abc", - Content: []anthropic.ToolResultBlockParamContentUnion{ - {OfText: &anthropic.TextBlockParam{Text: "result"}}, - }, - }, - }, - ), - }, - expected: utils.PtrTo("toolu_abc"), + name: "empty messages array", + requestBody: `{"messages":[]}`, + expected: nil, }, { - name: "multiple tool_result blocks returns last", - messages: []anthropic.MessageParam{ - anthropic.NewUserMessage( - anthropic.ContentBlockParamUnion{ - OfToolResult: &anthropic.ToolResultBlockParam{ - ToolUseID: "toolu_first", - Content: []anthropic.ToolResultBlockParamContentUnion{ - {OfText: &anthropic.TextBlockParam{Text: "first"}}, - }, - }, - }, - anthropic.NewTextBlock("some text"), - anthropic.ContentBlockParamUnion{ - OfToolResult: &anthropic.ToolResultBlockParam{ - ToolUseID: "toolu_second", - Content: []anthropic.ToolResultBlockParamContentUnion{ - {OfText: &anthropic.TextBlockParam{Text: "second"}}, - }, - }, - }, - ), - }, - expected: utils.PtrTo("toolu_second"), + name: "last message has no tool result blocks", + requestBody: `{"messages":[{"role":"user","content":"hello"}]}`, + expected: nil, }, { - name: "last message is not a tool result", - messages: []anthropic.MessageParam{ - anthropic.NewUserMessage( - anthropic.ContentBlockParamUnion{ - OfToolResult: &anthropic.ToolResultBlockParam{ - ToolUseID: "toolu_first", - Content: []anthropic.ToolResultBlockParamContentUnion{ - {OfText: &anthropic.TextBlockParam{Text: "first"}}, - }, - }, - }), - anthropic.NewUserMessage(anthropic.NewTextBlock("some text")), - }, - expected: nil, + name: "single tool result block", + requestBody: `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_abc","content":"result"}]}]}`, + expected: utils.PtrTo("toolu_abc"), + }, + { + name: "multiple tool result blocks returns last", + requestBody: `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"},{"type":"text","text":"ignored"},{"type":"tool_result","tool_use_id":"toolu_second","content":"second"}]}]}`, + expected: utils.PtrTo("toolu_second"), + }, + { + name: "last message is not a tool result", + requestBody: `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"}]},{"role":"user","content":"some text"}]}`, + expected: nil, }, } @@ -97,11 +65,7 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { t.Parallel() base := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: tc.messages, - }, - }, + reqPayload: mustMessagesPayload(t, tc.requestBody), } require.Equal(t, tc.expected, base.CorrelatingToolCallID()) @@ -407,28 +371,19 @@ func TestInjectTools_CacheBreakpoints(t *testing.T) { // Request has existing tool with cache control, but no tools to inject. i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Tools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "existing_tool", - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: constant.ValueOf[constant.Ephemeral](), - }, - }, - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tools":[`+ + `{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`), mcpProxy: &mockServerProxier{tools: nil}, + logger: slog.Make(), } i.injectTools() // Cache control should remain untouched since no tools were injected. - require.Len(t, i.req.Tools, 1) - require.Equal(t, constant.ValueOf[constant.Ephemeral](), i.req.Tools[0].OfTool.CacheControl.Type) + toolItems := gjson.GetBytes(i.reqPayload, "tools").Array() + require.Len(t, toolItems, 1) + require.Equal(t, "existing_tool", toolItems[0].Get("name").String()) + require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[0].Get("cache_control.type").String()) }) t.Run("cache control breakpoint is preserved by prepending injected tools", func(t *testing.T) { @@ -436,36 +391,26 @@ func TestInjectTools_CacheBreakpoints(t *testing.T) { // Request has existing tool with cache control. i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Tools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "existing_tool", - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: constant.ValueOf[constant.Ephemeral](), - }, - }, - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tools":[`+ + `{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{ {ID: "injected_tool", Name: "injected", Description: "Injected tool"}, }, }, + logger: slog.Make(), } i.injectTools() - require.Len(t, i.req.Tools, 2) + toolItems := gjson.GetBytes(i.reqPayload, "tools").Array() + require.Len(t, toolItems, 2) // Injected tools are prepended. - require.Equal(t, "injected_tool", i.req.Tools[0].OfTool.Name) - require.Zero(t, i.req.Tools[0].OfTool.CacheControl) + require.Equal(t, "injected_tool", toolItems[0].Get("name").String()) + require.Empty(t, toolItems[0].Get("cache_control.type").String()) // Original tool's cache control should be preserved at the end. - require.Equal(t, "existing_tool", i.req.Tools[1].OfTool.Name) - require.Equal(t, constant.ValueOf[constant.Ephemeral](), i.req.Tools[1].OfTool.CacheControl.Type) + require.Equal(t, "existing_tool", toolItems[1].Get("name").String()) + require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[1].Get("cache_control.type").String()) }) // The cache breakpoint SHOULD be on the final tool, but may not be; we must preserve that intention. @@ -474,43 +419,29 @@ func TestInjectTools_CacheBreakpoints(t *testing.T) { // Request has multiple tools with cache control breakpoints. i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Tools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "tool_with_cache_1", - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: constant.ValueOf[constant.Ephemeral](), - }, - }, - }, - { - OfTool: &anthropic.ToolParam{ - Name: "tool_with_cache_2", - }, - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tools":[`+ + `{"name":"tool_with_cache_1","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}},`+ + `{"name":"tool_with_cache_2","type":"custom","input_schema":{"type":"object","properties":{}}}]}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{ {ID: "injected_tool", Name: "injected", Description: "Injected tool"}, }, }, + logger: slog.Make(), } i.injectTools() - require.Len(t, i.req.Tools, 3) + toolItems := gjson.GetBytes(i.reqPayload, "tools").Array() + require.Len(t, toolItems, 3) // Injected tool is prepended without cache control. - require.Equal(t, "injected_tool", i.req.Tools[0].OfTool.Name) - require.Zero(t, i.req.Tools[0].OfTool.CacheControl) + require.Equal(t, "injected_tool", toolItems[0].Get("name").String()) + require.Empty(t, toolItems[0].Get("cache_control.type").String()) // Both original tools' cache controls should remain. - require.Equal(t, "tool_with_cache_1", i.req.Tools[1].OfTool.Name) - require.Equal(t, constant.ValueOf[constant.Ephemeral](), i.req.Tools[1].OfTool.CacheControl.Type) - require.Equal(t, "tool_with_cache_2", i.req.Tools[2].OfTool.Name) - require.Zero(t, i.req.Tools[2].OfTool.CacheControl) + require.Equal(t, "tool_with_cache_1", toolItems[1].Get("name").String()) + require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[1].Get("cache_control.type").String()) + require.Equal(t, "tool_with_cache_2", toolItems[2].Get("name").String()) + require.Empty(t, toolItems[2].Get("cache_control.type").String()) }) t.Run("no cache control added when none originally set", func(t *testing.T) { @@ -518,33 +449,26 @@ func TestInjectTools_CacheBreakpoints(t *testing.T) { // Request has tools but none with cache control. i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Tools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "existing_tool_no_cache", - }, - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tools":[`+ + `{"name":"existing_tool_no_cache","type":"custom","input_schema":{"type":"object","properties":{}}}]}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{ {ID: "injected_tool", Name: "injected", Description: "Injected tool"}, }, }, + logger: slog.Make(), } i.injectTools() - require.Len(t, i.req.Tools, 2) + toolItems := gjson.GetBytes(i.reqPayload, "tools").Array() + require.Len(t, toolItems, 2) // Injected tool is prepended without cache control. - require.Equal(t, "injected_tool", i.req.Tools[0].OfTool.Name) - require.Zero(t, i.req.Tools[0].OfTool.CacheControl) + require.Equal(t, "injected_tool", toolItems[0].Get("name").String()) + require.Empty(t, toolItems[0].Get("cache_control.type").String()) // Original tool remains at the end without cache control. - require.Equal(t, "existing_tool_no_cache", i.req.Tools[1].OfTool.Name) - require.Zero(t, i.req.Tools[1].OfTool.CacheControl) + require.Equal(t, "existing_tool_no_cache", toolItems[1].Get("name").String()) + require.Empty(t, toolItems[1].Get("cache_control.type").String()) }) } @@ -555,152 +479,126 @@ func TestInjectTools_ParallelToolCalls(t *testing.T) { t.Parallel() i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{ - Type: constant.ValueOf[constant.Auto](), - }, - }, - }, - }, - mcpProxy: &mockServerProxier{tools: nil}, // No tools to inject. + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"auto"}}`), + mcpProxy: &mockServerProxier{tools: nil}, // No tools to inject. + logger: slog.Make(), } i.injectTools() // Tool choice should remain unchanged - DisableParallelToolUse should not be set. - require.NotNil(t, i.req.ToolChoice.OfAuto) - require.False(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Valid()) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String()) + require.False(t, toolChoice.Get("disable_parallel_tool_use").Exists()) }) - t.Run("disables parallel tool use for auto tool choice (default)", func(t *testing.T) { + t.Run("disables parallel tool use for empty tool choice (default)", func(t *testing.T) { t.Parallel() i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - // No tool choice set (default). - }, - }, + reqPayload: mustMessagesPayload(t, `{}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, }, + logger: slog.Make(), } i.injectTools() - require.NotNil(t, i.req.ToolChoice.OfAuto) - require.True(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Valid()) - require.True(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Value) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool()) }) t.Run("disables parallel tool use for explicit auto tool choice", func(t *testing.T) { t.Parallel() i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{ - Type: constant.ValueOf[constant.Auto](), - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"auto"}}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, }, + logger: slog.Make(), } i.injectTools() - require.NotNil(t, i.req.ToolChoice.OfAuto) - require.True(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Valid()) - require.True(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Value) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool()) }) t.Run("disables parallel tool use for any tool choice", func(t *testing.T) { t.Parallel() i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfAny: &anthropic.ToolChoiceAnyParam{ - Type: constant.ValueOf[constant.Any](), - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"any"}}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, }, + logger: slog.Make(), } i.injectTools() - require.NotNil(t, i.req.ToolChoice.OfAny) - require.True(t, i.req.ToolChoice.OfAny.DisableParallelToolUse.Valid()) - require.True(t, i.req.ToolChoice.OfAny.DisableParallelToolUse.Value) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Any]()), toolChoice.Get("type").String()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool()) }) t.Run("disables parallel tool use for tool choice type", func(t *testing.T) { t.Parallel() i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfTool: &anthropic.ToolChoiceToolParam{ - Type: constant.ValueOf[constant.Tool](), - Name: "specific_tool", - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"tool","name":"specific_tool"}}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, }, + logger: slog.Make(), } i.injectTools() - require.NotNil(t, i.req.ToolChoice.OfTool) - require.True(t, i.req.ToolChoice.OfTool.DisableParallelToolUse.Valid()) - require.True(t, i.req.ToolChoice.OfTool.DisableParallelToolUse.Value) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Tool]()), toolChoice.Get("type").String()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool()) }) t.Run("no-op for none tool choice type", func(t *testing.T) { t.Parallel() i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfNone: &anthropic.ToolChoiceNoneParam{ - Type: constant.ValueOf[constant.None](), - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"none"}}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, }, + logger: slog.Make(), } i.injectTools() // Tools are still injected. - require.Len(t, i.req.Tools, 1) + require.Len(t, gjson.GetBytes(i.reqPayload, "tools").Array(), 1) // But no parallel tool use modification for "none" type. - require.Nil(t, i.req.ToolChoice.OfAuto) - require.Nil(t, i.req.ToolChoice.OfAny) - require.Nil(t, i.req.ToolChoice.OfTool) - require.NotNil(t, i.req.ToolChoice.OfNone) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.None]()), toolChoice.Get("type").String()) + require.False(t, toolChoice.Get("disable_parallel_tool_use").Exists()) }) } +func mustMessagesPayload(t *testing.T, requestBody string) MessagesRequestPayload { + t.Helper() + + payload, err := NewMessagesRequestPayload([]byte(requestBody)) + require.NoError(t, err) + + return payload +} + // mockServerProxier is a test implementation of mcp.ServerProxier. type mockServerProxier struct { tools []*mcp.Tool diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 51a1a98d..7ed267cd 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -30,8 +30,7 @@ type BlockingInterception struct { func NewBlockingInterceptor( id uuid.UUID, - req *MessageNewParamsWrapper, - payload []byte, + reqPayload MessagesRequestPayload, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, @@ -40,8 +39,7 @@ func NewBlockingInterceptor( ) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ id: id, - req: req, - payload: payload, + reqPayload: reqPayload, cfg: cfg, bedrockCfg: bedrockCfg, clientHeaders: clientHeaders, @@ -63,8 +61,8 @@ func (s *BlockingInterception) Streaming() bool { } func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { - if i.req == nil { - return fmt.Errorf("developer error: req is nil") + if len(i.reqPayload) == 0 { + return fmt.Errorf("developer error: request payload is empty") } ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) @@ -72,22 +70,19 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req i.injectTools() - var ( - prompt *string - err error - ) - // Track user prompt if not a small/fast model + var prompt *string if !i.isSmallFastModel() { - prompt, err = i.req.lastUserPrompt() - if err != nil { - i.logger.Warn(ctx, "failed to retrieve last user prompt", slog.Error(err)) + promptText, promptFound, promptErr := i.reqPayload.lastUserPrompt() + if promptErr != nil { + i.logger.Warn(ctx, "failed to retrieve last user prompt", slog.Error(promptErr)) + } else if promptFound { + prompt = &promptText } } - opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 600)} - // TODO(ssncferreira): inject actor headers directly in the client-header // middleware instead of using SDK options. + opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 600)} if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsAnthropicOpts(actor)...) } @@ -99,8 +94,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req return err } - messages := i.req.MessageNewParams - logger := i.logger.With(slog.F("model", i.req.Model)) + logger := i.logger.With(slog.F("model", i.Model())) var resp *anthropic.Message // Accumulate usage across the entire streaming interaction (including tool reinvocations). @@ -108,7 +102,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req for { // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) - resp, err = i.newMessage(ctx, svc, messages) + resp, err = i.newMessage(ctx, svc) if err != nil { if eventstream.IsConnError(err) { // Can't write a response, just error out. @@ -187,9 +181,8 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req break } - // Append the assistant's message (which contains the tool_use block) - // to the messages for the next API call. - messages.Messages = append(messages.Messages, resp.ToParam()) + var loopMessages []anthropic.MessageParam + loopMessages = append(loopMessages, resp.ToParam()) // Process each pending tool call. for _, tc := range pendingToolCalls { @@ -201,7 +194,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req if tool == nil { logger.Warn(ctx, "tool not found in manager", slog.F("tool", tc.Name)) // Continue to next tool call, but still append an error tool_result - messages.Messages = append(messages.Messages, + loopMessages = append(loopMessages, anthropic.NewUserMessage(anthropic.NewToolResultBlock(tc.ID, fmt.Sprintf("Error: tool %s not found", tc.Name), true)), ) continue @@ -222,7 +215,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req if err != nil { // Always provide a tool_result even if the tool call failed - messages.Messages = append(messages.Messages, + loopMessages = append(loopMessages, anthropic.NewUserMessage(anthropic.NewToolResultBlock(tc.ID, fmt.Sprintf("Error: calling tool: %v", err), true)), ) continue @@ -301,16 +294,16 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req } if len(toolResult.OfToolResult.Content) > 0 { - messages.Messages = append(messages.Messages, anthropic.NewUserMessage(toolResult)) + loopMessages = append(loopMessages, anthropic.NewUserMessage(toolResult)) } } - // Sync the raw payload with updated messages so that withBody() - // sends the updated payload on the next iteration. - if err := i.syncPayloadMessages(messages.Messages); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return fmt.Errorf("sync payload for agentic loop: %w", err) + updatedPayload, rewriteErr := i.reqPayload.appendedMessages(loopMessages) + if rewriteErr != nil { + http.Error(w, rewriteErr.Error(), http.StatusInternalServerError) + return fmt.Errorf("rewrite payload for agentic loop: %w", rewriteErr) } + i.reqPayload = updatedPayload } if resp == nil { @@ -336,9 +329,9 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req return nil } -func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService, msgParams anthropic.MessageNewParams) (_ *anthropic.Message, outErr error) { +func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService) (_ *anthropic.Message, outErr error) { ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) - return svc.New(ctx, msgParams, i.withBody()) + return svc.New(ctx, anthropic.MessageNewParams{}, i.withBody()) } diff --git a/intercept/messages/paramswrap.go b/intercept/messages/paramswrap.go deleted file mode 100644 index bd5175aa..00000000 --- a/intercept/messages/paramswrap.go +++ /dev/null @@ -1,142 +0,0 @@ -package messages - -import ( - "encoding/json" - "errors" - - "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/packages/param" -) - -// MessageNewParamsWrapper exists because the "stream" param is not included in anthropic.MessageNewParams. -type MessageNewParamsWrapper struct { - anthropic.MessageNewParams `json:""` - Stream bool `json:"stream,omitempty"` -} - -func (b MessageNewParamsWrapper) MarshalJSON() ([]byte, error) { - type shadow MessageNewParamsWrapper - return param.MarshalWithExtras(b, (*shadow)(&b), map[string]any{ - "stream": b.Stream, - }) -} - -func (b *MessageNewParamsWrapper) UnmarshalJSON(raw []byte) error { - // Parse JSON once and extract both stream field and do content conversion - // to avoid double-parsing the same payload. - var modifiedJSON map[string]any - if err := json.Unmarshal(raw, &modifiedJSON); err != nil { - return err - } - - // Extract stream field from already-parsed map - if stream, ok := modifiedJSON["stream"].(bool); ok { - b.Stream = stream - } - - // Convert string content to array format if needed - if _, hasMessages := modifiedJSON["messages"]; hasMessages { - convertStringContentRecursive(modifiedJSON) - } - - // Marshal back for SDK parsing - convertedRaw, err := json.Marshal(modifiedJSON) - if err != nil { - return err - } - - return b.MessageNewParams.UnmarshalJSON(convertedRaw) -} - -func (b *MessageNewParamsWrapper) lastUserPrompt() (*string, error) { - if b == nil { - return nil, errors.New("nil struct") - } - - if len(b.Messages) == 0 { - return nil, errors.New("no messages") - } - - // We only care if the last message was issued by a user. - msg := b.Messages[len(b.Messages)-1] - if msg.Role != anthropic.MessageParamRoleUser { - return nil, nil - } - - if len(msg.Content) == 0 { - return nil, nil - } - - // Walk backwards on "user"-initiated message content. Clients often inject - // content ahead of the actual prompt to provide context to the model, - // so the last item in the slice is most likely the user's prompt. - for i := len(msg.Content) - 1; i >= 0; i-- { - // Only text content is supported currently. - if textContent := msg.Content[i].GetText(); textContent != nil { - return textContent, nil - } - } - - return nil, nil -} - -// convertStringContentRecursive recursively scans JSON data and converts string "content" fields -// to proper text block arrays where needed for Anthropic SDK compatibility. -// Returns true if any modifications were made. -func convertStringContentRecursive(data any) bool { - modified := false - switch v := data.(type) { - case map[string]any: - // Check if this object has a "content" field with string value - if content, hasContent := v["content"]; hasContent { - if contentStr, isString := content.(string); isString { - // Check if this needs conversion based on context - if shouldConvertContentField(v) { - v["content"] = []map[string]any{ - { - "type": "text", - "text": contentStr, - }, - } - modified = true - } - } - } - - // Recursively process all values in the map - for _, value := range v { - if convertStringContentRecursive(value) { - modified = true - } - } - - case []any: - // Recursively process all items in the array - for _, item := range v { - if convertStringContentRecursive(item) { - modified = true - } - } - } - return modified -} - -// shouldConvertContentField determines if a "content" string field should be converted to text block array -func shouldConvertContentField(obj map[string]any) bool { - // Check if this is a message-level content (has "role" field) - if _, hasRole := obj["role"]; hasRole { - return true - } - - // Check if this is a tool_result block (but not mcp_tool_result which supports strings) - if objType, hasType := obj["type"].(string); hasType { - switch objType { - case "tool_result": - return true // Regular tool_result needs array format - case "mcp_tool_result": - return false // MCP tool_result supports strings - } - } - - return false -} diff --git a/intercept/messages/paramswrap_test.go b/intercept/messages/paramswrap_test.go deleted file mode 100644 index 7f8793d7..00000000 --- a/intercept/messages/paramswrap_test.go +++ /dev/null @@ -1,303 +0,0 @@ -package messages - -import ( - "testing" - - "github.com/anthropics/anthropic-sdk-go" - "github.com/stretchr/testify/require" -) - -func TestMessageNewParamsWrapperUnmarshalJSON(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input string - expectedStream bool - checkContent func(t *testing.T, w *MessageNewParamsWrapper) - }{ - { - name: "message with string content converts to array", - input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"Hello world"}]}`, - expectedStream: false, - checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { - require.Len(t, w.Messages, 1) - require.Equal(t, anthropic.MessageParamRoleUser, w.Messages[0].Role) - text := w.Messages[0].Content[0].GetText() - require.NotNil(t, text) - require.Equal(t, "Hello world", *text) - }, - }, - { - name: "stream field extracted", - input: `{"model":"claude-3","max_tokens":1000,"stream":true,"messages":[{"role":"user","content":"Hi"}]}`, - expectedStream: true, - checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { - require.Len(t, w.Messages, 1) - }, - }, - { - name: "stream false", - input: `{"model":"claude-3","max_tokens":1000,"stream":false,"messages":[{"role":"user","content":"Hi"}]}`, - expectedStream: false, - checkContent: nil, - }, - { - name: "array content unchanged", - input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`, - expectedStream: false, - checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { - require.Len(t, w.Messages, 1) - text := w.Messages[0].Content[0].GetText() - require.NotNil(t, text) - require.Equal(t, "Hello", *text) - }, - }, - { - name: "multiple messages with mixed content", - input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"First"},{"role":"assistant","content":[{"type":"text","text":"Response"}]},{"role":"user","content":"Second"}]}`, - expectedStream: false, - checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { - require.Len(t, w.Messages, 3) - text0 := w.Messages[0].Content[0].GetText() - require.NotNil(t, text0) - require.Equal(t, "First", *text0) - text2 := w.Messages[2].Content[0].GetText() - require.NotNil(t, text2) - require.Equal(t, "Second", *text2) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var wrapper MessageNewParamsWrapper - err := wrapper.UnmarshalJSON([]byte(tt.input)) - require.NoError(t, err) - require.Equal(t, tt.expectedStream, wrapper.Stream) - if tt.checkContent != nil { - tt.checkContent(t, &wrapper) - } - }) - } -} - -func TestShouldConvertContentField(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - obj map[string]any - expected bool - }{ - { - name: "message with role", - obj: map[string]any{ - "role": "user", - "content": "test", - }, - expected: true, - }, - { - name: "tool_result type", - obj: map[string]any{ - "type": "tool_result", - "content": "result", - }, - expected: true, - }, - { - name: "mcp_tool_result type", - obj: map[string]any{ - "type": "mcp_tool_result", - "content": "result", - }, - expected: false, - }, - { - name: "other type", - obj: map[string]any{ - "type": "text", - "content": "text", - }, - expected: false, - }, - { - name: "no role or type", - obj: map[string]any{ - "content": "test", - }, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := shouldConvertContentField(tt.obj) - require.Equal(t, tt.expected, result) - }) - } -} - -func TestAnthropicLastUserPrompt(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - wrapper *MessageNewParamsWrapper - expected string - expectError bool - errorMsg string - }{ - { - name: "nil struct", - expectError: true, - errorMsg: "nil struct", - }, - { - name: "no messages", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{}, - }, - }, - expectError: true, - errorMsg: "no messages", - }, - { - name: "last message not from user", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("user message"), - }, - }, - { - Role: anthropic.MessageParamRoleAssistant, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("assistant message"), - }, - }, - }, - }, - }, - }, - { - name: "last user message with empty content", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{}, - }, - }, - }, - }, - }, - { - name: "last user message with single text content", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("Hello, world!"), - }, - }, - }, - }, - }, - expected: "Hello, world!", - }, - { - name: "last user message with multiple content blocks - text at end", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewImageBlockBase64("image/png", "base64data"), - anthropic.NewTextBlock("First text"), - anthropic.NewImageBlockBase64("image/jpeg", "moredata"), - anthropic.NewTextBlock("Last text"), - }, - }, - }, - }, - }, - expected: "Last text", - }, - { - name: "last user message with only non-text content", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewImageBlockBase64("image/png", "base64data"), - anthropic.NewImageBlockBase64("image/jpeg", "moredata"), - }, - }, - }, - }, - }, - }, - { - name: "multiple messages with last being user", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("First user message"), - }, - }, - { - Role: anthropic.MessageParamRoleAssistant, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("Assistant response"), - }, - }, - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("Second user message"), - }, - }, - }, - }, - }, - expected: "Second user message", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := tt.wrapper.lastUserPrompt() - - if tt.expectError { - require.Error(t, err) - require.Contains(t, err.Error(), tt.errorMsg) - require.Nil(t, result) - } else { - require.NoError(t, err) - // Check pointer equality - both nil or both non-nil - if tt.expected == "" { - require.Nil(t, result) - } else { - require.NotNil(t, result) - // The result should point to the same string from the content block - require.Equal(t, tt.expected, *result) - } - } - }) - } -} diff --git a/intercept/messages/reqpayload.go b/intercept/messages/reqpayload.go new file mode 100644 index 00000000..cefddd87 --- /dev/null +++ b/intercept/messages/reqpayload.go @@ -0,0 +1,271 @@ +package messages + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + // Absolute JSON paths from the request root. + messagesReqPathMessages = "messages" + messagesReqPathModel = "model" + messagesReqPathStream = "stream" + messagesReqPathToolChoice = "tool_choice" + messagesReqPathToolChoiceDisableParallel = "tool_choice.disable_parallel_tool_use" + messagesReqPathToolChoiceType = "tool_choice.type" + messagesReqPathTools = "tools" + + // Relative field names used within sub-objects. + messagesReqFieldContent = "content" + messagesReqFieldRole = "role" + messagesReqFieldText = "text" + messagesReqFieldToolUseID = "tool_use_id" + messagesReqFieldType = "type" +) + +var ( + constAny = string(constant.ValueOf[constant.Any]()) + constAuto = string(constant.ValueOf[constant.Auto]()) + constNone = string(constant.ValueOf[constant.None]()) + constText = string(constant.ValueOf[constant.Text]()) + constTool = string(constant.ValueOf[constant.Tool]()) + constToolResult = string(constant.ValueOf[constant.ToolResult]()) + constUser = string(anthropic.MessageParamRoleUser) +) + +// MessagesRequestPayload is raw JSON bytes of an Anthropic Messages API request. +// Methods provide package-specific reads and rewrites while preserving the +// original body for upstream pass-through. +type MessagesRequestPayload []byte + +func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) { + if len(bytes.TrimSpace(raw)) == 0 { + return nil, fmt.Errorf("messages empty request body") + } + if !json.Valid(raw) { + return nil, fmt.Errorf("messages invalid JSON request body") + } + + return MessagesRequestPayload(raw), nil +} + +func (p MessagesRequestPayload) Stream() bool { + return gjson.GetBytes(p, messagesReqPathStream).Bool() +} + +func (p MessagesRequestPayload) model() string { + return gjson.GetBytes(p, messagesReqPathModel).String() +} + +func (p MessagesRequestPayload) correlatingToolCallID() *string { + messages := gjson.GetBytes(p, messagesReqPathMessages) + if !messages.IsArray() { + return nil + } + + messageItems := messages.Array() + if len(messageItems) == 0 { + return nil + } + + content := messageItems[len(messageItems)-1].Get(messagesReqFieldContent) + if !content.IsArray() { + return nil + } + + contentItems := content.Array() + for idx := len(contentItems) - 1; idx >= 0; idx-- { + contentItem := contentItems[idx] + if contentItem.Get(messagesReqFieldType).String() != constToolResult { + continue + } + + toolUseID := contentItem.Get(messagesReqFieldToolUseID).String() + if toolUseID == "" { + continue + } + + return &toolUseID + } + + return nil +} + +// lastUserPrompt returns the prompt text from the last user message. If no prompt +// is found, it returns empty string, false, nil. Unexpected shapes are treated as +// unsupported and do not fail the request path. +func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) { + messages := gjson.GetBytes(p, messagesReqPathMessages) + if !messages.Exists() || messages.Type == gjson.Null { + return "", false, nil + } + if !messages.IsArray() { + return "", false, fmt.Errorf("unexpected messages type: %s", messages.Type) + } + + messageItems := messages.Array() + if len(messageItems) == 0 { + return "", false, nil + } + + lastMessage := messageItems[len(messageItems)-1] + if lastMessage.Get(messagesReqFieldRole).String() != constUser { + return "", false, nil + } + + content := lastMessage.Get(messagesReqFieldContent) + if !content.Exists() || content.Type == gjson.Null { + return "", false, nil + } + if content.Type == gjson.String { + return content.String(), true, nil + } + if !content.IsArray() { + return "", false, fmt.Errorf("unexpected message content type: %s", content.Type) + } + + contentItems := content.Array() + for idx := len(contentItems) - 1; idx >= 0; idx-- { + contentItem := contentItems[idx] + if contentItem.Get(messagesReqFieldType).String() != constText { + continue + } + + text := contentItem.Get(messagesReqFieldText) + if text.Type != gjson.String { + continue + } + + return text.String(), true, nil + } + + return "", false, nil +} + +func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam) (MessagesRequestPayload, error) { + if len(injected) == 0 { + return p, nil + } + + existing, err := p.tools() + if err != nil { + return p, err + } + + allTools := make([]any, 0, len(injected)+len(existing)) + for _, tool := range injected { + allTools = append(allTools, tool) + } + for _, tool := range existing { + allTools = append(allTools, tool) + } + + return p.set(messagesReqPathTools, allTools) +} + +func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPayload, error) { + toolChoice := gjson.GetBytes(p, messagesReqPathToolChoice) + + // If no tool_choice was defined, assume auto. + // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use. + if !toolChoice.Exists() || toolChoice.Type == gjson.Null { + updated, err := p.set(messagesReqPathToolChoiceType, constAuto) + if err != nil { + return p, err + } + return updated.set(messagesReqPathToolChoiceDisableParallel, true) + } + if !toolChoice.IsObject() { + return p, fmt.Errorf("unsupported tool_choice type: %s", toolChoice.Type) + } + + toolChoiceType := gjson.GetBytes(p, messagesReqPathToolChoiceType) + if toolChoiceType.Exists() && toolChoiceType.Type != gjson.String { + return p, fmt.Errorf("unsupported tool_choice.type type: %s", toolChoiceType.Type) + } + + switch toolChoiceType.String() { + case "": + updated, err := p.set(messagesReqPathToolChoiceType, constAuto) + if err != nil { + return p, err + } + return updated.set(messagesReqPathToolChoiceDisableParallel, true) + case constAuto, constAny, constTool: + return p.set(messagesReqPathToolChoiceDisableParallel, true) + case constNone: + return p, nil + default: + return p, fmt.Errorf("unsupported tool_choice.type value: %q", toolChoiceType.String()) + } +} + +func (p MessagesRequestPayload) appendedMessages(messages []anthropic.MessageParam) (MessagesRequestPayload, error) { + if len(messages) == 0 { + return p, nil + } + + existing, err := p.messages() + if err != nil { + return p, err + } + + allMessages := make([]any, 0, len(existing)+len(messages)) + allMessages = append(allMessages, existing...) + for _, message := range messages { + allMessages = append(allMessages, message) + } + + return p.set(messagesReqPathMessages, allMessages) +} + +func (p MessagesRequestPayload) withModel(model string) (MessagesRequestPayload, error) { + return p.set(messagesReqPathModel, model) +} + +func (p MessagesRequestPayload) messages() ([]any, error) { + messages := gjson.GetBytes(p, messagesReqPathMessages) + if !messages.Exists() || messages.Type == gjson.Null { + return []any{}, nil + } + if !messages.IsArray() { + return nil, fmt.Errorf("unsupported messages type: %s", messages.Type) + } + + messageItems := messages.Array() + existing := make([]any, 0, len(messageItems)) + for _, item := range messageItems { + existing = append(existing, json.RawMessage(item.Raw)) + } + + return existing, nil +} + +func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) { + tools := gjson.GetBytes(p, messagesReqPathTools) + if !tools.Exists() || tools.Type == gjson.Null { + return nil, nil + } + if !tools.IsArray() { + return nil, fmt.Errorf("unsupported tools type: %s", tools.Type) + } + + toolItems := tools.Array() + existing := make([]json.RawMessage, 0, len(toolItems)) + for _, item := range toolItems { + existing = append(existing, json.RawMessage(item.Raw)) + } + + return existing, nil +} + +func (p MessagesRequestPayload) set(path string, value any) (MessagesRequestPayload, error) { + out, err := sjson.SetBytes(p, path, value) + return MessagesRequestPayload(out), err +} diff --git a/intercept/messages/reqpayload_test.go b/intercept/messages/reqpayload_test.go new file mode 100644 index 00000000..56cb7e96 --- /dev/null +++ b/intercept/messages/reqpayload_test.go @@ -0,0 +1,343 @@ +package messages + +import ( + "testing" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/coder/aibridge/utils" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestNewMessagesRequestPayload(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody []byte + + expectError bool + }{ + { + name: "empty body", + requestBody: []byte(" \n\t "), + expectError: true, + }, + { + name: "invalid json", + requestBody: []byte(`{"model":`), + expectError: true, + }, + { + name: "valid json", + requestBody: []byte(`{"model":"claude-opus-4-5","max_tokens":1024}`), + expectError: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload, err := NewMessagesRequestPayload(testCase.requestBody) + if testCase.expectError { + require.Error(t, err) + require.Nil(t, payload) + return + } + + require.NoError(t, err) + require.Equal(t, MessagesRequestPayload(testCase.requestBody), payload) + }) + } +} + +func TestMessagesRequestPayloadStream(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedStream bool + }{ + { + name: "stream true", + requestBody: `{"stream":true}`, + expectedStream: true, + }, + { + name: "stream false", + requestBody: `{"stream":false}`, + expectedStream: false, + }, + { + name: "stream missing", + requestBody: `{"model":"claude-opus-4-5"}`, + expectedStream: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + require.Equal(t, testCase.expectedStream, payload.Stream()) + }) + } +} + +func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedPrompt string + + expectedFound bool + + expectError bool + }{ + { + name: "last user message string content", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`, + expectedPrompt: "hello", + expectedFound: true, + expectError: false, + }, + { + name: "last user message typed content returns last text block", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}},{"type":"text","text":"first"},{"type":"text","text":"last"}]}]}`, + expectedPrompt: "last", + expectedFound: true, + expectError: false, + }, + { + name: "last message not from user", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"assistant","content":"hello"}]}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "no messages key", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "empty messages array", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[]}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "last user message with empty content array", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[]}]}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "last user message with only non text content", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}},{"type":"image","source":{"type":"base64","media_type":"image/jpeg","data":"def"}}]}]}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "multiple messages with last being user", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"first"},{"role":"assistant","content":[{"type":"text","text":"response"}]},{"role":"user","content":"second"}]}`, + expectedPrompt: "second", + expectedFound: true, + expectError: false, + }, + { + name: "messages wrong type returns error", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":{}}`, + expectedPrompt: "", + expectedFound: false, + expectError: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + prompt, found, err := payload.lastUserPrompt() + if testCase.expectError { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, testCase.expectedFound, found) + require.Equal(t, testCase.expectedPrompt, prompt) + }) + } +} + +func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedToolUseID *string + }{ + { + name: "no tool result block", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`, + expectedToolUseID: nil, + }, + { + name: "returns last tool result from final message", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"},{"type":"tool_result","tool_use_id":"toolu_second","content":"second"}]}]}`, + expectedToolUseID: utils.PtrTo("toolu_second"), + }, + { + name: "ignores earlier message tool result", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"}]},{"role":"assistant","content":"done"}]}`, + expectedToolUseID: nil, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + require.Equal(t, testCase.expectedToolUseID, payload.correlatingToolCallID()) + }) + } +} + +func TestMessagesRequestPayloadInjectTools(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`) + + updatedPayload, err := payload.injectTools([]anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "injected_tool", + Type: anthropic.ToolTypeCustom, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: map[string]interface{}{}, + }, + }, + }, + }) + require.NoError(t, err) + + toolItems := gjson.GetBytes(updatedPayload, "tools").Array() + require.Len(t, toolItems, 2) + require.Equal(t, "injected_tool", toolItems[0].Get("name").String()) + require.Equal(t, "existing_tool", toolItems[1].Get("name").String()) + require.Equal(t, "ephemeral", toolItems[1].Get("cache_control.type").String()) +} + +func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedType string + + expectedDisableParallel *bool + }{ + { + name: "defaults to auto when missing", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024}`, + expectedType: string(constant.ValueOf[constant.Auto]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "auto gets disabled", + requestBody: `{"tool_choice":{"type":"auto"}}`, + expectedType: string(constant.ValueOf[constant.Auto]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "any gets disabled", + requestBody: `{"tool_choice":{"type":"any"}}`, + expectedType: string(constant.ValueOf[constant.Any]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "tool gets disabled", + requestBody: `{"tool_choice":{"type":"tool","name":"abc"}}`, + expectedType: string(constant.ValueOf[constant.Tool]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "none remains unchanged", + requestBody: `{"tool_choice":{"type":"none"}}`, + expectedType: string(constant.ValueOf[constant.None]()), + expectedDisableParallel: nil, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + updatedPayload, err := payload.disableParallelToolCalls() + require.NoError(t, err) + + toolChoice := gjson.GetBytes(updatedPayload, "tool_choice") + require.Equal(t, testCase.expectedType, toolChoice.Get("type").String()) + + disableParallelResult := toolChoice.Get("disable_parallel_tool_use") + if testCase.expectedDisableParallel == nil { + require.False(t, disableParallelResult.Exists()) + return + } + + require.True(t, disableParallelResult.Exists()) + require.Equal(t, *testCase.expectedDisableParallel, disableParallelResult.Bool()) + }) + } +} + +func TestMessagesRequestPayloadAppendedMessages(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`) + + updatedPayload, err := payload.appendedMessages([]anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleAssistant, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock("assistant response"), + }, + }, + anthropic.NewUserMessage(anthropic.NewToolResultBlock("toolu_123", "tool output", false)), + }) + require.NoError(t, err) + + messageItems := gjson.GetBytes(updatedPayload, "messages").Array() + require.Len(t, messageItems, 3) + require.Equal(t, "hello", messageItems[0].Get("content").String()) + require.Equal(t, "assistant", messageItems[1].Get("role").String()) + require.Equal(t, "assistant response", messageItems[1].Get("content.0.text").String()) + require.Equal(t, "tool_result", messageItems[2].Get("content.0.type").String()) + require.Equal(t, "toolu_123", messageItems[2].Get("content.0.tool_use_id").String()) +} diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index b2900a75..d317c55b 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -36,8 +36,7 @@ type StreamingInterception struct { func NewStreamingInterceptor( id uuid.UUID, - req *MessageNewParamsWrapper, - payload []byte, + reqPayload MessagesRequestPayload, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, @@ -46,8 +45,7 @@ func NewStreamingInterceptor( ) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ id: id, - req: req, - payload: payload, + reqPayload: reqPayload, cfg: cfg, bedrockCfg: bedrockCfg, clientHeaders: clientHeaders, @@ -88,8 +86,8 @@ func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.Key // results relayed to the SERVER. The response from the server will be handled synchronously, and this loop // can continue until all injected tool invocations are completed and the response is relayed to the client. func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { - if i.req == nil { - return fmt.Errorf("developer error: req is nil") + if len(i.reqPayload) == 0 { + return fmt.Errorf("developer error: request payload is empty") } ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) @@ -100,16 +98,17 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. - logger := i.logger.With(slog.F("model", i.req.Model)) + logger := i.logger.With(slog.F("model", i.Model())) var ( - prompt *string - err error + prompt string + promptFound bool + err error ) // Claude Code uses a "small/fast model" for certain tasks. if !i.isSmallFastModel() { - prompt, err = i.req.lastUserPrompt() + prompt, promptFound, err = i.reqPayload.lastUserPrompt() if err != nil { logger.Warn(ctx, "failed to determine last user prompt", slog.Error(err)) } @@ -142,8 +141,6 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() - messages := i.req.MessageNewParams - // Accumulate usage across the entire streaming interaction (including tool reinvocations). var cumulativeUsage anthropic.Usage @@ -159,7 +156,7 @@ newStream: break } - stream := i.newStream(streamCtx, svc, messages) + stream := i.newStream(streamCtx, svc) var message anthropic.Message var lastToolName string @@ -278,7 +275,8 @@ newStream: // Process injected tools. if len(pendingToolCalls) > 0 { // Append the whole message from this stream as context since we'll be sending a new request with the tool results. - messages.Messages = append(messages.Messages, message.ToParam()) + var loopMessages []anthropic.MessageParam + loopMessages = append(loopMessages, message.ToParam()) for name, id := range pendingToolCalls { if i.mcpProxy == nil { @@ -332,7 +330,7 @@ newStream: if err != nil { // Always provide a tool_result even if the tool call failed - messages.Messages = append(messages.Messages, + loopMessages = append(loopMessages, anthropic.NewUserMessage(anthropic.NewToolResultBlock(id, fmt.Sprintf("Error calling tool: %v", err), true)), ) continue @@ -410,16 +408,18 @@ newStream: } if len(toolResult.OfToolResult.Content) > 0 { - messages.Messages = append(messages.Messages, anthropic.NewUserMessage(toolResult)) + loopMessages = append(loopMessages, anthropic.NewUserMessage(toolResult)) } } // Sync the raw payload with updated messages so that withBody() // sends the updated payload on the next iteration. - if syncErr := i.syncPayloadMessages(messages.Messages); syncErr != nil { + updatedPayload, syncErr := i.reqPayload.appendedMessages(loopMessages) + if syncErr != nil { lastErr = fmt.Errorf("sync payload for agentic loop: %w", syncErr) break } + i.reqPayload = updatedPayload // Causes a new stream to be run with updated messages. isFirst = false @@ -465,13 +465,14 @@ newStream: } } - if prompt != nil { + if promptFound { _ = i.recorder.RecordPromptUsage(ctx, &recorder.PromptUsageRecord{ InterceptionID: i.ID().String(), MsgID: message.ID, - Prompt: *prompt, + Prompt: prompt, }) - prompt = nil + prompt = "" + promptFound = false } if events.IsStreaming() { @@ -579,10 +580,10 @@ func (s *StreamingInterception) encodeForStream(payload []byte, typ string) []by return buf.Bytes() } -// newStream traces svc.NewStreaming(streamCtx, messages) -func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService, messages anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEventUnion] { +// newStream traces svc.NewStreaming() call. +func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] { _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() - return svc.NewStreaming(ctx, messages, s.withBody()) + return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, s.withBody()) } diff --git a/provider/anthropic.go b/provider/anthropic.go index 4a79cd42..f4380648 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -1,8 +1,6 @@ package provider import ( - "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -102,8 +100,9 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr if err != nil { return nil, fmt.Errorf("read body: %w", err) } - var req messages.MessageNewParamsWrapper - if err := json.NewDecoder(bytes.NewReader(payload)).Decode(&req); err != nil { + + reqPayload, err := messages.NewMessagesRequestPayload(payload) + if err != nil { return nil, fmt.Errorf("unmarshal request body: %w", err) } @@ -111,10 +110,10 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr cfg.ExtraHeaders = extractAnthropicHeaders(r) var interceptor intercept.Interceptor - if req.Stream { - interceptor = messages.NewStreamingInterceptor(id, &req, payload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) + if reqPayload.Stream() { + interceptor = messages.NewStreamingInterceptor(id, reqPayload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, &req, payload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) + interceptor = messages.NewBlockingInterceptor(id, reqPayload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil