diff --git a/intercept/messages/base.go b/intercept/messages/base.go index f2f0e70..f808302 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -264,21 +264,47 @@ func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibco } // augmentRequestForBedrock will change the model used for the request since AWS Bedrock doesn't support -// Anthropics' model names. +// Anthropics' model names. It also converts adaptive thinking to enabled with a budget for models that +// don't support adaptive thinking natively. func (i *interceptionBase) augmentRequestForBedrock() { if i.bedrockCfg == nil { return } - updated, err := i.reqPayload.withModel(i.Model()) + model := i.Model() + updated, err := i.reqPayload.withModel(model) if err != nil { i.logger.Warn(context.Background(), "failed to set model in request payload for Bedrock", slog.Error(err)) return } + i.reqPayload = updated + + if !bedrockModelSupportsAdaptiveThinking(model) { + updated, err = i.reqPayload.convertAdaptiveThinkingForBedrock() + if err != nil { + i.logger.Warn(context.Background(), "failed to convert adaptive thinking for Bedrock", slog.Error(err)) + return + } + i.reqPayload = updated + } + // Strip fields that Bedrock does not accept. + updated, err = i.reqPayload.removeUnsupportedBedrockFields() + if err != nil { + i.logger.Warn(context.Background(), "failed to remove unsupported fields for Bedrock", slog.Error(err)) + return + } i.reqPayload = updated } +// bedrockModelSupportsAdaptiveThinking returns true if the given Bedrock model ID +// supports the "adaptive" thinking type natively (i.e. Claude 4.6 models). +// See https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-adaptive-thinking.html +func bedrockModelSupportsAdaptiveThinking(model string) bool { + return strings.Contains(model, "anthropic.claude-opus-4-6") || + strings.Contains(model, "anthropic.claude-sonnet-4-6") +} + // writeUpstreamError marshals and writes a given error. func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *ErrorResponse) { if antErr == nil { diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index 3f25b6e..723a2b6 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -590,6 +590,124 @@ func TestInjectTools_ParallelToolCalls(t *testing.T) { }) } +func TestAugmentRequestForBedrock_AdaptiveThinking(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + + bedrockModel string + requestBody string + + expectThinkingType string + // expectBudgetTokens is the exact expected budget_tokens value. + // 0 means budget_tokens should not be present in the output. + expectBudgetTokens int64 + }{ + { + name: "non_4_6_model_with_adaptive_thinking_gets_converted", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"adaptive"},"messages":[]}`, + expectThinkingType: "enabled", + expectBudgetTokens: 8000, // 10000 * 0.8 (default/high effort) + }, + { + name: "non_4_6_model_with_adaptive_thinking_and_small_max_tokens_disables_thinking", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":1000,"thinking":{"type":"adaptive"},"messages":[]}`, + expectThinkingType: "disabled", // 1000 * 0.6 = 600, below 1024 minimum + }, + { + name: "opus_4_6_model_with_adaptive_thinking_is_not_converted", + bedrockModel: "anthropic.claude-opus-4-6-v1", + requestBody: `{"model":"claude-opus-4-6","max_tokens":10000,"thinking":{"type":"adaptive"},"messages":[]}`, + expectThinkingType: "adaptive", + expectBudgetTokens: 0, + }, + { + name: "sonnet_4_6_model_with_adaptive_thinking_is_not_converted", + bedrockModel: "anthropic.claude-sonnet-4-6", + requestBody: `{"model":"claude-sonnet-4-6","max_tokens":10000,"thinking":{"type":"adaptive"},"messages":[]}`, + expectThinkingType: "adaptive", + expectBudgetTokens: 0, + }, + { + name: "non_4_6_model_with_no_thinking_field_is_unchanged", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"messages":[]}`, + expectThinkingType: "", + expectBudgetTokens: 0, + }, + { + name: "non_4_6_model_with_enabled_thinking_is_unchanged", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":5000},"messages":[]}`, + expectThinkingType: "enabled", + expectBudgetTokens: 5000, // already set, not recalculated + }, + { + name: "non_4_6_model_with_output_config_strips_it_and_uses_effort_for_budget", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"adaptive"},"output_config":{"effort":"low"},"messages":[]}`, + expectThinkingType: "enabled", + expectBudgetTokens: 2000, // 10000 * 0.2 (low effort) + }, + { + name: "4_6_model_with_output_config_strips_it", + bedrockModel: "anthropic.claude-opus-4-6-v1", + requestBody: `{"model":"claude-opus-4-6","max_tokens":10000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"messages":[]}`, + expectThinkingType: "adaptive", + expectBudgetTokens: 0, + }, + { + name: "all_unsupported_fields_are_stripped", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"messages":[],"output_config":{"effort":"high"},"metadata":{"user_id":"u123"},"service_tier":"auto","container":"ctr_abc","inference_geo":"us","context_management":{"type":"auto"}}`, + expectThinkingType: "", + expectBudgetTokens: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + i := &interceptionBase{ + reqPayload: mustMessagesPayload(t, tc.requestBody), + bedrockCfg: &config.AWSBedrock{ + Model: tc.bedrockModel, + SmallFastModel: "anthropic.claude-haiku-3-5", + }, + logger: slog.Make(), + } + + i.augmentRequestForBedrock() + + thinkingType := gjson.GetBytes(i.reqPayload, "thinking.type") + if tc.expectThinkingType == "" { + require.False(t, thinkingType.Exists()) + } else { + require.Equal(t, tc.expectThinkingType, thinkingType.String()) + } + + budgetTokens := gjson.GetBytes(i.reqPayload, "thinking.budget_tokens") + if tc.expectBudgetTokens == 0 { + require.False(t, budgetTokens.Exists(), "budget_tokens should not be set") + } else { + require.Equal(t, tc.expectBudgetTokens, budgetTokens.Int()) + } + + // Model should always be set to the bedrock model. + require.Equal(t, tc.bedrockModel, gjson.GetBytes(i.reqPayload, "model").String()) + + // Unsupported fields should always be stripped for Bedrock. + for _, field := range bedrockUnsupportedFields { + require.False(t, gjson.GetBytes(i.reqPayload, field).Exists(), "%s should be removed for Bedrock", field) + } + }) + } +} + func mustMessagesPayload(t *testing.T, requestBody string) MessagesRequestPayload { t.Helper() diff --git a/intercept/messages/reqpayload.go b/intercept/messages/reqpayload.go index cefddd8..2cba26b 100644 --- a/intercept/messages/reqpayload.go +++ b/intercept/messages/reqpayload.go @@ -14,8 +14,19 @@ import ( const ( // Absolute JSON paths from the request root. messagesReqPathMessages = "messages" + messagesReqPathMaxTokens = "max_tokens" messagesReqPathModel = "model" + messagesReqPathOutputConfig = "output_config" + messagesReqPathOutputConfigEffort = "output_config.effort" + messagesReqPathMetadata = "metadata" + messagesReqPathServiceTier = "service_tier" + messagesReqPathContainer = "container" + messagesReqPathInferenceGeo = "inference_geo" + messagesReqPathContextManagement = "context_management" messagesReqPathStream = "stream" + messagesReqPathThinking = "thinking" + messagesReqPathThinkingBudgetTokens = "thinking.budget_tokens" + messagesReqPathThinkingType = "thinking.type" messagesReqPathToolChoice = "tool_choice" messagesReqPathToolChoiceDisableParallel = "tool_choice.disable_parallel_tool_use" messagesReqPathToolChoiceType = "tool_choice.type" @@ -29,6 +40,12 @@ const ( messagesReqFieldType = "type" ) +const ( + constAdaptive = "adaptive" + constDisabled = "disabled" + constEnabled = "enabled" +) + var ( constAny = string(constant.ValueOf[constant.Any]()) constAuto = string(constant.ValueOf[constant.Auto]()) @@ -37,6 +54,21 @@ var ( constTool = string(constant.ValueOf[constant.Tool]()) constToolResult = string(constant.ValueOf[constant.ToolResult]()) constUser = string(anthropic.MessageParamRoleUser) + + // bedrockUnsupportedFields are top-level fields present in the Anthropic Messages + // API that are absent from the Bedrock request body schema. Sending them results + // in a 400 "Extra inputs are not permitted" error. + // + // Anthropic API fields: https://platform.claude.com/docs/en/api/messages/create + // Bedrock request body: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html + bedrockUnsupportedFields = []string{ + messagesReqPathOutputConfig, // requires beta header 'effort-2025-11-24' + messagesReqPathMetadata, + messagesReqPathServiceTier, + messagesReqPathContainer, + messagesReqPathInferenceGeo, + messagesReqPathContextManagement, + } ) // MessagesRequestPayload is raw JSON bytes of an Anthropic Messages API request. @@ -265,7 +297,71 @@ func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) { return existing, nil } +// convertAdaptiveThinkingForBedrock converts thinking.type "adaptive" to "enabled" with a calculated budget_tokens +func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesRequestPayload, error) { + thinkingType := gjson.GetBytes(p, messagesReqPathThinkingType) + if thinkingType.String() != constAdaptive { + return p, nil + } + + maxTokens := gjson.GetBytes(p, messagesReqPathMaxTokens).Int() + if maxTokens <= 0 { + // max_tokens is required by messages API + return p, fmt.Errorf("max_tokens: field required") + } + + effort := gjson.GetBytes(p, messagesReqPathOutputConfigEffort).String() + + // Effort-to-ratio mapping adapted from OpenRouter: + // https://openrouter.ai/docs/guides/best-practices/reasoning-tokens#reasoning-effort-level + var ratio float64 + switch effort { + case "low": + ratio = 0.2 + case "medium": + ratio = 0.5 + case "max": + ratio = 0.95 + default: // "high" or absent (high is the default effort) + ratio = 0.8 + } + + // budget_tokens must be ≥ 1024 && < max_tokens. If the calculated budget + // doesn't meet the minimum, disable thinking entirely rather than forcing + // an artificially high budget that would starve the output. + // https://platform.claude.com/docs/en/api/messages/create#create.thinking + // https://platform.claude.com/docs/en/build-with-claude/extended-thinking#how-to-use-extended-thinking + budgetTokens := int64(float64(maxTokens) * ratio) + if budgetTokens < 1024 { + return p.set(messagesReqPathThinking, map[string]string{"type": constDisabled}) + } + + return p.set(messagesReqPathThinking, map[string]any{ + "type": constEnabled, + "budget_tokens": budgetTokens, + }) +} + +// removeUnsupportedBedrockFields strips all top-level fields that Bedrock does +// not support from the payload. +func (p MessagesRequestPayload) removeUnsupportedBedrockFields() (MessagesRequestPayload, error) { + result := p + for _, field := range bedrockUnsupportedFields { + var err error + result, err = result.delete(field) + if err != nil { + return p, fmt.Errorf("removing %q: %w", field, err) + } + } + return result, nil +} + func (p MessagesRequestPayload) set(path string, value any) (MessagesRequestPayload, error) { out, err := sjson.SetBytes(p, path, value) return MessagesRequestPayload(out), err } + +func (p MessagesRequestPayload) delete(path string) (MessagesRequestPayload, error) { + out, err := sjson.DeleteBytes(p, path) + return MessagesRequestPayload(out), err +} diff --git a/intercept/messages/reqpayload_test.go b/intercept/messages/reqpayload_test.go index 56cb7e9..eb41842 100644 --- a/intercept/messages/reqpayload_test.go +++ b/intercept/messages/reqpayload_test.go @@ -250,6 +250,76 @@ func TestMessagesRequestPayloadInjectTools(t *testing.T) { require.Equal(t, "ephemeral", toolItems[1].Get("cache_control.type").String()) } +func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedThinkingType string + expectedBudgetTokens int64 + expectError bool + }{ + { + name: "no_thinking_field_is_no_op", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"messages":[]}`, + expectedThinkingType: "", + }, + { + name: "non_adaptive_thinking_type_is_no_op", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":5000},"messages":[]}`, + expectedThinkingType: "enabled", + expectedBudgetTokens: 5000, + }, + { + name: "adaptive_with_no_effort_defaults_to_80%", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"adaptive"},"messages":[]}`, + expectedThinkingType: "enabled", + expectedBudgetTokens: 8000, // 10000 * 0.8 (default/high effort) + }, + { + name: "adaptive_with_explicit_effort_uses_correct_percentage", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"adaptive"},"output_config":{"effort":"low"},"messages":[]}`, + expectedThinkingType: "enabled", + expectedBudgetTokens: 2000, // 10000 * 0.2 + }, + { + name: "adaptive_disables_thinking_when_budget_below_minimum", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":512,"thinking":{"type":"adaptive"},"messages":[]}`, + expectedThinkingType: "disabled", // 512 * 0.8 = 409, below 1024 minimum + }, + { + name: "adaptive_without_max_tokens_returns_error", + requestBody: `{"model":"claude-sonnet-4-5","thinking":{"type":"adaptive"},"messages":[]}`, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, tc.requestBody) + updatedPayload, err := payload.convertAdaptiveThinkingForBedrock() + if tc.expectError { + require.Error(t, err) + return + } + require.NoError(t, err) + + thinking := gjson.GetBytes(updatedPayload, messagesReqPathThinking) + require.NotEqual(t, tc.expectedThinkingType == "", thinking.Exists(), "thinking should not be set") + require.Equal(t, tc.expectedThinkingType, gjson.GetBytes(updatedPayload, messagesReqPathThinkingType).String()) // non existing field returns zero value + + budgetTokens := gjson.GetBytes(updatedPayload, messagesReqPathThinkingBudgetTokens) + require.NotEqual(t, tc.expectedBudgetTokens == 0, budgetTokens.Exists(), "budget_tokens should not be set") + require.Equal(t, tc.expectedBudgetTokens, budgetTokens.Int()) // non existing field returns zero value + }) + } +} + func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) { t.Parallel()