diff --git a/agent-schema.json b/agent-schema.json index b106a8509..8f89600eb 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -106,6 +106,19 @@ "examples": [ "CUSTOM_PROVIDER_API_KEY" ] + }, + "headers": { + "type": "object", + "description": "Custom HTTP headers to include in requests. Header values can reference environment variables using ${VAR_NAME} syntax.", + "additionalProperties": { + "type": "string" + }, + "examples": [ + { + "cf-aig-authorization": "Bearer ${CLOUDFLARE_AI_GATEWAY_TOKEN}", + "x-custom-header": "value" + } + ] } }, "required": [ @@ -525,6 +538,19 @@ "type": "string", "description": "Token key for authentication" }, + "headers": { + "type": "object", + "description": "Custom HTTP headers to include in requests to this model's provider. Header values can reference environment variables using ${VAR_NAME} syntax.", + "additionalProperties": { + "type": "string" + }, + "examples": [ + { + "cf-aig-authorization": "Bearer ${CLOUDFLARE_AI_GATEWAY_TOKEN}", + "x-custom-header": "value" + } + ] + }, "provider_opts": { "type": "object", "description": "Provider-specific options. dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).", diff --git a/examples/custom_provider.yaml b/examples/custom_provider.yaml index f173c5535..0ba20742c 100644 --- a/examples/custom_provider.yaml +++ b/examples/custom_provider.yaml @@ -6,45 +6,31 @@ # Define custom providers with reusable configuration providers: - # Example: A custom OpenAI Chat Completions compatible API gateway - my_gateway: - api_type: openai_chatcompletions # Use the Chat Completions API schema - base_url: https://api.example.com/ - token_key: API_KEY_ENV_VAR_NAME # Environment variable containing the API token - # Example: A custom OpenAI Responses compatible API gateway - responses_provider: - api_type: openai_responses - base_url: https://responses.example.com/ - token_key: API_KEY_ENV_VAR_NAME + # Example: Cloudflare AI Gateway with custom headers + cloudflare_gateway: + api_type: openai_chatcompletions + base_url: https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/compat + token_key: GOOGLE_API_KEY # Standard Authorization header for provider auth + headers: + # Custom header for gateway authentication with environment variable expansion + cf-aig-authorization: Bearer ${CLOUDFLARE_AI_GATEWAY_TOKEN} # Define models that use the custom providers models: - # Model using the custom gateway provider - gateway_gpt4o: - provider: my_gateway - model: gpt-4o - max_tokens: 32768 - temperature: 0.7 - # Model using the responses provider - responses_model: - provider: responses_provider - model: gpt-5 - max_tokens: 16000 + # Model using Cloudflare AI Gateway with custom headers + gemini_via_cloudflare: + provider: cloudflare_gateway + model: google-ai-studio/gemini-3-flash-preview + max_tokens: 8000 + temperature: 0.7 # Define agents that use the models agents: root: - model: responses_model + model: gemini_via_cloudflare description: Main assistant using the custom gateway instruction: | You are a helpful AI assistant. Be concise and helpful in your responses. - # Example using shorthand syntax: provider_name/model_name - # The provider defaults (base_url, token_key, api_type) are automatically applied - subagent: - model: my_gateway/gpt-4o-mini - description: Sub-agent for specialized tasks - instruction: | - You are a specialized assistant for specific tasks. diff --git a/pkg/config/gather.go b/pkg/config/gather.go index a34119e2b..a4f0be5bb 100644 --- a/pkg/config/gather.go +++ b/pkg/config/gather.go @@ -6,6 +6,7 @@ import ( "fmt" "maps" "os" + "regexp" "slices" "strings" @@ -122,6 +123,37 @@ func addEnvVarsForModelConfig(model *latest.ModelConfig, customProviders map[str } } } + + // Gather env vars from headers (model-level and provider-level) and base URLs + gatherEnvVarsFromHeaders(model.Headers, requiredEnv) + gatherEnvVarsFromString(model.BaseURL, requiredEnv) + if customProviders != nil { + if provCfg, exists := customProviders[model.Provider]; exists { + gatherEnvVarsFromHeaders(provCfg.Headers, requiredEnv) + gatherEnvVarsFromString(provCfg.BaseURL, requiredEnv) + } + } +} + +// envVarPattern matches ${VAR} and $VAR references in strings. +var envVarPattern = regexp.MustCompile(`\$\{([^}]+)\}|\$([A-Za-z_][A-Za-z0-9_]*)`) + +// gatherEnvVarsFromHeaders extracts environment variable names referenced in header values. +func gatherEnvVarsFromHeaders(headers map[string]string, requiredEnv map[string]bool) { + for _, value := range headers { + gatherEnvVarsFromString(value, requiredEnv) + } +} + +// gatherEnvVarsFromString extracts environment variable names from a string containing $VAR or ${VAR}. +func gatherEnvVarsFromString(s string, requiredEnv map[string]bool) { + for _, match := range envVarPattern.FindAllStringSubmatch(s, -1) { + if match[1] != "" { + requiredEnv[match[1]] = true + } else if match[2] != "" { + requiredEnv[match[2]] = true + } + } } func GatherEnvVarsForTools(ctx context.Context, cfg *latest.Config) ([]string, error) { diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index e3ff0dc6c..9cb71a13f 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -135,6 +135,9 @@ type ProviderConfig struct { BaseURL string `json:"base_url"` // TokenKey is the environment variable name containing the API token TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + // Header values can reference environment variables using ${VAR_NAME} syntax. + Headers map[string]string `json:"headers,omitempty"` } // FallbackConfig represents fallback model configuration for an agent. @@ -392,6 +395,9 @@ type ModelConfig struct { BaseURL string `json:"base_url,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests to this model's provider. + // Header values can reference environment variables using ${VAR_NAME} syntax. + Headers map[string]string `json:"headers,omitempty"` // ProviderOpts allows provider-specific options. ProviderOpts map[string]any `json:"provider_opts,omitempty"` TrackUsage *bool `json:"track_usage,omitempty"` diff --git a/pkg/config/v3/types.go b/pkg/config/v3/types.go index 1efcec154..168a7560a 100644 --- a/pkg/config/v3/types.go +++ b/pkg/config/v3/types.go @@ -34,6 +34,9 @@ type ProviderConfig struct { BaseURL string `json:"base_url"` // TokenKey is the environment variable name containing the API token TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + // Header values can reference environment variables using ${VAR_NAME} syntax. + Headers map[string]string `json:"headers,omitempty"` } // AgentConfig represents a single agent configuration @@ -70,6 +73,8 @@ type ModelConfig struct { BaseURL string `json:"base_url,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + Headers map[string]string `json:"headers,omitempty"` // ProviderOpts allows provider-specific options. Currently used for "dmr" provider only. ProviderOpts map[string]any `json:"provider_opts,omitempty"` TrackUsage *bool `json:"track_usage,omitempty"` diff --git a/pkg/config/v4/types.go b/pkg/config/v4/types.go index 548eda23b..78bc3d501 100644 --- a/pkg/config/v4/types.go +++ b/pkg/config/v4/types.go @@ -110,6 +110,9 @@ type ProviderConfig struct { BaseURL string `json:"base_url"` // TokenKey is the environment variable name containing the API token TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + // Header values can reference environment variables using ${VAR_NAME} syntax. + Headers map[string]string `json:"headers,omitempty"` } // FallbackConfig represents fallback model configuration for an agent. @@ -270,6 +273,8 @@ type ModelConfig struct { BaseURL string `json:"base_url,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + Headers map[string]string `json:"headers,omitempty"` // ProviderOpts allows provider-specific options. ProviderOpts map[string]any `json:"provider_opts,omitempty"` TrackUsage *bool `json:"track_usage,omitempty"` diff --git a/pkg/config/v5/types.go b/pkg/config/v5/types.go index bc810ce36..312683d8c 100644 --- a/pkg/config/v5/types.go +++ b/pkg/config/v5/types.go @@ -112,6 +112,9 @@ type ProviderConfig struct { BaseURL string `json:"base_url"` // TokenKey is the environment variable name containing the API token TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + // Header values can reference environment variables using ${VAR_NAME} syntax. + Headers map[string]string `json:"headers,omitempty"` } // FallbackConfig represents fallback model configuration for an agent. @@ -369,6 +372,8 @@ type ModelConfig struct { BaseURL string `json:"base_url,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + Headers map[string]string `json:"headers,omitempty"` // ProviderOpts allows provider-specific options. ProviderOpts map[string]any `json:"provider_opts,omitempty"` TrackUsage *bool `json:"track_usage,omitempty"` diff --git a/pkg/config/v6/types.go b/pkg/config/v6/types.go index 7caca6695..bd5b1a8c4 100644 --- a/pkg/config/v6/types.go +++ b/pkg/config/v6/types.go @@ -135,6 +135,9 @@ type ProviderConfig struct { BaseURL string `json:"base_url"` // TokenKey is the environment variable name containing the API token TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + // Header values can reference environment variables using ${VAR_NAME} syntax. + Headers map[string]string `json:"headers,omitempty"` } // FallbackConfig represents fallback model configuration for an agent. @@ -392,6 +395,8 @@ type ModelConfig struct { BaseURL string `json:"base_url,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` TokenKey string `json:"token_key,omitempty"` + // Headers allows custom HTTP headers to be included in requests. + Headers map[string]string `json:"headers,omitempty"` // ProviderOpts allows provider-specific options. ProviderOpts map[string]any `json:"provider_opts,omitempty"` TrackUsage *bool `json:"track_usage,omitempty"` diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index 10e05b701..32c8e9575 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -153,8 +153,43 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro option.WithHTTPClient(httpclient.NewHTTPClient()), } if cfg.BaseURL != "" { - requestOptions = append(requestOptions, option.WithBaseURL(cfg.BaseURL)) + expandedBaseURL, err := environment.Expand(ctx, cfg.BaseURL, env) + if err != nil { + return nil, fmt.Errorf("expanding base_url: %w", err) + } + requestOptions = append(requestOptions, option.WithBaseURL(expandedBaseURL)) + } + + // Apply custom headers from provider config if present + if cfg.ProviderOpts != nil { + if headers, exists := cfg.ProviderOpts["headers"]; exists { + headersMap := make(map[string]string) + switch h := headers.(type) { + case map[string]string: + headersMap = h + case map[interface{}]interface{}: + for k, v := range h { + keyStr, okKey := k.(string) + valStr, okVal := v.(string) + if !okKey || !okVal { + return nil, fmt.Errorf("invalid header key/value type: key=%T, value=%T", k, v) + } + headersMap[keyStr] = valStr + } + default: + return nil, fmt.Errorf("invalid headers configuration: expected map[string]string, got %T", headers) + } + for key, value := range headersMap { + expandedValue, err := environment.Expand(ctx, value, env) + if err != nil { + return nil, fmt.Errorf("expanding header %s: %w", key, err) + } + requestOptions = append(requestOptions, option.WithHeader(key, expandedValue)) + slog.Debug("Applied custom header", "header", key, "provider", cfg.Provider) + } + } } + client := anthropic.NewClient(requestOptions...) anthropicClient.clientFn = func(context.Context) (anthropic.Client, error) { return client, nil diff --git a/pkg/model/provider/custom_headers_test.go b/pkg/model/provider/custom_headers_test.go new file mode 100644 index 000000000..cd13843a9 --- /dev/null +++ b/pkg/model/provider/custom_headers_test.go @@ -0,0 +1,135 @@ +package provider + +import ( + "testing" + + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplyProviderDefaults_WithHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + providerName string + providerCfg latest.ProviderConfig + modelCfg latest.ModelConfig + expectedHeaders map[string]string + headersInOpts bool + }{ + { + name: "custom provider with headers", + providerName: "custom", + providerCfg: latest.ProviderConfig{ + BaseURL: "https://gateway.example.com/v1", + Headers: map[string]string{ + "cf-aig-authorization": "Bearer token123", + "x-custom-header": "value", + }, + }, + modelCfg: latest.ModelConfig{ + Provider: "custom", + Model: "gpt-4o", + }, + expectedHeaders: map[string]string{ + "cf-aig-authorization": "Bearer token123", + "x-custom-header": "value", + }, + headersInOpts: true, + }, + { + name: "custom provider without headers", + providerName: "custom", + providerCfg: latest.ProviderConfig{ + BaseURL: "https://api.example.com/v1", + }, + modelCfg: latest.ModelConfig{ + Provider: "custom", + Model: "gpt-4o", + }, + headersInOpts: false, + }, + { + name: "custom provider with empty headers", + providerName: "custom", + providerCfg: latest.ProviderConfig{ + BaseURL: "https://api.example.com/v1", + Headers: map[string]string{}, + }, + modelCfg: latest.ModelConfig{ + Provider: "custom", + Model: "gpt-4o", + }, + headersInOpts: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + providers := map[string]latest.ProviderConfig{ + tt.providerName: tt.providerCfg, + } + + result := applyProviderDefaults(&tt.modelCfg, providers) + require.NotNil(t, result) + + if tt.headersInOpts { + require.NotNil(t, result.ProviderOpts, "ProviderOpts should not be nil") + headers, ok := result.ProviderOpts["headers"] + require.True(t, ok, "headers should be in ProviderOpts") + + headerMap, ok := headers.(map[string]string) + require.True(t, ok, "headers should be map[string]string") + assert.Equal(t, tt.expectedHeaders, headerMap, "headers should match") + } else { + if result.ProviderOpts != nil { + _, hasHeaders := result.ProviderOpts["headers"] + assert.False(t, hasHeaders, "headers should not be in ProviderOpts") + } + } + }) + } +} + +func TestApplyProviderDefaults_HeadersDoNotOverrideExisting(t *testing.T) { + t.Parallel() + + providerCfg := latest.ProviderConfig{ + BaseURL: "https://gateway.example.com/v1", + Headers: map[string]string{ + "x-provider-header": "from-provider", + }, + } + + modelCfg := latest.ModelConfig{ + Provider: "custom", + Model: "gpt-4o", + ProviderOpts: map[string]any{ + "headers": map[string]string{ + "x-model-header": "from-model", + }, + }, + } + + providers := map[string]latest.ProviderConfig{ + "custom": providerCfg, + } + + result := applyProviderDefaults(&modelCfg, providers) + require.NotNil(t, result) + + // Model config's headers should take precedence (not be overwritten) + require.NotNil(t, result.ProviderOpts) + headers, ok := result.ProviderOpts["headers"] + require.True(t, ok) + + headerMap, ok := headers.(map[string]string) + require.True(t, ok) + + // Should have model's header, not provider's header + assert.Equal(t, map[string]string{"x-model-header": "from-model"}, headerMap) +} diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 32746c406..1a2465f49 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -98,6 +98,50 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro httpClient = httpclient.NewHTTPClient() } + // Expand environment variables in base URL (e.g., ${VAR_NAME}) + baseURL := cfg.BaseURL + if baseURL != "" { + expanded, err := environment.Expand(ctx, baseURL, env) + if err != nil { + return nil, fmt.Errorf("expanding base_url: %w", err) + } + baseURL = expanded + } + + // Build custom headers from provider config + httpHeaders := make(http.Header) + if cfg.ProviderOpts != nil { + if headers, exists := cfg.ProviderOpts["headers"]; exists { + headersMap := make(map[string]string) + + switch h := headers.(type) { + case map[string]string: + headersMap = h + case map[interface{}]interface{}: + for k, v := range h { + keyStr, okKey := k.(string) + valStr, okVal := v.(string) + if !okKey || !okVal { + return nil, fmt.Errorf("invalid header key/value type: key=%T, value=%T", k, v) + } + headersMap[keyStr] = valStr + } + default: + return nil, fmt.Errorf("invalid headers configuration: expected map[string]string, got %T", headers) + } + + for key, value := range headersMap { + expandedValue, err := environment.Expand(ctx, value, env) + if err != nil { + return nil, fmt.Errorf("expanding header %q: %w", key, err) + } + httpHeaders.Set(key, expandedValue) + slog.Debug("Applied custom header", "header", key, "provider", cfg.Provider) + } + slog.Debug("Applying custom headers", "count", len(headersMap), "provider", cfg.Provider) + } + } + client, err := genai.NewClient(ctx, &genai.ClientConfig{ APIKey: apiKey, Project: project, @@ -105,7 +149,8 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro Backend: backend, HTTPClient: httpClient, HTTPOptions: genai.HTTPOptions{ - BaseURL: cfg.BaseURL, + BaseURL: baseURL, + Headers: httpHeaders, }, }) if err != nil { diff --git a/pkg/model/provider/openai/api_type_test.go b/pkg/model/provider/openai/api_type_test.go index 239b649fb..163e9e290 100644 --- a/pkg/model/provider/openai/api_type_test.go +++ b/pkg/model/provider/openai/api_type_test.go @@ -237,6 +237,7 @@ func TestCustomProvider_WithoutTokenKey(t *testing.T) { mu.Lock() defer mu.Unlock() - // SDK sends "Bearer" with empty key - that's effectively no auth - assert.Equal(t, "Bearer", receivedAuth, "Should send empty bearer token when no token_key") + // When no token_key is set, our middleware strips the Authorization header + // to prevent empty Bearer tokens from being sent to custom providers + assert.Equal(t, "", receivedAuth, "Should strip Authorization header when no token_key") } diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 98ac96c9c..8db298ee4 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "log/slog" + "net/http" "net/url" "strings" @@ -58,18 +59,34 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro return nil, fmt.Errorf("%s environment variable is required", cfg.TokenKey) } clientOptions = append(clientOptions, option.WithAPIKey(authToken)) - } else if isCustomProvider(cfg) { - // Custom provider (has api_type in ProviderOpts) without token_key - no auth - slog.Debug("Custom provider with no token_key, sending requests without authentication", + } else if !isCustomProvider(cfg) { + // Not a custom provider - use default OpenAI behavior (OPENAI_API_KEY from env) + // The OpenAI SDK will automatically look for OPENAI_API_KEY if no key is set + } else { + // Custom provider without token_key - prevent SDK from using OPENAI_API_KEY env var + // We need to explicitly set the API key to prevent the SDK from reading OPENAI_API_KEY + // but we don't want to send an Authorization header. The SDK doesn't send the header + // if we use option.WithAPIKey with a specific marker value and then remove it via middleware. + slog.Debug("Custom provider with no token_key, disabling OpenAI SDK authentication", "provider", cfg.Provider, "base_url", cfg.BaseURL) - clientOptions = append(clientOptions, option.WithAPIKey("")) + + // Use a custom HTTP client that removes the Authorization header + clientOptions = append(clientOptions, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + // Remove Authorization header for custom providers without token_key + req.Header.Del("Authorization") + return next(req) + })) } // Otherwise let the OpenAI SDK use its default behavior (OPENAI_API_KEY from env) if cfg.Provider == "azure" { // Azure configuration if cfg.BaseURL != "" { - clientOptions = append(clientOptions, option.WithBaseURL(cfg.BaseURL)) + expandedBaseURL, err := environment.Expand(ctx, cfg.BaseURL, env) + if err != nil { + return nil, fmt.Errorf("expanding base_url: %w", err) + } + clientOptions = append(clientOptions, option.WithBaseURL(expandedBaseURL)) } // Azure API version from provider opts @@ -82,7 +99,64 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro } } } else if cfg.BaseURL != "" { - clientOptions = append(clientOptions, option.WithBaseURL(cfg.BaseURL)) + expandedBaseURL, err := environment.Expand(ctx, cfg.BaseURL, env) + if err != nil { + return nil, fmt.Errorf("expanding base_url: %w", err) + } + clientOptions = append(clientOptions, option.WithBaseURL(expandedBaseURL)) + } + + + // Apply custom headers from provider config if present + if cfg.ProviderOpts != nil { + if headers, exists := cfg.ProviderOpts["headers"]; exists { + // Handle both map[string]string and map[interface{}]interface{} from YAML parsing + headersMap := make(map[string]string) + + switch h := headers.(type) { + case map[string]string: + // Direct map[string]string - use as-is + headersMap = h + case map[interface{}]interface{}: + // YAML parsed as map[interface{}]interface{} - convert + for k, v := range h { + keyStr, okKey := k.(string) + valStr, okVal := v.(string) + if !okKey || !okVal { + slog.Error("Invalid header key/value type", + "key_type", fmt.Sprintf("%T", k), + "value_type", fmt.Sprintf("%T", v), + "provider", cfg.Provider) + return nil, fmt.Errorf("invalid header key/value type: key=%T, value=%T", k, v) + } + headersMap[keyStr] = valStr + } + default: + slog.Error("Invalid headers configuration - expected map[string]string or map[interface{}]interface{}", + "type", fmt.Sprintf("%T", headers), + "provider", cfg.Provider) + return nil, fmt.Errorf("invalid headers configuration: expected map[string]string, got %T", headers) + } + + if len(headersMap) > 0 { + slog.Debug("Applying custom headers", "count", len(headersMap), "provider", cfg.Provider) + for key, value := range headersMap { + // Expand environment variables in header values (e.g., ${VAR_NAME}) + expandedValue, err := environment.Expand(ctx, value, env) + if err != nil { + slog.Error("Failed to expand environment variable in header", + "header", key, + "error", err, + "provider", cfg.Provider) + return nil, fmt.Errorf("expanding header %s: %w", key, err) + } + clientOptions = append(clientOptions, option.WithHeader(key, expandedValue)) + slog.Debug("Applied custom header", + "header", key, + "provider", cfg.Provider) + } + } + } } httpClient := httpclient.NewHTTPClient() diff --git a/pkg/model/provider/openai/schema.go b/pkg/model/provider/openai/schema.go index 28149ab3b..1e04c9d91 100644 --- a/pkg/model/provider/openai/schema.go +++ b/pkg/model/provider/openai/schema.go @@ -16,7 +16,7 @@ func ConvertParametersToSchema(params any) (shared.FunctionParameters, error) { return nil, err } - return fixSchemaArrayItems(removeFormatFields(makeAllRequired(p))), nil + return normalizeUnionTypes(fixSchemaArrayItems(removeFormatFields(makeAllRequired(p)))), nil } // walkSchema calls fn on the given schema node, then recursively walks into @@ -144,3 +144,94 @@ func fixSchemaArrayItems(schema shared.FunctionParameters) shared.FunctionParame return schema } + +// normalizeUnionTypes converts union types like ["array", "null"] back to simple types +// for compatibility with AI gateways that don't support JSON Schema union types. +// This is needed for Cloudflare AI Gateway and similar proxies. +func normalizeUnionTypes(schema shared.FunctionParameters) shared.FunctionParameters { + if schema == nil { + return schema + } + + // Convert union types at the current level + // Only normalize nullable patterns: exactly 2 types where one is "null" + if typeArray, ok := schema["type"].([]any); ok { + if len(typeArray) == 2 { + var hasNull bool + var nonNullType string + for _, t := range typeArray { + if tStr, ok := t.(string); ok { + if tStr == "null" { + hasNull = true + } else { + nonNullType = tStr + } + } + } + if hasNull && nonNullType != "" { + schema["type"] = nonNullType + } + } + } else if typeArray, ok := schema["type"].([]string); ok { + if len(typeArray) == 2 { + var hasNull bool + var nonNullType string + for _, t := range typeArray { + if t == "null" { + hasNull = true + } else { + nonNullType = t + } + } + if hasNull && nonNullType != "" { + schema["type"] = nonNullType + } + } + } + + // Convert nullable anyOf patterns like {"anyOf": [{"type":"string"},{"type":"null"}]} to {"type":"string"} + // Only normalize when there are exactly 2 alternatives and one is {"type":"null"}. + // This is needed for Gemini via Cloudflare which doesn't support anyOf in tool parameters. + if anyOf, ok := schema["anyOf"].([]any); ok && len(anyOf) == 2 { + hasNull := false + var nonNullType string + for _, item := range anyOf { + if itemMap, ok := item.(map[string]any); ok { + if typStr, ok := itemMap["type"].(string); ok { + if typStr == "null" { + hasNull = true + } else { + nonNullType = typStr + } + } + } + } + if hasNull && nonNullType != "" { + schema["type"] = nonNullType + delete(schema, "anyOf") + } + } + + // Recursively handle properties + if propertiesValue, ok := schema["properties"]; ok { + if properties, ok := propertiesValue.(map[string]any); ok { + for _, propValue := range properties { + if prop, ok := propValue.(map[string]any); ok { + normalizeUnionTypes(prop) + } + } + } + } + + // Recursively handle items (for arrays) + if items, ok := schema["items"].(map[string]any); ok { + normalizeUnionTypes(items) + } + + // Recursively handle additionalProperties + if addProps, ok := schema["additionalProperties"].(map[string]any); ok { + normalizeUnionTypes(addProps) + } + + return schema +} diff --git a/pkg/model/provider/openai/schema_test.go b/pkg/model/provider/openai/schema_test.go index 65562b774..634d9ee11 100644 --- a/pkg/model/provider/openai/schema_test.go +++ b/pkg/model/provider/openai/schema_test.go @@ -390,3 +390,50 @@ func TestFixSchemaArrayItems(t *testing.T) { "type": "object" }`, string(buf)) } + +func TestNormalizeUnionTypes_AnyOfPattern(t *testing.T) { + t.Parallel() + + // Simulate the anyOf pattern from MCP tool schemas (e.g., Optional[str] in Python) + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "source": map[string]any{ + "anyOf": []any{ + map[string]any{"type": "string"}, + map[string]any{"type": "null"}, + }, + "default": nil, + "title": "Source", + }, + "days": map[string]any{ + "anyOf": []any{ + map[string]any{"type": "integer"}, + map[string]any{"type": "null"}, + }, + "default": nil, + "title": "Days", + }, + "name": map[string]any{ + "type": "string", + "title": "Name", + }, + }, + } + + result := normalizeUnionTypes(schema) + props := result["properties"].(map[string]any) + + // anyOf should be converted to simple type + source := props["source"].(map[string]any) + assert.Equal(t, "string", source["type"]) + assert.Nil(t, source["anyOf"], "anyOf should be removed after normalization") + + days := props["days"].(map[string]any) + assert.Equal(t, "integer", days["type"]) + assert.Nil(t, days["anyOf"], "anyOf should be removed after normalization") + + // Regular type should be unchanged + name := props["name"].(map[string]any) + assert.Equal(t, "string", name["type"]) +} diff --git a/pkg/model/provider/provider.go b/pkg/model/provider/provider.go index f43d49c1e..6aa967146 100644 --- a/pkg/model/provider/provider.go +++ b/pkg/model/provider/provider.go @@ -334,6 +334,30 @@ func applyProviderDefaults(cfg *latest.ModelConfig, customProviders map[string]l enhancedCfg.ProviderOpts["api_type"] = apiType } + // Copy custom headers from provider config if not already set in provider_opts + // Deep-copy the map to avoid mutating the shared provider config + if _, hasHeaders := enhancedCfg.ProviderOpts["headers"]; !hasHeaders { + if len(providerCfg.Headers) > 0 { + headersCopy := make(map[string]string, len(providerCfg.Headers)) + for k, v := range providerCfg.Headers { + headersCopy[k] = v + } + enhancedCfg.ProviderOpts["headers"] = headersCopy + } + } + + // Merge model-level headers into provider_opts headers (model-level takes precedence) + if len(enhancedCfg.Headers) > 0 { + existing, _ := enhancedCfg.ProviderOpts["headers"].(map[string]string) + if existing == nil { + existing = make(map[string]string) + } + for k, v := range enhancedCfg.Headers { + existing[k] = v + } + enhancedCfg.ProviderOpts["headers"] = existing + } + applyModelDefaults(&enhancedCfg) return &enhancedCfg } @@ -351,6 +375,21 @@ func applyProviderDefaults(cfg *latest.ModelConfig, customProviders map[string]l } } + // Merge model-level headers into provider_opts for non-custom providers too + if len(enhancedCfg.Headers) > 0 { + if enhancedCfg.ProviderOpts == nil { + enhancedCfg.ProviderOpts = make(map[string]any) + } + existing, _ := enhancedCfg.ProviderOpts["headers"].(map[string]string) + if existing == nil { + existing = make(map[string]string) + } + for k, v := range enhancedCfg.Headers { + existing[k] = v + } + enhancedCfg.ProviderOpts["headers"] = existing + } + // Apply model-specific defaults (e.g., thinking budget for Claude/GPT models) applyModelDefaults(&enhancedCfg) return &enhancedCfg diff --git a/pkg/model/provider/schema_test.go b/pkg/model/provider/schema_test.go index e9150d123..b4125b7e1 100644 --- a/pkg/model/provider/schema_test.go +++ b/pkg/model/provider/schema_test.go @@ -220,20 +220,20 @@ func TestSchemaForOpenai(t *testing.T) { "direction": { "description": "Order", "enum": ["ASC", "DESC"], - "type": ["string", "null"] + "type": "string" }, "labels": { "description": "Filter", "items": { "type": "string" }, - "type": ["array", "null"] + "type": "array" }, "perPage": { "description": "Results", "maximum": 100, "minimum": 1, - "type": ["number", "null"] + "type": "number" }, "repo": { "description": "Repository", diff --git a/pkg/runtime/streaming.go b/pkg/runtime/streaming.go index 62b132076..4a9b9eca8 100644 --- a/pkg/runtime/streaming.go +++ b/pkg/runtime/streaming.go @@ -8,6 +8,8 @@ import ( "log/slog" "strings" + "github.com/openai/openai-go/v3" + "github.com/docker/docker-agent/pkg/agent" "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/modelsdev" @@ -79,6 +81,13 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre break } if err != nil { + var apiErr *openai.Error + if errors.As(err, &apiErr) { + slog.Debug("Stream API error details", + "agent", a.Name(), + "status_code", apiErr.StatusCode, + ) + } return streamResult{Stopped: true}, fmt.Errorf("error receiving from stream: %w", err) } diff --git a/pkg/teamloader/filter.go b/pkg/teamloader/filter.go index fdbc717b7..3b6901292 100644 --- a/pkg/teamloader/filter.go +++ b/pkg/teamloader/filter.go @@ -53,6 +53,24 @@ func (f *filterTools) Unwrap() tools.ToolSet { return f.ToolSet } +// Start forwards the Start call to the inner toolset if it implements Startable. +// This is necessary because filterTools wraps toolsets (like MCP) that require +// initialization before their Tools() method can be called. +func (f *filterTools) Start(ctx context.Context) error { + if startable, ok := f.ToolSet.(tools.Startable); ok { + return startable.Start(ctx) + } + return nil +} + +// Stop forwards the Stop call to the inner toolset if it implements Startable. +func (f *filterTools) Stop(ctx context.Context) error { + if startable, ok := f.ToolSet.(tools.Startable); ok { + return startable.Stop(ctx) + } + return nil +} + // Instructions implements tools.Instructable by delegating to the inner toolset. func (f *filterTools) Instructions() string { return tools.GetInstructions(f.ToolSet) diff --git a/pkg/teamloader/filter_test.go b/pkg/teamloader/filter_test.go index 8e2727ea0..cc2ee3fa6 100644 --- a/pkg/teamloader/filter_test.go +++ b/pkg/teamloader/filter_test.go @@ -23,6 +23,23 @@ func (m *mockToolSet) Tools(ctx context.Context) ([]tools.Tool, error) { return nil, nil } +// startableToolSet is a mock that implements both ToolSet and Startable, +// like the real MCP toolset does. +type startableToolSet struct { + mockToolSet + started bool +} + +func (s *startableToolSet) Start(context.Context) error { + s.started = true + return nil +} + +func (s *startableToolSet) Stop(context.Context) error { + s.started = false + return nil +} + func TestWithToolsFilter_NilToolNames(t *testing.T) { inner := &mockToolSet{} @@ -167,3 +184,88 @@ func TestWithToolsFilter_NonInstructableInner(t *testing.T) { instructions := tools.GetInstructions(wrapped) assert.Empty(t, instructions) } + +func TestWithToolsFilter_ForwardsStartToStartableInner(t *testing.T) { + t.Parallel() + + inner := &startableToolSet{ + mockToolSet: mockToolSet{ + toolsFunc: func(context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "tool1"}, {Name: "tool2"}}, nil + }, + }, + } + + wrapped := WithToolsFilter(inner, "tool1") + + // Verify the inner toolset is not started yet + assert.False(t, inner.started) + + // The wrapped filterTools should satisfy Startable + startable, ok := wrapped.(tools.Startable) + require.True(t, ok, "filterTools should implement tools.Startable") + + // Start should forward to the inner toolset + err := startable.Start(t.Context()) + require.NoError(t, err) + assert.True(t, inner.started, "Start() should have been forwarded to inner toolset") + + // Stop should also forward + err = startable.Stop(t.Context()) + require.NoError(t, err) + assert.False(t, inner.started, "Stop() should have been forwarded to inner toolset") +} + +func TestWithToolsFilter_StartNoOpForNonStartableInner(t *testing.T) { + t.Parallel() + + inner := &mockToolSet{ + toolsFunc: func(context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "tool1"}}, nil + }, + } + + wrapped := WithToolsFilter(inner, "tool1") + + // Should still implement Startable + startable, ok := wrapped.(tools.Startable) + require.True(t, ok, "filterTools should implement tools.Startable") + + // Start/Stop should be no-ops without error + err := startable.Start(t.Context()) + require.NoError(t, err) + + err = startable.Stop(t.Context()) + require.NoError(t, err) +} + +func TestWithToolsFilter_StartableToolSetIntegration(t *testing.T) { + t.Parallel() + + // This test simulates the real wrapping: MCP → filterTools → StartableToolSet + inner := &startableToolSet{ + mockToolSet: mockToolSet{ + toolsFunc: func(context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "tool1"}, {Name: "tool2"}}, nil + }, + }, + } + + // Wrap in filterTools (like teamloader does) + filtered := WithToolsFilter(inner, "tool1") + + // Wrap in StartableToolSet (like agent.WithToolSets does) + startable := tools.NewStartable(filtered) + + // Start should propagate through: StartableToolSet → filterTools → startableToolSet + err := startable.Start(t.Context()) + require.NoError(t, err) + assert.True(t, startable.IsStarted(), "StartableToolSet should be started") + assert.True(t, inner.started, "Inner startable toolset should have been started") + + // Tools should work through the whole chain + result, err := startable.Tools(t.Context()) + require.NoError(t, err) + require.Len(t, result, 1) + assert.Equal(t, "tool1", result[0].Name) +} diff --git a/pkg/teamloader/instructions.go b/pkg/teamloader/instructions.go index 0bb396536..00957c056 100644 --- a/pkg/teamloader/instructions.go +++ b/pkg/teamloader/instructions.go @@ -1,6 +1,7 @@ package teamloader import ( + "context" "strings" "github.com/docker/docker-agent/pkg/tools" @@ -33,6 +34,22 @@ func (a *replaceInstruction) Unwrap() tools.ToolSet { return a.ToolSet } +// Start forwards the Start call to the inner toolset if it implements Startable. +func (a *replaceInstruction) Start(ctx context.Context) error { + if startable, ok := a.ToolSet.(tools.Startable); ok { + return startable.Start(ctx) + } + return nil +} + +// Stop forwards the Stop call to the inner toolset if it implements Startable. +func (a *replaceInstruction) Stop(ctx context.Context) error { + if startable, ok := a.ToolSet.(tools.Startable); ok { + return startable.Stop(ctx) + } + return nil +} + func (a *replaceInstruction) Instructions() string { original := tools.GetInstructions(a.ToolSet) return strings.Replace(a.instruction, "{ORIGINAL_INSTRUCTIONS}", original, 1) diff --git a/pkg/teamloader/instructions_test.go b/pkg/teamloader/instructions_test.go index 5c8e04e37..3d1075e56 100644 --- a/pkg/teamloader/instructions_test.go +++ b/pkg/teamloader/instructions_test.go @@ -1,9 +1,11 @@ package teamloader import ( + "context" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/tools" ) @@ -44,3 +46,55 @@ func TestWithInstructions_add(t *testing.T) { assert.Equal(t, "Existing instructions\nMore instructions", tools.GetInstructions(wrapped)) } + +type startableInstructableToolSet struct { + toolSet + started bool +} + +func (s *startableInstructableToolSet) Start(_ context.Context) error { + s.started = true + return nil +} + +func (s *startableInstructableToolSet) Stop(_ context.Context) error { + s.started = false + return nil +} + +func TestWithInstructions_ForwardsStartToStartableInner(t *testing.T) { + t.Parallel() + + inner := &startableInstructableToolSet{ + toolSet: toolSet{instruction: "test"}, + } + + wrapped := WithInstructions(inner, "New instructions") + + startable, ok := wrapped.(tools.Startable) + require.True(t, ok, "replaceInstruction should implement tools.Startable") + + err := startable.Start(t.Context()) + require.NoError(t, err) + assert.True(t, inner.started, "Start() should have been forwarded to inner toolset") + + err = startable.Stop(t.Context()) + require.NoError(t, err) + assert.False(t, inner.started, "Stop() should have been forwarded to inner toolset") +} + +func TestWithInstructions_StartNoOpForNonStartableInner(t *testing.T) { + t.Parallel() + + inner := &toolSet{instruction: "test"} + wrapped := WithInstructions(inner, "New instructions") + + startable, ok := wrapped.(tools.Startable) + require.True(t, ok, "replaceInstruction should implement tools.Startable") + + err := startable.Start(t.Context()) + require.NoError(t, err) + + err = startable.Stop(t.Context()) + require.NoError(t, err) +} diff --git a/pkg/teamloader/toon.go b/pkg/teamloader/toon.go index ec949cf0c..7e2c0e93d 100644 --- a/pkg/teamloader/toon.go +++ b/pkg/teamloader/toon.go @@ -19,6 +19,22 @@ type toonTools struct { // Verify interface compliance var _ tools.Unwrapper = (*toonTools)(nil) +// Start forwards the Start call to the inner toolset if it implements Startable. +func (f *toonTools) Start(ctx context.Context) error { + if startable, ok := f.ToolSet.(tools.Startable); ok { + return startable.Start(ctx) + } + return nil +} + +// Stop forwards the Stop call to the inner toolset if it implements Startable. +func (f *toonTools) Stop(ctx context.Context) error { + if startable, ok := f.ToolSet.(tools.Startable); ok { + return startable.Stop(ctx) + } + return nil +} + func (f *toonTools) Tools(ctx context.Context) ([]tools.Tool, error) { allTools, err := f.ToolSet.Tools(ctx) if err != nil { diff --git a/pkg/teamloader/toon_test.go b/pkg/teamloader/toon_test.go index 52607eb73..3eb399ac6 100644 --- a/pkg/teamloader/toon_test.go +++ b/pkg/teamloader/toon_test.go @@ -70,3 +70,49 @@ func TestToon(t *testing.T) { }) } } + +func TestWithToon_ForwardsStartToStartableInner(t *testing.T) { + t.Parallel() + + inner := &startableToolSet{ + mockToolSet: mockToolSet{ + toolsFunc: func(ctx context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "tool1"}}, nil + }, + }, + } + + wrapped := WithToon(inner, "tool1") + + startable, ok := wrapped.(tools.Startable) + require.True(t, ok, "toonTools should implement tools.Startable") + + err := startable.Start(t.Context()) + require.NoError(t, err) + assert.True(t, inner.started, "Start() should have been forwarded to inner toolset") + + err = startable.Stop(t.Context()) + require.NoError(t, err) + assert.False(t, inner.started, "Stop() should have been forwarded to inner toolset") +} + +func TestWithToon_StartNoOpForNonStartableInner(t *testing.T) { + t.Parallel() + + inner := &mockToolSet{ + toolsFunc: func(ctx context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "tool1"}}, nil + }, + } + + wrapped := WithToon(inner, "tool1") + + startable, ok := wrapped.(tools.Startable) + require.True(t, ok, "toonTools should implement tools.Startable") + + err := startable.Start(t.Context()) + require.NoError(t, err) + + err = startable.Stop(t.Context()) + require.NoError(t, err) +}