From ee35e1a96b1c36956095036aef50546c26e0b194 Mon Sep 17 00:00:00 2001 From: Alin D'Silva Date: Sat, 7 Feb 2026 21:55:12 +0000 Subject: [PATCH 1/5] feat: add custom headers support for provider configs Add Headers map[string]string to ProviderConfig, allowing custom HTTP headers on provider definitions. Headers flow through ProviderOpts to the OpenAI client with env var expansion (${VAR_NAME} syntax). Includes: - ProviderConfig.Headers field in config schema (v3 and latest) - Headers wiring in applyProviderDefaults - OpenAI client: headers parsing, env expansion, auth middleware for custom providers without token_key - Schema normalization (normalizeUnionTypes) for gateway compatibility - Handle both map[string]string and map[interface{}]interface{} YAML types --- agent-schema.json | 13 +++ examples/custom_provider.yaml | 44 +++---- pkg/config/latest/types.go | 3 + pkg/config/v3/types.go | 3 + pkg/model/provider/custom_headers_test.go | 135 ++++++++++++++++++++++ pkg/model/provider/openai/client.go | 75 +++++++++++- pkg/model/provider/openai/schema.go | 55 +++++++++ pkg/model/provider/provider.go | 7 ++ 8 files changed, 302 insertions(+), 33 deletions(-) create mode 100644 pkg/model/provider/custom_headers_test.go diff --git a/agent-schema.json b/agent-schema.json index b106a8509..30a8ba955 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -106,6 +106,19 @@ "examples": [ "CUSTOM_PROVIDER_API_KEY" ] + }, + "headers": { + "type": "object", + "description": "Custom HTTP headers to include in requests. Header values can reference environment variables using ${VAR_NAME} syntax.", + "additionalProperties": { + "type": "string" + }, + "examples": [ + { + "cf-aig-authorization": "Bearer ${CLOUDFLARE_AI_GATEWAY_TOKEN}", + "x-custom-header": "value" + } + ] } }, "required": [ diff --git a/examples/custom_provider.yaml b/examples/custom_provider.yaml index f173c5535..0ba20742c 100644 --- a/examples/custom_provider.yaml +++ b/examples/custom_provider.yaml @@ -6,45 +6,31 @@ # Define custom providers with reusable configuration providers: - # Example: A custom OpenAI Chat Completions compatible API gateway - my_gateway: - api_type: openai_chatcompletions # Use the Chat Completions API schema - base_url: https://api.example.com/ - token_key: API_KEY_ENV_VAR_NAME # Environment variable containing the API token - # Example: A custom OpenAI Responses compatible API gateway - responses_provider: - api_type: openai_responses - base_url: https://responses.example.com/ - token_key: API_KEY_ENV_VAR_NAME + # Example: Cloudflare AI Gateway with custom headers + cloudflare_gateway: + api_type: openai_chatcompletions + base_url: https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/compat + token_key: GOOGLE_API_KEY # Standard Authorization header for provider auth + headers: + # Custom header for gateway authentication with environment variable expansion + cf-aig-authorization: Bearer ${CLOUDFLARE_AI_GATEWAY_TOKEN} # Define models that use the custom providers models: - # Model using the custom gateway provider - gateway_gpt4o: - provider: my_gateway - model: gpt-4o - max_tokens: 32768 - temperature: 0.7 - # Model using the responses provider - responses_model: - provider: responses_provider - model: gpt-5 - max_tokens: 16000 + # Model using Cloudflare AI Gateway with custom headers + gemini_via_cloudflare: + provider: cloudflare_gateway + model: google-ai-studio/gemini-3-flash-preview + max_tokens: 8000 + temperature: 0.7 # Define agents that use the models agents: root: - model: responses_model + model: gemini_via_cloudflare description: Main assistant using the custom gateway instruction: | You are a helpful AI assistant. Be concise and helpful in your responses. - # Example using shorthand syntax: provider_name/model_name - # The provider defaults (base_url, token_key, api_type) are automatically applied - subagent: - model: my_gateway/gpt-4o-mini - description: Sub-agent for specialized tasks - instruction: | - You are a specialized assistant for specific tasks. diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index e3ff0dc6c..9da7b1641 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -135,6 +135,9 @@ type ProviderConfig struct { BaseURL string `json:"base_url"` // TokenKey is the environment variable name containing the API token TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + // Header values can reference environment variables using ${VAR_NAME} syntax. + Headers map[string]string `json:"headers,omitempty"` } // FallbackConfig represents fallback model configuration for an agent. diff --git a/pkg/config/v3/types.go b/pkg/config/v3/types.go index 1efcec154..90cb2a449 100644 --- a/pkg/config/v3/types.go +++ b/pkg/config/v3/types.go @@ -34,6 +34,9 @@ type ProviderConfig struct { BaseURL string `json:"base_url"` // TokenKey is the environment variable name containing the API token TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + // Header values can reference environment variables using ${VAR_NAME} syntax. + Headers map[string]string `json:"headers,omitempty"` } // AgentConfig represents a single agent configuration diff --git a/pkg/model/provider/custom_headers_test.go b/pkg/model/provider/custom_headers_test.go new file mode 100644 index 000000000..33133a0fb --- /dev/null +++ b/pkg/model/provider/custom_headers_test.go @@ -0,0 +1,135 @@ +package provider + +import ( + "testing" + + "github.com/docker/cagent/pkg/config/latest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplyProviderDefaults_WithHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + providerName string + providerCfg latest.ProviderConfig + modelCfg latest.ModelConfig + expectedHeaders map[string]string + headersInOpts bool + }{ + { + name: "custom provider with headers", + providerName: "custom", + providerCfg: latest.ProviderConfig{ + BaseURL: "https://gateway.example.com/v1", + Headers: map[string]string{ + "cf-aig-authorization": "Bearer token123", + "x-custom-header": "value", + }, + }, + modelCfg: latest.ModelConfig{ + Provider: "custom", + Model: "gpt-4o", + }, + expectedHeaders: map[string]string{ + "cf-aig-authorization": "Bearer token123", + "x-custom-header": "value", + }, + headersInOpts: true, + }, + { + name: "custom provider without headers", + providerName: "custom", + providerCfg: latest.ProviderConfig{ + BaseURL: "https://api.example.com/v1", + }, + modelCfg: latest.ModelConfig{ + Provider: "custom", + Model: "gpt-4o", + }, + headersInOpts: false, + }, + { + name: "custom provider with empty headers", + providerName: "custom", + providerCfg: latest.ProviderConfig{ + BaseURL: "https://api.example.com/v1", + Headers: map[string]string{}, + }, + modelCfg: latest.ModelConfig{ + Provider: "custom", + Model: "gpt-4o", + }, + headersInOpts: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + providers := map[string]latest.ProviderConfig{ + tt.providerName: tt.providerCfg, + } + + result := applyProviderDefaults(&tt.modelCfg, providers) + require.NotNil(t, result) + + if tt.headersInOpts { + require.NotNil(t, result.ProviderOpts, "ProviderOpts should not be nil") + headers, ok := result.ProviderOpts["headers"] + require.True(t, ok, "headers should be in ProviderOpts") + + headerMap, ok := headers.(map[string]string) + require.True(t, ok, "headers should be map[string]string") + assert.Equal(t, tt.expectedHeaders, headerMap, "headers should match") + } else { + if result.ProviderOpts != nil { + _, hasHeaders := result.ProviderOpts["headers"] + assert.False(t, hasHeaders, "headers should not be in ProviderOpts") + } + } + }) + } +} + +func TestApplyProviderDefaults_HeadersDoNotOverrideExisting(t *testing.T) { + t.Parallel() + + providerCfg := latest.ProviderConfig{ + BaseURL: "https://gateway.example.com/v1", + Headers: map[string]string{ + "x-provider-header": "from-provider", + }, + } + + modelCfg := latest.ModelConfig{ + Provider: "custom", + Model: "gpt-4o", + ProviderOpts: map[string]any{ + "headers": map[string]string{ + "x-model-header": "from-model", + }, + }, + } + + providers := map[string]latest.ProviderConfig{ + "custom": providerCfg, + } + + result := applyProviderDefaults(&modelCfg, providers) + require.NotNil(t, result) + + // Model config's headers should take precedence (not be overwritten) + require.NotNil(t, result.ProviderOpts) + headers, ok := result.ProviderOpts["headers"] + require.True(t, ok) + + headerMap, ok := headers.(map[string]string) + require.True(t, ok) + + // Should have model's header, not provider's header + assert.Equal(t, map[string]string{"x-model-header": "from-model"}, headerMap) +} diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 98ac96c9c..4ee7512e6 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "log/slog" + "net/http" "net/url" "strings" @@ -58,11 +59,23 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro return nil, fmt.Errorf("%s environment variable is required", cfg.TokenKey) } clientOptions = append(clientOptions, option.WithAPIKey(authToken)) - } else if isCustomProvider(cfg) { - // Custom provider (has api_type in ProviderOpts) without token_key - no auth - slog.Debug("Custom provider with no token_key, sending requests without authentication", + } else if !isCustomProvider(cfg) { + // Not a custom provider - use default OpenAI behavior (OPENAI_API_KEY from env) + // The OpenAI SDK will automatically look for OPENAI_API_KEY if no key is set + } else { + // Custom provider without token_key - prevent SDK from using OPENAI_API_KEY env var + // We need to explicitly set the API key to prevent the SDK from reading OPENAI_API_KEY + // but we don't want to send an Authorization header. The SDK doesn't send the header + // if we use option.WithAPIKey with a specific marker value and then remove it via middleware. + slog.Debug("Custom provider with no token_key, disabling OpenAI SDK authentication", "provider", cfg.Provider, "base_url", cfg.BaseURL) - clientOptions = append(clientOptions, option.WithAPIKey("")) + + // Use a custom HTTP client that removes the Authorization header + clientOptions = append(clientOptions, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + // Remove Authorization header for custom providers without token_key + req.Header.Del("Authorization") + return next(req) + })) } // Otherwise let the OpenAI SDK use its default behavior (OPENAI_API_KEY from env) @@ -85,6 +98,60 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro clientOptions = append(clientOptions, option.WithBaseURL(cfg.BaseURL)) } + + // Apply custom headers from provider config if present + if cfg.ProviderOpts != nil { + if headers, exists := cfg.ProviderOpts["headers"]; exists { + // Handle both map[string]string and map[interface{}]interface{} from YAML parsing + headersMap := make(map[string]string) + + switch h := headers.(type) { + case map[string]string: + // Direct map[string]string - use as-is + headersMap = h + case map[interface{}]interface{}: + // YAML parsed as map[interface{}]interface{} - convert + for k, v := range h { + keyStr, okKey := k.(string) + valStr, okVal := v.(string) + if !okKey || !okVal { + slog.Error("Invalid header key/value type", + "key_type", fmt.Sprintf("%T", k), + "value_type", fmt.Sprintf("%T", v), + "provider", cfg.Provider) + return nil, fmt.Errorf("invalid header key/value type: key=%T, value=%T", k, v) + } + headersMap[keyStr] = valStr + } + default: + slog.Error("Invalid headers configuration - expected map[string]string or map[interface{}]interface{}", + "type", fmt.Sprintf("%T", headers), + "provider", cfg.Provider) + return nil, fmt.Errorf("invalid headers configuration: expected map[string]string, got %T", headers) + } + + if len(headersMap) > 0 { + slog.Debug("Applying custom headers", "count", len(headersMap), "provider", cfg.Provider) + for key, value := range headersMap { + // Expand environment variables in header values (e.g., ${VAR_NAME}) + expandedValue, err := environment.Expand(ctx, value, env) + if err != nil { + slog.Error("Failed to expand environment variable in header", + "header", key, + "value", value, + "error", err, + "provider", cfg.Provider) + return nil, fmt.Errorf("expanding header %s: %w", key, err) + } + clientOptions = append(clientOptions, option.WithHeader(key, expandedValue)) + slog.Debug("Applied custom header", + "header", key, + "provider", cfg.Provider) + } + } + } + } + httpClient := httpclient.NewHTTPClient() clientOptions = append(clientOptions, option.WithHTTPClient(httpClient)) diff --git a/pkg/model/provider/openai/schema.go b/pkg/model/provider/openai/schema.go index 28149ab3b..cb6513084 100644 --- a/pkg/model/provider/openai/schema.go +++ b/pkg/model/provider/openai/schema.go @@ -144,3 +144,58 @@ func fixSchemaArrayItems(schema shared.FunctionParameters) shared.FunctionParame return schema } + +// normalizeUnionTypes converts union types like ["array", "null"] back to simple types +// for compatibility with AI gateways that don't support JSON Schema union types. +// This is needed for Cloudflare AI Gateway and similar proxies. +func normalizeUnionTypes(schema shared.FunctionParameters) shared.FunctionParameters { + if schema == nil { + return schema + } + + // Convert union types at the current level + if typeArray, ok := schema["type"].([]any); ok { + if len(typeArray) == 2 { + // Find the non-null type + for _, t := range typeArray { + if tStr, ok := t.(string); ok && tStr != "null" { + schema["type"] = tStr + break + } + } + } + } else if typeArray, ok := schema["type"].([]string); ok { + if len(typeArray) == 2 { + // Find the non-null type + for _, t := range typeArray { + if t != "null" { + schema["type"] = t + break + } + } + } + } + + // Recursively handle properties + if propertiesValue, ok := schema["properties"]; ok { + if properties, ok := propertiesValue.(map[string]any); ok { + for _, propValue := range properties { + if prop, ok := propValue.(map[string]any); ok { + normalizeUnionTypes(prop) + } + } + } + } + + // Recursively handle items (for arrays) + if items, ok := schema["items"].(map[string]any); ok { + normalizeUnionTypes(items) + } + + // Recursively handle additionalProperties + if addProps, ok := schema["additionalProperties"].(map[string]any); ok { + normalizeUnionTypes(addProps) + } + + return schema +} diff --git a/pkg/model/provider/provider.go b/pkg/model/provider/provider.go index f43d49c1e..cf784ab6e 100644 --- a/pkg/model/provider/provider.go +++ b/pkg/model/provider/provider.go @@ -334,6 +334,13 @@ func applyProviderDefaults(cfg *latest.ModelConfig, customProviders map[string]l enhancedCfg.ProviderOpts["api_type"] = apiType } + // Copy custom headers from provider config if not already set + if _, hasHeaders := enhancedCfg.ProviderOpts["headers"]; !hasHeaders { + if len(providerCfg.Headers) > 0 { + enhancedCfg.ProviderOpts["headers"] = providerCfg.Headers + } + } + applyModelDefaults(&enhancedCfg) return &enhancedCfg } From 19147d981f6e318127d8aa67697cf4c1ce532b7d Mon Sep 17 00:00:00 2001 From: Alin D'Silva Date: Sat, 7 Feb 2026 21:55:18 +0000 Subject: [PATCH 2/5] fix: forward Start/Stop to inner toolsets in teamloader wrappers The filter, instructions, and toon toolset wrappers were not forwarding Start() and Stop() calls to their inner toolsets. This caused MCP tools to fail with 'toolset not started' errors in multi-agent configurations. --- pkg/teamloader/filter.go | 18 +++++ pkg/teamloader/filter_test.go | 102 ++++++++++++++++++++++++++++ pkg/teamloader/instructions.go | 17 +++++ pkg/teamloader/instructions_test.go | 54 +++++++++++++++ pkg/teamloader/toon.go | 16 +++++ pkg/teamloader/toon_test.go | 46 +++++++++++++ 6 files changed, 253 insertions(+) diff --git a/pkg/teamloader/filter.go b/pkg/teamloader/filter.go index fdbc717b7..3b6901292 100644 --- a/pkg/teamloader/filter.go +++ b/pkg/teamloader/filter.go @@ -53,6 +53,24 @@ func (f *filterTools) Unwrap() tools.ToolSet { return f.ToolSet } +// Start forwards the Start call to the inner toolset if it implements Startable. +// This is necessary because filterTools wraps toolsets (like MCP) that require +// initialization before their Tools() method can be called. +func (f *filterTools) Start(ctx context.Context) error { + if startable, ok := f.ToolSet.(tools.Startable); ok { + return startable.Start(ctx) + } + return nil +} + +// Stop forwards the Stop call to the inner toolset if it implements Startable. +func (f *filterTools) Stop(ctx context.Context) error { + if startable, ok := f.ToolSet.(tools.Startable); ok { + return startable.Stop(ctx) + } + return nil +} + // Instructions implements tools.Instructable by delegating to the inner toolset. func (f *filterTools) Instructions() string { return tools.GetInstructions(f.ToolSet) diff --git a/pkg/teamloader/filter_test.go b/pkg/teamloader/filter_test.go index 8e2727ea0..cc2ee3fa6 100644 --- a/pkg/teamloader/filter_test.go +++ b/pkg/teamloader/filter_test.go @@ -23,6 +23,23 @@ func (m *mockToolSet) Tools(ctx context.Context) ([]tools.Tool, error) { return nil, nil } +// startableToolSet is a mock that implements both ToolSet and Startable, +// like the real MCP toolset does. +type startableToolSet struct { + mockToolSet + started bool +} + +func (s *startableToolSet) Start(context.Context) error { + s.started = true + return nil +} + +func (s *startableToolSet) Stop(context.Context) error { + s.started = false + return nil +} + func TestWithToolsFilter_NilToolNames(t *testing.T) { inner := &mockToolSet{} @@ -167,3 +184,88 @@ func TestWithToolsFilter_NonInstructableInner(t *testing.T) { instructions := tools.GetInstructions(wrapped) assert.Empty(t, instructions) } + +func TestWithToolsFilter_ForwardsStartToStartableInner(t *testing.T) { + t.Parallel() + + inner := &startableToolSet{ + mockToolSet: mockToolSet{ + toolsFunc: func(context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "tool1"}, {Name: "tool2"}}, nil + }, + }, + } + + wrapped := WithToolsFilter(inner, "tool1") + + // Verify the inner toolset is not started yet + assert.False(t, inner.started) + + // The wrapped filterTools should satisfy Startable + startable, ok := wrapped.(tools.Startable) + require.True(t, ok, "filterTools should implement tools.Startable") + + // Start should forward to the inner toolset + err := startable.Start(t.Context()) + require.NoError(t, err) + assert.True(t, inner.started, "Start() should have been forwarded to inner toolset") + + // Stop should also forward + err = startable.Stop(t.Context()) + require.NoError(t, err) + assert.False(t, inner.started, "Stop() should have been forwarded to inner toolset") +} + +func TestWithToolsFilter_StartNoOpForNonStartableInner(t *testing.T) { + t.Parallel() + + inner := &mockToolSet{ + toolsFunc: func(context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "tool1"}}, nil + }, + } + + wrapped := WithToolsFilter(inner, "tool1") + + // Should still implement Startable + startable, ok := wrapped.(tools.Startable) + require.True(t, ok, "filterTools should implement tools.Startable") + + // Start/Stop should be no-ops without error + err := startable.Start(t.Context()) + require.NoError(t, err) + + err = startable.Stop(t.Context()) + require.NoError(t, err) +} + +func TestWithToolsFilter_StartableToolSetIntegration(t *testing.T) { + t.Parallel() + + // This test simulates the real wrapping: MCP → filterTools → StartableToolSet + inner := &startableToolSet{ + mockToolSet: mockToolSet{ + toolsFunc: func(context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "tool1"}, {Name: "tool2"}}, nil + }, + }, + } + + // Wrap in filterTools (like teamloader does) + filtered := WithToolsFilter(inner, "tool1") + + // Wrap in StartableToolSet (like agent.WithToolSets does) + startable := tools.NewStartable(filtered) + + // Start should propagate through: StartableToolSet → filterTools → startableToolSet + err := startable.Start(t.Context()) + require.NoError(t, err) + assert.True(t, startable.IsStarted(), "StartableToolSet should be started") + assert.True(t, inner.started, "Inner startable toolset should have been started") + + // Tools should work through the whole chain + result, err := startable.Tools(t.Context()) + require.NoError(t, err) + require.Len(t, result, 1) + assert.Equal(t, "tool1", result[0].Name) +} diff --git a/pkg/teamloader/instructions.go b/pkg/teamloader/instructions.go index 0bb396536..00957c056 100644 --- a/pkg/teamloader/instructions.go +++ b/pkg/teamloader/instructions.go @@ -1,6 +1,7 @@ package teamloader import ( + "context" "strings" "github.com/docker/docker-agent/pkg/tools" @@ -33,6 +34,22 @@ func (a *replaceInstruction) Unwrap() tools.ToolSet { return a.ToolSet } +// Start forwards the Start call to the inner toolset if it implements Startable. +func (a *replaceInstruction) Start(ctx context.Context) error { + if startable, ok := a.ToolSet.(tools.Startable); ok { + return startable.Start(ctx) + } + return nil +} + +// Stop forwards the Stop call to the inner toolset if it implements Startable. +func (a *replaceInstruction) Stop(ctx context.Context) error { + if startable, ok := a.ToolSet.(tools.Startable); ok { + return startable.Stop(ctx) + } + return nil +} + func (a *replaceInstruction) Instructions() string { original := tools.GetInstructions(a.ToolSet) return strings.Replace(a.instruction, "{ORIGINAL_INSTRUCTIONS}", original, 1) diff --git a/pkg/teamloader/instructions_test.go b/pkg/teamloader/instructions_test.go index 5c8e04e37..3d1075e56 100644 --- a/pkg/teamloader/instructions_test.go +++ b/pkg/teamloader/instructions_test.go @@ -1,9 +1,11 @@ package teamloader import ( + "context" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/tools" ) @@ -44,3 +46,55 @@ func TestWithInstructions_add(t *testing.T) { assert.Equal(t, "Existing instructions\nMore instructions", tools.GetInstructions(wrapped)) } + +type startableInstructableToolSet struct { + toolSet + started bool +} + +func (s *startableInstructableToolSet) Start(_ context.Context) error { + s.started = true + return nil +} + +func (s *startableInstructableToolSet) Stop(_ context.Context) error { + s.started = false + return nil +} + +func TestWithInstructions_ForwardsStartToStartableInner(t *testing.T) { + t.Parallel() + + inner := &startableInstructableToolSet{ + toolSet: toolSet{instruction: "test"}, + } + + wrapped := WithInstructions(inner, "New instructions") + + startable, ok := wrapped.(tools.Startable) + require.True(t, ok, "replaceInstruction should implement tools.Startable") + + err := startable.Start(t.Context()) + require.NoError(t, err) + assert.True(t, inner.started, "Start() should have been forwarded to inner toolset") + + err = startable.Stop(t.Context()) + require.NoError(t, err) + assert.False(t, inner.started, "Stop() should have been forwarded to inner toolset") +} + +func TestWithInstructions_StartNoOpForNonStartableInner(t *testing.T) { + t.Parallel() + + inner := &toolSet{instruction: "test"} + wrapped := WithInstructions(inner, "New instructions") + + startable, ok := wrapped.(tools.Startable) + require.True(t, ok, "replaceInstruction should implement tools.Startable") + + err := startable.Start(t.Context()) + require.NoError(t, err) + + err = startable.Stop(t.Context()) + require.NoError(t, err) +} diff --git a/pkg/teamloader/toon.go b/pkg/teamloader/toon.go index ec949cf0c..7e2c0e93d 100644 --- a/pkg/teamloader/toon.go +++ b/pkg/teamloader/toon.go @@ -19,6 +19,22 @@ type toonTools struct { // Verify interface compliance var _ tools.Unwrapper = (*toonTools)(nil) +// Start forwards the Start call to the inner toolset if it implements Startable. +func (f *toonTools) Start(ctx context.Context) error { + if startable, ok := f.ToolSet.(tools.Startable); ok { + return startable.Start(ctx) + } + return nil +} + +// Stop forwards the Stop call to the inner toolset if it implements Startable. +func (f *toonTools) Stop(ctx context.Context) error { + if startable, ok := f.ToolSet.(tools.Startable); ok { + return startable.Stop(ctx) + } + return nil +} + func (f *toonTools) Tools(ctx context.Context) ([]tools.Tool, error) { allTools, err := f.ToolSet.Tools(ctx) if err != nil { diff --git a/pkg/teamloader/toon_test.go b/pkg/teamloader/toon_test.go index 52607eb73..3eb399ac6 100644 --- a/pkg/teamloader/toon_test.go +++ b/pkg/teamloader/toon_test.go @@ -70,3 +70,49 @@ func TestToon(t *testing.T) { }) } } + +func TestWithToon_ForwardsStartToStartableInner(t *testing.T) { + t.Parallel() + + inner := &startableToolSet{ + mockToolSet: mockToolSet{ + toolsFunc: func(ctx context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "tool1"}}, nil + }, + }, + } + + wrapped := WithToon(inner, "tool1") + + startable, ok := wrapped.(tools.Startable) + require.True(t, ok, "toonTools should implement tools.Startable") + + err := startable.Start(t.Context()) + require.NoError(t, err) + assert.True(t, inner.started, "Start() should have been forwarded to inner toolset") + + err = startable.Stop(t.Context()) + require.NoError(t, err) + assert.False(t, inner.started, "Stop() should have been forwarded to inner toolset") +} + +func TestWithToon_StartNoOpForNonStartableInner(t *testing.T) { + t.Parallel() + + inner := &mockToolSet{ + toolsFunc: func(ctx context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "tool1"}}, nil + }, + } + + wrapped := WithToon(inner, "tool1") + + startable, ok := wrapped.(tools.Startable) + require.True(t, ok, "toonTools should implement tools.Startable") + + err := startable.Start(t.Context()) + require.NoError(t, err) + + err = startable.Stop(t.Context()) + require.NoError(t, err) +} From 14eb42df260eb5a028a874fd565b0ca9d5d6e37b Mon Sep 17 00:00:00 2001 From: Alin D'Silva Date: Sat, 7 Feb 2026 21:55:24 +0000 Subject: [PATCH 3/5] fix: normalize anyOf schemas and add API error response body logging - Convert anyOf patterns like {anyOf: [{type:string},{type:null}]} to {type:string} for compatibility with AI gateways (e.g. Cloudflare) that don't support anyOf in tool parameter schemas. - Log HTTP response body on non-2xx API errors for easier debugging. --- pkg/model/provider/openai/api_type_test.go | 5 ++- pkg/model/provider/openai/schema.go | 16 +++++++- pkg/model/provider/openai/schema_test.go | 47 ++++++++++++++++++++++ pkg/runtime/streaming.go | 10 +++++ 4 files changed, 75 insertions(+), 3 deletions(-) diff --git a/pkg/model/provider/openai/api_type_test.go b/pkg/model/provider/openai/api_type_test.go index 239b649fb..163e9e290 100644 --- a/pkg/model/provider/openai/api_type_test.go +++ b/pkg/model/provider/openai/api_type_test.go @@ -237,6 +237,7 @@ func TestCustomProvider_WithoutTokenKey(t *testing.T) { mu.Lock() defer mu.Unlock() - // SDK sends "Bearer" with empty key - that's effectively no auth - assert.Equal(t, "Bearer", receivedAuth, "Should send empty bearer token when no token_key") + // When no token_key is set, our middleware strips the Authorization header + // to prevent empty Bearer tokens from being sent to custom providers + assert.Equal(t, "", receivedAuth, "Should strip Authorization header when no token_key") } diff --git a/pkg/model/provider/openai/schema.go b/pkg/model/provider/openai/schema.go index cb6513084..2db906f33 100644 --- a/pkg/model/provider/openai/schema.go +++ b/pkg/model/provider/openai/schema.go @@ -16,7 +16,7 @@ func ConvertParametersToSchema(params any) (shared.FunctionParameters, error) { return nil, err } - return fixSchemaArrayItems(removeFormatFields(makeAllRequired(p))), nil + return normalizeUnionTypes(fixSchemaArrayItems(removeFormatFields(makeAllRequired(p)))), nil } // walkSchema calls fn on the given schema node, then recursively walks into @@ -176,6 +176,20 @@ func normalizeUnionTypes(schema shared.FunctionParameters) shared.FunctionParame } } + // Convert anyOf patterns like {"anyOf": [{"type":"string"},{"type":"null"}]} to {"type":"string"} + // This is needed for Gemini via Cloudflare which doesn't support anyOf in tool parameters. + if anyOf, ok := schema["anyOf"].([]any); ok { + for _, item := range anyOf { + if itemMap, ok := item.(map[string]any); ok { + if typStr, ok := itemMap["type"].(string); ok && typStr != "null" { + schema["type"] = typStr + delete(schema, "anyOf") + break + } + } + } + } + // Recursively handle properties if propertiesValue, ok := schema["properties"]; ok { if properties, ok := propertiesValue.(map[string]any); ok { diff --git a/pkg/model/provider/openai/schema_test.go b/pkg/model/provider/openai/schema_test.go index 65562b774..634d9ee11 100644 --- a/pkg/model/provider/openai/schema_test.go +++ b/pkg/model/provider/openai/schema_test.go @@ -390,3 +390,50 @@ func TestFixSchemaArrayItems(t *testing.T) { "type": "object" }`, string(buf)) } + +func TestNormalizeUnionTypes_AnyOfPattern(t *testing.T) { + t.Parallel() + + // Simulate the anyOf pattern from MCP tool schemas (e.g., Optional[str] in Python) + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "source": map[string]any{ + "anyOf": []any{ + map[string]any{"type": "string"}, + map[string]any{"type": "null"}, + }, + "default": nil, + "title": "Source", + }, + "days": map[string]any{ + "anyOf": []any{ + map[string]any{"type": "integer"}, + map[string]any{"type": "null"}, + }, + "default": nil, + "title": "Days", + }, + "name": map[string]any{ + "type": "string", + "title": "Name", + }, + }, + } + + result := normalizeUnionTypes(schema) + props := result["properties"].(map[string]any) + + // anyOf should be converted to simple type + source := props["source"].(map[string]any) + assert.Equal(t, "string", source["type"]) + assert.Nil(t, source["anyOf"], "anyOf should be removed after normalization") + + days := props["days"].(map[string]any) + assert.Equal(t, "integer", days["type"]) + assert.Nil(t, days["anyOf"], "anyOf should be removed after normalization") + + // Regular type should be unchanged + name := props["name"].(map[string]any) + assert.Equal(t, "string", name["type"]) +} diff --git a/pkg/runtime/streaming.go b/pkg/runtime/streaming.go index 62b132076..fc40788c8 100644 --- a/pkg/runtime/streaming.go +++ b/pkg/runtime/streaming.go @@ -8,6 +8,8 @@ import ( "log/slog" "strings" + "github.com/openai/openai-go/v3" + "github.com/docker/docker-agent/pkg/agent" "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/modelsdev" @@ -79,6 +81,14 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre break } if err != nil { + var apiErr *openai.Error + if errors.As(err, &apiErr) { + slog.Debug("Stream API error details", + "agent", a.Name(), + "status_code", apiErr.StatusCode, + "response_body", string(apiErr.DumpResponse(true)), + ) + } return streamResult{Stopped: true}, fmt.Errorf("error receiving from stream: %w", err) } From 11167d789a7f42383013dcb30db038acf9c542ba Mon Sep 17 00:00:00 2001 From: Alin D'Silva Date: Sat, 7 Feb 2026 21:55:33 +0000 Subject: [PATCH 4/5] feat: add custom headers and base_url env expansion to all providers Add custom headers support and ${VAR_NAME} expansion in base_url to the Gemini and Anthropic provider clients, matching the existing OpenAI client capability. Also add Headers field directly to ModelConfig for convenience (no separate providers section needed). - Gemini: read headers from ProviderOpts, expand env vars, set on genai.HTTPOptions; expand env vars in base_url - Anthropic: same pattern with option.WithHeader; expand env vars in base_url - ModelConfig.Headers: new field merged into ProviderOpts['headers'] with model-level taking precedence over provider-level - Updated JSON schema and config types (v3 + latest) --- agent-schema.json | 13 +++++++ pkg/config/gather.go | 31 +++++++++++++++ pkg/config/latest/types.go | 3 ++ pkg/config/v3/types.go | 2 + pkg/model/provider/anthropic/client.go | 37 +++++++++++++++++- pkg/model/provider/custom_headers_test.go | 2 +- pkg/model/provider/gemini/client.go | 47 ++++++++++++++++++++++- pkg/model/provider/provider.go | 29 +++++++++++++- pkg/model/provider/schema_test.go | 6 +-- 9 files changed, 163 insertions(+), 7 deletions(-) diff --git a/agent-schema.json b/agent-schema.json index 30a8ba955..8f89600eb 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -538,6 +538,19 @@ "type": "string", "description": "Token key for authentication" }, + "headers": { + "type": "object", + "description": "Custom HTTP headers to include in requests to this model's provider. Header values can reference environment variables using ${VAR_NAME} syntax.", + "additionalProperties": { + "type": "string" + }, + "examples": [ + { + "cf-aig-authorization": "Bearer ${CLOUDFLARE_AI_GATEWAY_TOKEN}", + "x-custom-header": "value" + } + ] + }, "provider_opts": { "type": "object", "description": "Provider-specific options. dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).", diff --git a/pkg/config/gather.go b/pkg/config/gather.go index a34119e2b..a05626073 100644 --- a/pkg/config/gather.go +++ b/pkg/config/gather.go @@ -6,6 +6,7 @@ import ( "fmt" "maps" "os" + "regexp" "slices" "strings" @@ -122,6 +123,36 @@ func addEnvVarsForModelConfig(model *latest.ModelConfig, customProviders map[str } } } + + // Gather env vars from headers (model-level and provider-level) + gatherEnvVarsFromHeaders(model.Headers, requiredEnv) + if customProviders != nil { + if provCfg, exists := customProviders[model.Provider]; exists { + gatherEnvVarsFromHeaders(provCfg.Headers, requiredEnv) + gatherEnvVarsFromString(provCfg.BaseURL, requiredEnv) + } + } +} + +// envVarPattern matches ${VAR} and $VAR references in strings. +var envVarPattern = regexp.MustCompile(`\$\{([^}]+)\}|\$([A-Za-z_][A-Za-z0-9_]*)`) + +// gatherEnvVarsFromHeaders extracts environment variable names referenced in header values. +func gatherEnvVarsFromHeaders(headers map[string]string, requiredEnv map[string]bool) { + for _, value := range headers { + gatherEnvVarsFromString(value, requiredEnv) + } +} + +// gatherEnvVarsFromString extracts environment variable names from a string containing $VAR or ${VAR}. +func gatherEnvVarsFromString(s string, requiredEnv map[string]bool) { + for _, match := range envVarPattern.FindAllStringSubmatch(s, -1) { + if match[1] != "" { + requiredEnv[match[1]] = true + } else if match[2] != "" { + requiredEnv[match[2]] = true + } + } } func GatherEnvVarsForTools(ctx context.Context, cfg *latest.Config) ([]string, error) { diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 9da7b1641..9cb71a13f 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -395,6 +395,9 @@ type ModelConfig struct { BaseURL string `json:"base_url,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests to this model's provider. + // Header values can reference environment variables using ${VAR_NAME} syntax. + Headers map[string]string `json:"headers,omitempty"` // ProviderOpts allows provider-specific options. ProviderOpts map[string]any `json:"provider_opts,omitempty"` TrackUsage *bool `json:"track_usage,omitempty"` diff --git a/pkg/config/v3/types.go b/pkg/config/v3/types.go index 90cb2a449..168a7560a 100644 --- a/pkg/config/v3/types.go +++ b/pkg/config/v3/types.go @@ -73,6 +73,8 @@ type ModelConfig struct { BaseURL string `json:"base_url,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + Headers map[string]string `json:"headers,omitempty"` // ProviderOpts allows provider-specific options. Currently used for "dmr" provider only. ProviderOpts map[string]any `json:"provider_opts,omitempty"` TrackUsage *bool `json:"track_usage,omitempty"` diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index 10e05b701..32c8e9575 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -153,8 +153,43 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro option.WithHTTPClient(httpclient.NewHTTPClient()), } if cfg.BaseURL != "" { - requestOptions = append(requestOptions, option.WithBaseURL(cfg.BaseURL)) + expandedBaseURL, err := environment.Expand(ctx, cfg.BaseURL, env) + if err != nil { + return nil, fmt.Errorf("expanding base_url: %w", err) + } + requestOptions = append(requestOptions, option.WithBaseURL(expandedBaseURL)) + } + + // Apply custom headers from provider config if present + if cfg.ProviderOpts != nil { + if headers, exists := cfg.ProviderOpts["headers"]; exists { + headersMap := make(map[string]string) + switch h := headers.(type) { + case map[string]string: + headersMap = h + case map[interface{}]interface{}: + for k, v := range h { + keyStr, okKey := k.(string) + valStr, okVal := v.(string) + if !okKey || !okVal { + return nil, fmt.Errorf("invalid header key/value type: key=%T, value=%T", k, v) + } + headersMap[keyStr] = valStr + } + default: + return nil, fmt.Errorf("invalid headers configuration: expected map[string]string, got %T", headers) + } + for key, value := range headersMap { + expandedValue, err := environment.Expand(ctx, value, env) + if err != nil { + return nil, fmt.Errorf("expanding header %s: %w", key, err) + } + requestOptions = append(requestOptions, option.WithHeader(key, expandedValue)) + slog.Debug("Applied custom header", "header", key, "provider", cfg.Provider) + } + } } + client := anthropic.NewClient(requestOptions...) anthropicClient.clientFn = func(context.Context) (anthropic.Client, error) { return client, nil diff --git a/pkg/model/provider/custom_headers_test.go b/pkg/model/provider/custom_headers_test.go index 33133a0fb..7e7058ca5 100644 --- a/pkg/model/provider/custom_headers_test.go +++ b/pkg/model/provider/custom_headers_test.go @@ -3,7 +3,7 @@ package provider import ( "testing" - "github.com/docker/cagent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 32746c406..1a2465f49 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -98,6 +98,50 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro httpClient = httpclient.NewHTTPClient() } + // Expand environment variables in base URL (e.g., ${VAR_NAME}) + baseURL := cfg.BaseURL + if baseURL != "" { + expanded, err := environment.Expand(ctx, baseURL, env) + if err != nil { + return nil, fmt.Errorf("expanding base_url: %w", err) + } + baseURL = expanded + } + + // Build custom headers from provider config + httpHeaders := make(http.Header) + if cfg.ProviderOpts != nil { + if headers, exists := cfg.ProviderOpts["headers"]; exists { + headersMap := make(map[string]string) + + switch h := headers.(type) { + case map[string]string: + headersMap = h + case map[interface{}]interface{}: + for k, v := range h { + keyStr, okKey := k.(string) + valStr, okVal := v.(string) + if !okKey || !okVal { + return nil, fmt.Errorf("invalid header key/value type: key=%T, value=%T", k, v) + } + headersMap[keyStr] = valStr + } + default: + return nil, fmt.Errorf("invalid headers configuration: expected map[string]string, got %T", headers) + } + + for key, value := range headersMap { + expandedValue, err := environment.Expand(ctx, value, env) + if err != nil { + return nil, fmt.Errorf("expanding header %q: %w", key, err) + } + httpHeaders.Set(key, expandedValue) + slog.Debug("Applied custom header", "header", key, "provider", cfg.Provider) + } + slog.Debug("Applying custom headers", "count", len(headersMap), "provider", cfg.Provider) + } + } + client, err := genai.NewClient(ctx, &genai.ClientConfig{ APIKey: apiKey, Project: project, @@ -105,7 +149,8 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro Backend: backend, HTTPClient: httpClient, HTTPOptions: genai.HTTPOptions{ - BaseURL: cfg.BaseURL, + BaseURL: baseURL, + Headers: httpHeaders, }, }) if err != nil { diff --git a/pkg/model/provider/provider.go b/pkg/model/provider/provider.go index cf784ab6e..c2d3305fb 100644 --- a/pkg/model/provider/provider.go +++ b/pkg/model/provider/provider.go @@ -334,13 +334,25 @@ func applyProviderDefaults(cfg *latest.ModelConfig, customProviders map[string]l enhancedCfg.ProviderOpts["api_type"] = apiType } - // Copy custom headers from provider config if not already set + // Copy custom headers from provider config if not already set in provider_opts if _, hasHeaders := enhancedCfg.ProviderOpts["headers"]; !hasHeaders { if len(providerCfg.Headers) > 0 { enhancedCfg.ProviderOpts["headers"] = providerCfg.Headers } } + // Merge model-level headers into provider_opts headers (model-level takes precedence) + if len(enhancedCfg.Headers) > 0 { + existing, _ := enhancedCfg.ProviderOpts["headers"].(map[string]string) + if existing == nil { + existing = make(map[string]string) + } + for k, v := range enhancedCfg.Headers { + existing[k] = v + } + enhancedCfg.ProviderOpts["headers"] = existing + } + applyModelDefaults(&enhancedCfg) return &enhancedCfg } @@ -358,6 +370,21 @@ func applyProviderDefaults(cfg *latest.ModelConfig, customProviders map[string]l } } + // Merge model-level headers into provider_opts for non-custom providers too + if len(enhancedCfg.Headers) > 0 { + if enhancedCfg.ProviderOpts == nil { + enhancedCfg.ProviderOpts = make(map[string]any) + } + existing, _ := enhancedCfg.ProviderOpts["headers"].(map[string]string) + if existing == nil { + existing = make(map[string]string) + } + for k, v := range enhancedCfg.Headers { + existing[k] = v + } + enhancedCfg.ProviderOpts["headers"] = existing + } + // Apply model-specific defaults (e.g., thinking budget for Claude/GPT models) applyModelDefaults(&enhancedCfg) return &enhancedCfg diff --git a/pkg/model/provider/schema_test.go b/pkg/model/provider/schema_test.go index e9150d123..b4125b7e1 100644 --- a/pkg/model/provider/schema_test.go +++ b/pkg/model/provider/schema_test.go @@ -220,20 +220,20 @@ func TestSchemaForOpenai(t *testing.T) { "direction": { "description": "Order", "enum": ["ASC", "DESC"], - "type": ["string", "null"] + "type": "string" }, "labels": { "description": "Filter", "items": { "type": "string" }, - "type": ["array", "null"] + "type": "array" }, "perPage": { "description": "Results", "maximum": 100, "minimum": 1, - "type": ["number", "null"] + "type": "number" }, "repo": { "description": "Repository", From 8d15048c4f47b4593eb87124b5a7c12d305c373e Mon Sep 17 00:00:00 2001 From: Alin D'Silva Date: Sun, 15 Mar 2026 20:41:05 +0000 Subject: [PATCH 5/5] fix: add Headers field to v4/v5/v6 config types and address PR review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add Headers field to ModelConfig and ProviderConfig in v4, v5, v6 config types so headers survive the JSON upgrade chain from v3 to latest. This was causing Cloudflare AI Gateway 401 errors because custom headers were silently dropped at the v3→v4 boundary. - Address Copilot review comments on PR #2108: 1. Remove response body from streaming error logs to avoid leaking sensitive data (pkg/runtime/streaming.go) 2. Deep-copy provider headers map before merging to avoid mutating shared config across models (pkg/model/provider/provider.go) 3. Gather env vars from model-level base_url in addition to provider base_url (pkg/config/gather.go) 4. Expand env vars in OpenAI/Azure base_url consistently with Anthropic/Gemini (pkg/model/provider/openai/client.go) 5. Redact header values from error logs to prevent credential leaks (pkg/model/provider/openai/client.go) 6. Tighten union type normalization to only collapse nullable patterns (exactly 2 options with one being null), preserving non-nullable unions (pkg/model/provider/openai/schema.go) --- pkg/config/gather.go | 3 +- pkg/config/v4/types.go | 5 +++ pkg/config/v5/types.go | 5 +++ pkg/config/v6/types.go | 5 +++ pkg/model/provider/custom_headers_test.go | 12 +++--- pkg/model/provider/openai/client.go | 13 ++++-- pkg/model/provider/openai/schema.go | 50 ++++++++++++++++------- pkg/model/provider/provider.go | 7 +++- pkg/runtime/streaming.go | 1 - 9 files changed, 75 insertions(+), 26 deletions(-) diff --git a/pkg/config/gather.go b/pkg/config/gather.go index a05626073..a4f0be5bb 100644 --- a/pkg/config/gather.go +++ b/pkg/config/gather.go @@ -124,8 +124,9 @@ func addEnvVarsForModelConfig(model *latest.ModelConfig, customProviders map[str } } - // Gather env vars from headers (model-level and provider-level) + // Gather env vars from headers (model-level and provider-level) and base URLs gatherEnvVarsFromHeaders(model.Headers, requiredEnv) + gatherEnvVarsFromString(model.BaseURL, requiredEnv) if customProviders != nil { if provCfg, exists := customProviders[model.Provider]; exists { gatherEnvVarsFromHeaders(provCfg.Headers, requiredEnv) diff --git a/pkg/config/v4/types.go b/pkg/config/v4/types.go index 548eda23b..78bc3d501 100644 --- a/pkg/config/v4/types.go +++ b/pkg/config/v4/types.go @@ -110,6 +110,9 @@ type ProviderConfig struct { BaseURL string `json:"base_url"` // TokenKey is the environment variable name containing the API token TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + // Header values can reference environment variables using ${VAR_NAME} syntax. + Headers map[string]string `json:"headers,omitempty"` } // FallbackConfig represents fallback model configuration for an agent. @@ -270,6 +273,8 @@ type ModelConfig struct { BaseURL string `json:"base_url,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + Headers map[string]string `json:"headers,omitempty"` // ProviderOpts allows provider-specific options. ProviderOpts map[string]any `json:"provider_opts,omitempty"` TrackUsage *bool `json:"track_usage,omitempty"` diff --git a/pkg/config/v5/types.go b/pkg/config/v5/types.go index bc810ce36..312683d8c 100644 --- a/pkg/config/v5/types.go +++ b/pkg/config/v5/types.go @@ -112,6 +112,9 @@ type ProviderConfig struct { BaseURL string `json:"base_url"` // TokenKey is the environment variable name containing the API token TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + // Header values can reference environment variables using ${VAR_NAME} syntax. + Headers map[string]string `json:"headers,omitempty"` } // FallbackConfig represents fallback model configuration for an agent. @@ -369,6 +372,8 @@ type ModelConfig struct { BaseURL string `json:"base_url,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + Headers map[string]string `json:"headers,omitempty"` // ProviderOpts allows provider-specific options. ProviderOpts map[string]any `json:"provider_opts,omitempty"` TrackUsage *bool `json:"track_usage,omitempty"` diff --git a/pkg/config/v6/types.go b/pkg/config/v6/types.go index 7caca6695..bd5b1a8c4 100644 --- a/pkg/config/v6/types.go +++ b/pkg/config/v6/types.go @@ -135,6 +135,9 @@ type ProviderConfig struct { BaseURL string `json:"base_url"` // TokenKey is the environment variable name containing the API token TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + // Header values can reference environment variables using ${VAR_NAME} syntax. + Headers map[string]string `json:"headers,omitempty"` } // FallbackConfig represents fallback model configuration for an agent. @@ -392,6 +395,8 @@ type ModelConfig struct { BaseURL string `json:"base_url,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + Headers map[string]string `json:"headers,omitempty"` // ProviderOpts allows provider-specific options. ProviderOpts map[string]any `json:"provider_opts,omitempty"` TrackUsage *bool `json:"track_usage,omitempty"` diff --git a/pkg/model/provider/custom_headers_test.go b/pkg/model/provider/custom_headers_test.go index 7e7058ca5..cd13843a9 100644 --- a/pkg/model/provider/custom_headers_test.go +++ b/pkg/model/provider/custom_headers_test.go @@ -12,12 +12,12 @@ func TestApplyProviderDefaults_WithHeaders(t *testing.T) { t.Parallel() tests := []struct { - name string - providerName string - providerCfg latest.ProviderConfig - modelCfg latest.ModelConfig - expectedHeaders map[string]string - headersInOpts bool + name string + providerName string + providerCfg latest.ProviderConfig + modelCfg latest.ModelConfig + expectedHeaders map[string]string + headersInOpts bool }{ { name: "custom provider with headers", diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 4ee7512e6..8db298ee4 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -82,7 +82,11 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro if cfg.Provider == "azure" { // Azure configuration if cfg.BaseURL != "" { - clientOptions = append(clientOptions, option.WithBaseURL(cfg.BaseURL)) + expandedBaseURL, err := environment.Expand(ctx, cfg.BaseURL, env) + if err != nil { + return nil, fmt.Errorf("expanding base_url: %w", err) + } + clientOptions = append(clientOptions, option.WithBaseURL(expandedBaseURL)) } // Azure API version from provider opts @@ -95,7 +99,11 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro } } } else if cfg.BaseURL != "" { - clientOptions = append(clientOptions, option.WithBaseURL(cfg.BaseURL)) + expandedBaseURL, err := environment.Expand(ctx, cfg.BaseURL, env) + if err != nil { + return nil, fmt.Errorf("expanding base_url: %w", err) + } + clientOptions = append(clientOptions, option.WithBaseURL(expandedBaseURL)) } @@ -138,7 +146,6 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro if err != nil { slog.Error("Failed to expand environment variable in header", "header", key, - "value", value, "error", err, "provider", cfg.Provider) return nil, fmt.Errorf("expanding header %s: %w", key, err) diff --git a/pkg/model/provider/openai/schema.go b/pkg/model/provider/openai/schema.go index 2db906f33..1e04c9d91 100644 --- a/pkg/model/provider/openai/schema.go +++ b/pkg/model/provider/openai/schema.go @@ -154,40 +154,62 @@ func normalizeUnionTypes(schema shared.FunctionParameters) shared.FunctionParame } // Convert union types at the current level + // Only normalize nullable patterns: exactly 2 types where one is "null" if typeArray, ok := schema["type"].([]any); ok { if len(typeArray) == 2 { - // Find the non-null type + var hasNull bool + var nonNullType string for _, t := range typeArray { - if tStr, ok := t.(string); ok && tStr != "null" { - schema["type"] = tStr - break + if tStr, ok := t.(string); ok { + if tStr == "null" { + hasNull = true + } else { + nonNullType = tStr + } } } + if hasNull && nonNullType != "" { + schema["type"] = nonNullType + } } } else if typeArray, ok := schema["type"].([]string); ok { if len(typeArray) == 2 { - // Find the non-null type + var hasNull bool + var nonNullType string for _, t := range typeArray { - if t != "null" { - schema["type"] = t - break + if t == "null" { + hasNull = true + } else { + nonNullType = t } } + if hasNull && nonNullType != "" { + schema["type"] = nonNullType + } } } - // Convert anyOf patterns like {"anyOf": [{"type":"string"},{"type":"null"}]} to {"type":"string"} + // Convert nullable anyOf patterns like {"anyOf": [{"type":"string"},{"type":"null"}]} to {"type":"string"} + // Only normalize when there are exactly 2 alternatives and one is {"type":"null"}. // This is needed for Gemini via Cloudflare which doesn't support anyOf in tool parameters. - if anyOf, ok := schema["anyOf"].([]any); ok { + if anyOf, ok := schema["anyOf"].([]any); ok && len(anyOf) == 2 { + hasNull := false + var nonNullType string for _, item := range anyOf { if itemMap, ok := item.(map[string]any); ok { - if typStr, ok := itemMap["type"].(string); ok && typStr != "null" { - schema["type"] = typStr - delete(schema, "anyOf") - break + if typStr, ok := itemMap["type"].(string); ok { + if typStr == "null" { + hasNull = true + } else { + nonNullType = typStr + } } } } + if hasNull && nonNullType != "" { + schema["type"] = nonNullType + delete(schema, "anyOf") + } } // Recursively handle properties diff --git a/pkg/model/provider/provider.go b/pkg/model/provider/provider.go index c2d3305fb..6aa967146 100644 --- a/pkg/model/provider/provider.go +++ b/pkg/model/provider/provider.go @@ -335,9 +335,14 @@ func applyProviderDefaults(cfg *latest.ModelConfig, customProviders map[string]l } // Copy custom headers from provider config if not already set in provider_opts + // Deep-copy the map to avoid mutating the shared provider config if _, hasHeaders := enhancedCfg.ProviderOpts["headers"]; !hasHeaders { if len(providerCfg.Headers) > 0 { - enhancedCfg.ProviderOpts["headers"] = providerCfg.Headers + headersCopy := make(map[string]string, len(providerCfg.Headers)) + for k, v := range providerCfg.Headers { + headersCopy[k] = v + } + enhancedCfg.ProviderOpts["headers"] = headersCopy } } diff --git a/pkg/runtime/streaming.go b/pkg/runtime/streaming.go index fc40788c8..4a9b9eca8 100644 --- a/pkg/runtime/streaming.go +++ b/pkg/runtime/streaming.go @@ -86,7 +86,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre slog.Debug("Stream API error details", "agent", a.Name(), "status_code", apiErr.StatusCode, - "response_body", string(apiErr.DumpResponse(true)), ) } return streamResult{Stopped: true}, fmt.Errorf("error receiving from stream: %w", err)