Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,21 +264,47 @@ func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibco
}

// augmentRequestForBedrock will change the model used for the request since AWS Bedrock doesn't support
// Anthropics' model names.
// Anthropics' model names. It also converts adaptive thinking to enabled with a budget for models that
// don't support adaptive thinking natively.
func (i *interceptionBase) augmentRequestForBedrock() {
if i.bedrockCfg == nil {
return
}

updated, err := i.reqPayload.withModel(i.Model())
model := i.Model()
updated, err := i.reqPayload.withModel(model)
if err != nil {
i.logger.Warn(context.Background(), "failed to set model in request payload for Bedrock", slog.Error(err))
return
}
i.reqPayload = updated

if !bedrockModelSupportsAdaptiveThinking(model) {
updated, err = i.reqPayload.convertAdaptiveThinkingForBedrock()
if err != nil {
i.logger.Warn(context.Background(), "failed to convert adaptive thinking for Bedrock", slog.Error(err))
return
}
i.reqPayload = updated
}

// Strip fields that Bedrock does not accept.
updated, err = i.reqPayload.removeUnsupportedBedrockFields()
if err != nil {
i.logger.Warn(context.Background(), "failed to remove unsupported fields for Bedrock", slog.Error(err))
return
}
i.reqPayload = updated
}

// bedrockModelSupportsAdaptiveThinking returns true if the given Bedrock model ID
// supports the "adaptive" thinking type natively (i.e. Claude 4.6 models).
// See https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-adaptive-thinking.html
func bedrockModelSupportsAdaptiveThinking(model string) bool {
return strings.Contains(model, "anthropic.claude-opus-4-6") ||
strings.Contains(model, "anthropic.claude-sonnet-4-6")
}

// writeUpstreamError marshals and writes a given error.
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *ErrorResponse) {
if antErr == nil {
Expand Down
118 changes: 118 additions & 0 deletions intercept/messages/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,124 @@ func TestInjectTools_ParallelToolCalls(t *testing.T) {
})
}

func TestAugmentRequestForBedrock_AdaptiveThinking(t *testing.T) {
t.Parallel()

tests := []struct {
name string

bedrockModel string
requestBody string

expectThinkingType string
// expectBudgetTokens is the exact expected budget_tokens value.
// 0 means budget_tokens should not be present in the output.
expectBudgetTokens int64
}{
{
name: "non_4_6_model_with_adaptive_thinking_gets_converted",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"adaptive"},"messages":[]}`,
expectThinkingType: "enabled",
expectBudgetTokens: 8000, // 10000 * 0.8 (default/high effort)
},
{
name: "non_4_6_model_with_adaptive_thinking_and_small_max_tokens_disables_thinking",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
requestBody: `{"model":"claude-sonnet-4-5","max_tokens":1000,"thinking":{"type":"adaptive"},"messages":[]}`,
expectThinkingType: "disabled", // 1000 * 0.6 = 600, below 1024 minimum
},
{
name: "opus_4_6_model_with_adaptive_thinking_is_not_converted",
bedrockModel: "anthropic.claude-opus-4-6-v1",
requestBody: `{"model":"claude-opus-4-6","max_tokens":10000,"thinking":{"type":"adaptive"},"messages":[]}`,
expectThinkingType: "adaptive",
expectBudgetTokens: 0,
},
{
name: "sonnet_4_6_model_with_adaptive_thinking_is_not_converted",
bedrockModel: "anthropic.claude-sonnet-4-6",
requestBody: `{"model":"claude-sonnet-4-6","max_tokens":10000,"thinking":{"type":"adaptive"},"messages":[]}`,
expectThinkingType: "adaptive",
expectBudgetTokens: 0,
},
{
name: "non_4_6_model_with_no_thinking_field_is_unchanged",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"messages":[]}`,
expectThinkingType: "",
expectBudgetTokens: 0,
},
{
name: "non_4_6_model_with_enabled_thinking_is_unchanged",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":5000},"messages":[]}`,
expectThinkingType: "enabled",
expectBudgetTokens: 5000, // already set, not recalculated
},
{
name: "non_4_6_model_with_output_config_strips_it_and_uses_effort_for_budget",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"adaptive"},"output_config":{"effort":"low"},"messages":[]}`,
expectThinkingType: "enabled",
expectBudgetTokens: 2000, // 10000 * 0.2 (low effort)
},
{
name: "4_6_model_with_output_config_strips_it",
bedrockModel: "anthropic.claude-opus-4-6-v1",
requestBody: `{"model":"claude-opus-4-6","max_tokens":10000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"messages":[]}`,
expectThinkingType: "adaptive",
expectBudgetTokens: 0,
},
{
name: "all_unsupported_fields_are_stripped",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"messages":[],"output_config":{"effort":"high"},"metadata":{"user_id":"u123"},"service_tier":"auto","container":"ctr_abc","inference_geo":"us","context_management":{"type":"auto"}}`,
expectThinkingType: "",
expectBudgetTokens: 0,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

i := &interceptionBase{
reqPayload: mustMessagesPayload(t, tc.requestBody),
bedrockCfg: &config.AWSBedrock{
Model: tc.bedrockModel,
SmallFastModel: "anthropic.claude-haiku-3-5",
},
logger: slog.Make(),
}

i.augmentRequestForBedrock()

thinkingType := gjson.GetBytes(i.reqPayload, "thinking.type")
if tc.expectThinkingType == "" {
require.False(t, thinkingType.Exists())
} else {
require.Equal(t, tc.expectThinkingType, thinkingType.String())
}

budgetTokens := gjson.GetBytes(i.reqPayload, "thinking.budget_tokens")
if tc.expectBudgetTokens == 0 {
require.False(t, budgetTokens.Exists(), "budget_tokens should not be set")
} else {
require.Equal(t, tc.expectBudgetTokens, budgetTokens.Int())
}

// Model should always be set to the bedrock model.
require.Equal(t, tc.bedrockModel, gjson.GetBytes(i.reqPayload, "model").String())

// Unsupported fields should always be stripped for Bedrock.
for _, field := range bedrockUnsupportedFields {
require.False(t, gjson.GetBytes(i.reqPayload, field).Exists(), "%s should be removed for Bedrock", field)
}
})
}
}

func mustMessagesPayload(t *testing.T, requestBody string) MessagesRequestPayload {
t.Helper()

Expand Down
96 changes: 96 additions & 0 deletions intercept/messages/reqpayload.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,19 @@ import (
const (
// Absolute JSON paths from the request root.
messagesReqPathMessages = "messages"
messagesReqPathMaxTokens = "max_tokens"
messagesReqPathModel = "model"
messagesReqPathOutputConfig = "output_config"
messagesReqPathOutputConfigEffort = "output_config.effort"
messagesReqPathMetadata = "metadata"
messagesReqPathServiceTier = "service_tier"
messagesReqPathContainer = "container"
messagesReqPathInferenceGeo = "inference_geo"
messagesReqPathContextManagement = "context_management"
messagesReqPathStream = "stream"
messagesReqPathThinking = "thinking"
messagesReqPathThinkingBudgetTokens = "thinking.budget_tokens"
messagesReqPathThinkingType = "thinking.type"
messagesReqPathToolChoice = "tool_choice"
messagesReqPathToolChoiceDisableParallel = "tool_choice.disable_parallel_tool_use"
messagesReqPathToolChoiceType = "tool_choice.type"
Expand All @@ -29,6 +40,12 @@ const (
messagesReqFieldType = "type"
)

const (
constAdaptive = "adaptive"
constDisabled = "disabled"
constEnabled = "enabled"
)

var (
constAny = string(constant.ValueOf[constant.Any]())
constAuto = string(constant.ValueOf[constant.Auto]())
Expand All @@ -37,6 +54,21 @@ var (
constTool = string(constant.ValueOf[constant.Tool]())
constToolResult = string(constant.ValueOf[constant.ToolResult]())
constUser = string(anthropic.MessageParamRoleUser)

// bedrockUnsupportedFields are top-level fields present in the Anthropic Messages
// API that are absent from the Bedrock request body schema. Sending them results
// in a 400 "Extra inputs are not permitted" error.
//
// Anthropic API fields: https://platform.claude.com/docs/en/api/messages/create
// Bedrock request body: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html
bedrockUnsupportedFields = []string{
messagesReqPathOutputConfig, // requires beta header 'effort-2025-11-24'
messagesReqPathMetadata,
messagesReqPathServiceTier,
messagesReqPathContainer,
messagesReqPathInferenceGeo,
messagesReqPathContextManagement,
}
)

// MessagesRequestPayload is raw JSON bytes of an Anthropic Messages API request.
Expand Down Expand Up @@ -265,7 +297,71 @@ func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) {
return existing, nil
}

// convertAdaptiveThinkingForBedrock converts thinking.type "adaptive" to "enabled" with a calculated budget_tokens
func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesRequestPayload, error) {
thinkingType := gjson.GetBytes(p, messagesReqPathThinkingType)
if thinkingType.String() != constAdaptive {
return p, nil
}

maxTokens := gjson.GetBytes(p, messagesReqPathMaxTokens).Int()
if maxTokens <= 0 {
// max_tokens is required by messages API
return p, fmt.Errorf("max_tokens: field required")
}

effort := gjson.GetBytes(p, messagesReqPathOutputConfigEffort).String()

// Effort-to-ratio mapping adapted from OpenRouter:
// https://openrouter.ai/docs/guides/best-practices/reasoning-tokens#reasoning-effort-level
var ratio float64
switch effort {
case "low":
ratio = 0.2
case "medium":
ratio = 0.5
case "max":
ratio = 0.95
default: // "high" or absent (high is the default effort)
ratio = 0.8
}

// budget_tokens must be ≥ 1024 && < max_tokens. If the calculated budget
// doesn't meet the minimum, disable thinking entirely rather than forcing
// an artificially high budget that would starve the output.
// https://platform.claude.com/docs/en/api/messages/create#create.thinking
// https://platform.claude.com/docs/en/build-with-claude/extended-thinking#how-to-use-extended-thinking
budgetTokens := int64(float64(maxTokens) * ratio)
if budgetTokens < 1024 {
return p.set(messagesReqPathThinking, map[string]string{"type": constDisabled})
}

return p.set(messagesReqPathThinking, map[string]any{
"type": constEnabled,
"budget_tokens": budgetTokens,
})
}

// removeUnsupportedBedrockFields strips all top-level fields that Bedrock does
// not support from the payload.
func (p MessagesRequestPayload) removeUnsupportedBedrockFields() (MessagesRequestPayload, error) {
result := p
for _, field := range bedrockUnsupportedFields {
var err error
result, err = result.delete(field)
if err != nil {
return p, fmt.Errorf("removing %q: %w", field, err)
}
}
return result, nil
}

func (p MessagesRequestPayload) set(path string, value any) (MessagesRequestPayload, error) {
out, err := sjson.SetBytes(p, path, value)
return MessagesRequestPayload(out), err
}

func (p MessagesRequestPayload) delete(path string) (MessagesRequestPayload, error) {
out, err := sjson.DeleteBytes(p, path)
return MessagesRequestPayload(out), err
}
70 changes: 70 additions & 0 deletions intercept/messages/reqpayload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,76 @@ func TestMessagesRequestPayloadInjectTools(t *testing.T) {
require.Equal(t, "ephemeral", toolItems[1].Get("cache_control.type").String())
}

func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) {
t.Parallel()

testCases := []struct {
name string

requestBody string

expectedThinkingType string
expectedBudgetTokens int64
expectError bool
}{
{
name: "no_thinking_field_is_no_op",
requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"messages":[]}`,
expectedThinkingType: "",
},
{
name: "non_adaptive_thinking_type_is_no_op",
requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":5000},"messages":[]}`,
expectedThinkingType: "enabled",
expectedBudgetTokens: 5000,
},
{
name: "adaptive_with_no_effort_defaults_to_80%",
requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"adaptive"},"messages":[]}`,
expectedThinkingType: "enabled",
expectedBudgetTokens: 8000, // 10000 * 0.8 (default/high effort)
},
{
name: "adaptive_with_explicit_effort_uses_correct_percentage",
requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"adaptive"},"output_config":{"effort":"low"},"messages":[]}`,
expectedThinkingType: "enabled",
expectedBudgetTokens: 2000, // 10000 * 0.2
},
{
name: "adaptive_disables_thinking_when_budget_below_minimum",
requestBody: `{"model":"claude-sonnet-4-5","max_tokens":512,"thinking":{"type":"adaptive"},"messages":[]}`,
expectedThinkingType: "disabled", // 512 * 0.8 = 409, below 1024 minimum
},
{
name: "adaptive_without_max_tokens_returns_error",
requestBody: `{"model":"claude-sonnet-4-5","thinking":{"type":"adaptive"},"messages":[]}`,
expectError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

payload := mustMessagesPayload(t, tc.requestBody)
updatedPayload, err := payload.convertAdaptiveThinkingForBedrock()
if tc.expectError {
require.Error(t, err)
return
}
require.NoError(t, err)

thinking := gjson.GetBytes(updatedPayload, messagesReqPathThinking)
require.NotEqual(t, tc.expectedThinkingType == "", thinking.Exists(), "thinking should not be set")
require.Equal(t, tc.expectedThinkingType, gjson.GetBytes(updatedPayload, messagesReqPathThinkingType).String()) // non existing field returns zero value

budgetTokens := gjson.GetBytes(updatedPayload, messagesReqPathThinkingBudgetTokens)
require.NotEqual(t, tc.expectedBudgetTokens == 0, budgetTokens.Exists(), "budget_tokens should not be set")
require.Equal(t, tc.expectedBudgetTokens, budgetTokens.Int()) // non existing field returns zero value
})
}
}

func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) {
t.Parallel()

Expand Down
Loading