diff --git a/go.mod b/go.mod index ea7add929..9b542daa2 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.19.12 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2 github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 + github.com/aws/smithy-go v1.24.2 github.com/aymanbagabas/go-udiff v0.4.1 github.com/blevesearch/bleve/v2 v2.5.7 github.com/bmatcuk/doublestar/v4 v4.10.0 @@ -90,7 +91,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect - github.com/aws/smithy-go v1.24.2 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/bits-and-blooms/bitset v1.24.4 // indirect diff --git a/pkg/backoff/backoff.go b/pkg/backoff/backoff.go new file mode 100644 index 000000000..722fb1d4e --- /dev/null +++ b/pkg/backoff/backoff.go @@ -0,0 +1,60 @@ +// Package backoff provides exponential backoff calculation and +// context-aware sleep utilities. +package backoff + +import ( + "context" + "math/rand/v2" + "time" +) + +// Configuration constants for exponential backoff. +const ( + baseDelay = 200 * time.Millisecond + maxDelay = 2 * time.Second + factor = 2.0 + jitter = 0.1 + + // MaxRetryAfterWait caps how long we'll honor a Retry-After header to prevent + // a misbehaving server from blocking the agent for an unreasonable amount of time. + MaxRetryAfterWait = 60 * time.Second +) + +// Calculate returns the backoff duration for a given attempt (0-indexed). +// Uses exponential backoff with jitter. +func Calculate(attempt int) time.Duration { + if attempt < 0 { + attempt = 0 + } + + // Calculate exponential delay + delay := float64(baseDelay) + for range attempt { + delay *= factor + } + + // Cap at max delay + if delay > float64(maxDelay) { + delay = float64(maxDelay) + } + + // Add jitter (±10%) + j := delay * jitter * (2*rand.Float64() - 1) + delay += j + + return time.Duration(delay) +} + +// SleepWithContext sleeps for the specified duration, returning early if context is cancelled. +// Returns true if the sleep completed, false if it was interrupted by context cancellation. +func SleepWithContext(ctx context.Context, d time.Duration) bool { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-timer.C: + return true + case <-ctx.Done(): + return false + } +} diff --git a/pkg/backoff/backoff_test.go b/pkg/backoff/backoff_test.go new file mode 100644 index 000000000..e0a61d9c5 --- /dev/null +++ b/pkg/backoff/backoff_test.go @@ -0,0 +1,70 @@ +package backoff + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCalculate(t *testing.T) { + t.Parallel() + + tests := []struct { + attempt int + minExpected time.Duration + maxExpected time.Duration + }{ + {attempt: 0, minExpected: 180 * time.Millisecond, maxExpected: 220 * time.Millisecond}, + {attempt: 1, minExpected: 360 * time.Millisecond, maxExpected: 440 * time.Millisecond}, + {attempt: 2, minExpected: 720 * time.Millisecond, maxExpected: 880 * time.Millisecond}, + {attempt: 3, minExpected: 1440 * time.Millisecond, maxExpected: 1760 * time.Millisecond}, + {attempt: 10, minExpected: 1800 * time.Millisecond, maxExpected: 2200 * time.Millisecond}, // capped at 2s + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("attempt_%d", tt.attempt), func(t *testing.T) { + t.Parallel() + b := Calculate(tt.attempt) + assert.GreaterOrEqual(t, b, tt.minExpected, "backoff should be at least %v", tt.minExpected) + assert.LessOrEqual(t, b, tt.maxExpected, "backoff should be at most %v", tt.maxExpected) + }) + } + + t.Run("negative attempt treated as 0", func(t *testing.T) { + t.Parallel() + b := Calculate(-1) + assert.GreaterOrEqual(t, b, 180*time.Millisecond) + assert.LessOrEqual(t, b, 220*time.Millisecond) + }) +} + +func TestSleepWithContext(t *testing.T) { + t.Parallel() + + t.Run("completes normally", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + start := time.Now() + completed := SleepWithContext(ctx, 10*time.Millisecond) + elapsed := time.Since(start) + + assert.True(t, completed, "should complete normally") + assert.GreaterOrEqual(t, elapsed, 10*time.Millisecond) + }) + + t.Run("interrupted by context", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(t.Context()) + time.AfterFunc(10*time.Millisecond, cancel) + + start := time.Now() + completed := SleepWithContext(ctx, 1*time.Second) + elapsed := time.Since(start) + + assert.False(t, completed, "should be interrupted") + assert.Less(t, elapsed, 100*time.Millisecond, "should return quickly after cancel") + }) +} diff --git a/pkg/model/provider/bedrock/adapter.go b/pkg/model/provider/bedrock/adapter.go index 8f29cd78e..0c6552fd9 100644 --- a/pkg/model/provider/bedrock/adapter.go +++ b/pkg/model/provider/bedrock/adapter.go @@ -71,7 +71,7 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) { // Check for errors if err := a.stream.Err(); err != nil { slog.Debug("Bedrock stream: error on channel close", "error", err) - return chat.MessageStreamResponse{}, err + return chat.MessageStreamResponse{}, wrapBedrockError(err) } // If we have a pending finish reason but never got metadata, emit it now if a.pendingFinishReason != "" { diff --git a/pkg/model/provider/bedrock/client.go b/pkg/model/provider/bedrock/client.go index a083f1d47..8086d41de 100644 --- a/pkg/model/provider/bedrock/client.go +++ b/pkg/model/provider/bedrock/client.go @@ -219,7 +219,7 @@ func (c *Client) CreateChatCompletionStream( output, err := c.bedrockClient.ConverseStream(ctx, input) if err != nil { slog.Error("Bedrock ConverseStream failed", "error", err) - return nil, fmt.Errorf("bedrock converse stream failed: %w", err) + return nil, wrapBedrockError(fmt.Errorf("bedrock converse stream failed: %w", err)) } trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage diff --git a/pkg/model/provider/bedrock/wrap.go b/pkg/model/provider/bedrock/wrap.go new file mode 100644 index 000000000..84c06c313 --- /dev/null +++ b/pkg/model/provider/bedrock/wrap.go @@ -0,0 +1,35 @@ +package bedrock + +import ( + "errors" + + smithyhttp "github.com/aws/smithy-go/transport/http" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +// wrapBedrockError wraps an AWS Bedrock SDK error in a *modelerrors.StatusError +// to carry HTTP status code metadata for the retry loop. +// The AWS SDK v2 exposes HTTP status via smithyhttp.ResponseError. +// Non-AWS errors (e.g., io.EOF, network errors) pass through unchanged. +func wrapBedrockError(err error) error { + if err == nil { + return nil + } + + var respErr *smithyhttp.ResponseError + if !errors.As(err, &respErr) { + return err + } + + var resp *smithyhttp.Response + if respErr.HTTPResponse() != nil { + resp = respErr.HTTPResponse() + } + + statusCode := respErr.HTTPStatusCode() + if resp != nil { + return modelerrors.WrapHTTPError(statusCode, resp.Response, err) + } + return modelerrors.WrapHTTPError(statusCode, nil, err) +} diff --git a/pkg/model/provider/bedrock/wrap_test.go b/pkg/model/provider/bedrock/wrap_test.go new file mode 100644 index 000000000..13f1a24d4 --- /dev/null +++ b/pkg/model/provider/bedrock/wrap_test.go @@ -0,0 +1,110 @@ +package bedrock + +import ( + "errors" + "fmt" + "net/http" + "testing" + "time" + + smithy "github.com/aws/smithy-go" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +func makeTestBedrockError(statusCode int, retryAfterValue string) error { + header := http.Header{} + if retryAfterValue != "" { + header.Set("Retry-After", retryAfterValue) + } + + httpResp := &http.Response{ + StatusCode: statusCode, + Header: header, + } + resp := &smithyhttp.Response{Response: httpResp} + + return &smithy.OperationError{ + ServiceID: "BedrockRuntime", + OperationName: "ConverseStream", + Err: &smithyhttp.ResponseError{ + Response: resp, + Err: &smithy.GenericAPIError{ + Code: "ThrottlingException", + Message: "Rate exceeded", + }, + }, + } +} + +func TestWrapBedrockError(t *testing.T) { + t.Parallel() + + t.Run("nil returns nil", func(t *testing.T) { + t.Parallel() + assert.NoError(t, wrapBedrockError(nil)) + }) + + t.Run("non-AWS error passes through unchanged", func(t *testing.T) { + t.Parallel() + orig := errors.New("some network error") + result := wrapBedrockError(orig) + assert.Equal(t, orig, result) + var se *modelerrors.StatusError + assert.NotErrorAs(t, result, &se) + }) + + t.Run("429 without Retry-After wraps with zero RetryAfter", func(t *testing.T) { + t.Parallel() + awsErr := makeTestBedrockError(429, "") + result := wrapBedrockError(awsErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, time.Duration(0), se.RetryAfter) + // Original error still accessible + assert.ErrorIs(t, result, awsErr) + }) + + t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) { + t.Parallel() + awsErr := makeTestBedrockError(429, "20") + result := wrapBedrockError(awsErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, 20*time.Second, se.RetryAfter) + }) + + t.Run("500 wraps with correct status code", func(t *testing.T) { + t.Parallel() + awsErr := makeTestBedrockError(500, "") + result := wrapBedrockError(awsErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 500, se.StatusCode) + assert.Equal(t, time.Duration(0), se.RetryAfter) + }) + + t.Run("wrapped error is classified correctly by ClassifyModelError", func(t *testing.T) { + t.Parallel() + awsErr := makeTestBedrockError(429, "15") + result := wrapBedrockError(awsErr) + retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(result) + assert.False(t, retryable) + assert.True(t, rateLimited) + assert.Equal(t, 15*time.Second, retryAfter) + }) + + t.Run("wrapped in fmt.Errorf still classified correctly", func(t *testing.T) { + t.Parallel() + awsErr := makeTestBedrockError(500, "") + wrapped := fmt.Errorf("bedrock converse stream failed: %w", wrapBedrockError(awsErr)) + retryable, rateLimited, _ := modelerrors.ClassifyModelError(wrapped) + assert.True(t, retryable) + assert.False(t, rateLimited) + }) +} diff --git a/pkg/modelerrors/modelerrors.go b/pkg/modelerrors/modelerrors.go index 4283262e7..a24d8edb6 100644 --- a/pkg/modelerrors/modelerrors.go +++ b/pkg/modelerrors/modelerrors.go @@ -9,28 +9,12 @@ import ( "errors" "fmt" "log/slog" - "math/rand" "net" "net/http" "regexp" "strconv" "strings" "time" - - "github.com/anthropics/anthropic-sdk-go" - "google.golang.org/genai" -) - -// Backoff and retry-after configuration constants. -const ( - backoffBaseDelay = 200 * time.Millisecond - backoffMaxDelay = 2 * time.Second - backoffFactor = 2.0 - backoffJitter = 0.1 - - // MaxRetryAfterWait caps how long we'll honor a Retry-After header to prevent - // a misbehaving server from blocking the agent for an unreasonable amount of time. - MaxRetryAfterWait = 60 * time.Second ) // StatusError wraps an HTTP API error with structured metadata for retry decisions. @@ -46,7 +30,7 @@ type StatusError struct { } func (e *StatusError) Error() string { - return e.Err.Error() + return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Err.Error()) } func (e *StatusError) Unwrap() error { @@ -62,7 +46,7 @@ func WrapHTTPError(statusCode int, resp *http.Response, err error) error { } var retryAfter time.Duration if resp != nil { - retryAfter = ParseRetryAfterHeader(resp.Header.Get("Retry-After")) + retryAfter = parseRetryAfterHeader(resp.Header.Get("Retry-After")) } return &StatusError{ StatusCode: statusCode, @@ -91,6 +75,13 @@ type ContextOverflowError struct { Underlying error } +// NewContextOverflowError creates a ContextOverflowError wrapping the given +// underlying error. Use this constructor rather than building the struct +// directly so that future field additions don't break callers. +func NewContextOverflowError(underlying error) *ContextOverflowError { + return &ContextOverflowError{Underlying: underlying} +} + func (e *ContextOverflowError) Error() string { if e.Underlying == nil { return "context window overflow" @@ -161,32 +152,29 @@ func IsContextOverflowError(err error) bool { // statusCodeRegex matches HTTP status codes in error messages (e.g., "429", "500", ": 429 ") var statusCodeRegex = regexp.MustCompile(`\b([45]\d{2})\b`) -// ExtractHTTPStatusCode attempts to extract an HTTP status code from the error. -// Checks in order: -// 1. Known provider SDK error types (Anthropic, Gemini) -// 2. Regex parsing of error message (fallback for OpenAI and others) +// extractHTTPStatusCode attempts to extract an HTTP status code from the error +// using regex parsing of the error message. This is a fallback for providers +// whose errors are not yet wrapped in *StatusError (the preferred path). +// +// The regex matches 4xx/5xx codes at word boundaries +// (e.g., "429 Too Many Requests", "500 Internal Server Error"). // Returns 0 if no status code found. -func ExtractHTTPStatusCode(err error) int { +func extractHTTPStatusCode(err error) int { if err == nil { return 0 } - // Check Anthropic SDK error type (public) - if anthropicErr, ok := errors.AsType[*anthropic.Error](err); ok { - return anthropicErr.StatusCode - } - - // Check Google Gemini SDK error type (public) - if geminiErr, ok := errors.AsType[*genai.APIError](err); ok { - return geminiErr.Code + // Check for *StatusError first (preferred structured path). + var statusErr *StatusError + if errors.As(err, &statusErr) { + return statusErr.StatusCode } - // For other providers (OpenAI, etc.), extract from error message using regex + // Fallback: extract from error message using regex. // OpenAI SDK error format: `POST "/v1/...": 429 Too Many Requests {...}` matches := statusCodeRegex.FindStringSubmatch(err.Error()) if len(matches) >= 2 { - var code int - if _, err := fmt.Sscanf(matches[1], "%d", &code); err == nil { + if code, err := strconv.Atoi(matches[1]); err == nil { return code } } @@ -194,7 +182,7 @@ func ExtractHTTPStatusCode(err error) int { return 0 } -// IsRetryableStatusCode determines if an HTTP status code is retryable. +// isRetryableStatusCode determines if an HTTP status code is retryable. // Retryable means we should retry the SAME model with exponential backoff. // // Retryable status codes: @@ -205,7 +193,7 @@ func ExtractHTTPStatusCode(err error) int { // Non-retryable status codes (skip to next model immediately): // - 429 (rate limit) - provider is explicitly telling us to back off // - 4xx client errors (400, 401, 403, 404) - won't get better with retry -func IsRetryableStatusCode(statusCode int) bool { +func isRetryableStatusCode(statusCode int) bool { switch statusCode { case 500, 502, 503, 504: // Server errors return true @@ -220,7 +208,45 @@ func IsRetryableStatusCode(statusCode int) bool { } } -// IsRetryableModelError determines if an error should trigger a retry of the SAME model. +// retryablePatterns contains error message substrings that indicate a +// transient/retryable failure. Numeric status codes (500, 502, etc.) are +// handled separately by extractHTTPStatusCode + isRetryableStatusCode. +var retryablePatterns = []string{ + "timeout", // Generic timeout + "connection reset", // Connection reset + "connection refused", // Connection refused + "no such host", // DNS failure + "temporary failure", // Temporary failure + "service unavailable", // Service unavailable + "internal server error", // Server error + "bad gateway", // Gateway error + "gateway timeout", // Gateway timeout + "overloaded", // Server overloaded + "overloaded_error", // Server overloaded + "other side closed", // Connection closed by peer + "fetch failed", // Network fetch failure + "reset before headers", // Connection reset before headers received + "upstream connect", // Upstream connection error + "internal_error", // HTTP/2 INTERNAL_ERROR (stream-level) +} + +// nonRetryablePatterns contains error message substrings that indicate a +// permanent/non-retryable failure. Numeric status codes (429, 401, etc.) are +// handled separately by extractHTTPStatusCode. +var nonRetryablePatterns = []string{ + "rate limit", // Rate limit message + "too many requests", // Rate limit message + "throttl", // Throttling (rate limiting) + "quota", // Quota exceeded + "capacity", // Capacity issues (often rate-limit related) + "invalid", // Invalid request + "unauthorized", // Auth error + "authentication", // Auth error + "api key", // API key error +} + +// isRetryableModelError determines if an error should trigger a retry of the SAME model. +// It is used as a fallback by ClassifyModelError when no *StatusError is present. // // Retryable errors (retry same model with backoff): // - Network timeouts @@ -238,7 +264,7 @@ func IsRetryableStatusCode(statusCode int) bool { // // The key distinction is: 429 means "you're calling too fast, slow down" which // suggests we should try a different model, not keep hammering the same one. -func IsRetryableModelError(err error) bool { +func isRetryableModelError(err error) bool { if err == nil { return false } @@ -257,8 +283,8 @@ func IsRetryableModelError(err error) bool { } // First, try to extract HTTP status code from known SDK error types - if statusCode := ExtractHTTPStatusCode(err); statusCode != 0 { - retryable := IsRetryableStatusCode(statusCode) + if statusCode := extractHTTPStatusCode(err); statusCode != 0 { + retryable := isRetryableStatusCode(statusCode) slog.Debug("Classified error by status code", "status_code", statusCode, "retryable", retryable) @@ -274,35 +300,8 @@ func IsRetryableModelError(err error) bool { } } - // Fall back to message-pattern matching for errors without structured status codes errMsg := strings.ToLower(err.Error()) - // Retryable patterns (5xx, timeout, network issues) - // NOTE: 429 is explicitly NOT in this list - we skip to next model for rate limits - retryablePatterns := []string{ - "500", // Internal server error - "502", // Bad gateway - "503", // Service unavailable - "504", // Gateway timeout - "408", // Request timeout - "timeout", // Generic timeout - "connection reset", // Connection reset - "connection refused", // Connection refused - "no such host", // DNS failure - "temporary failure", // Temporary failure - "service unavailable", // Service unavailable - "internal server error", // Server error - "bad gateway", // Gateway error - "gateway timeout", // Gateway timeout - "overloaded", // Server overloaded - "overloaded_error", // Server overloaded - "other side closed", // Connection closed by peer - "fetch failed", // Network fetch failure - "reset before headers", // Connection reset before headers received - "upstream connect", // Upstream connection error - "internal_error", // HTTP/2 INTERNAL_ERROR (stream-level) - } - for _, pattern := range retryablePatterns { if strings.Contains(errMsg, pattern) { slog.Debug("Matched retryable error pattern", "pattern", pattern) @@ -310,24 +309,6 @@ func IsRetryableModelError(err error) bool { } } - // Non-retryable patterns (skip to next model immediately) - nonRetryablePatterns := []string{ - "429", // Rate limit - skip to next model - "rate limit", // Rate limit message - "too many requests", // Rate limit message - "throttl", // Throttling (rate limiting) - "quota", // Quota exceeded - "capacity", // Capacity issues (often rate-limit related) - "401", // Unauthorized - "403", // Forbidden - "404", // Not found - "400", // Bad request - "invalid", // Invalid request - "unauthorized", // Auth error - "authentication", // Auth error - "api key", // API key error - } - for _, pattern := range nonRetryablePatterns { if strings.Contains(errMsg, pattern) { slog.Debug("Matched non-retryable error pattern", "pattern", pattern) @@ -340,10 +321,10 @@ func IsRetryableModelError(err error) bool { return false } -// ParseRetryAfterHeader parses a Retry-After header value. +// parseRetryAfterHeader parses a Retry-After header value. // Supports both seconds (integer) and HTTP-date formats per RFC 7231 §7.1.3. // Returns 0 if the value is empty, invalid, or results in a non-positive duration. -func ParseRetryAfterHeader(value string) time.Duration { +func parseRetryAfterHeader(value string) time.Duration { if value == "" { return 0 } @@ -397,58 +378,19 @@ func ClassifyModelError(err error) (retryable, rateLimited bool, retryAfter time if statusErr.StatusCode == http.StatusTooManyRequests { return false, true, statusErr.RetryAfter } - return IsRetryableStatusCode(statusErr.StatusCode), false, 0 + return isRetryableStatusCode(statusErr.StatusCode), false, 0 } // Fallback: providers that don't yet wrap (e.g. Bedrock), or non-provider // errors (network, pattern matching). - statusCode := ExtractHTTPStatusCode(err) + statusCode := extractHTTPStatusCode(err) if statusCode == http.StatusTooManyRequests { return false, true, 0 // No Retry-After without StatusError } if statusCode != 0 { - return IsRetryableStatusCode(statusCode), false, 0 - } - return IsRetryableModelError(err), false, 0 -} - -// CalculateBackoff returns the backoff duration for a given attempt (0-indexed). -// Uses exponential backoff with jitter. -func CalculateBackoff(attempt int) time.Duration { - if attempt < 0 { - attempt = 0 - } - - // Calculate exponential delay - delay := float64(backoffBaseDelay) - for range attempt { - delay *= backoffFactor - } - - // Cap at max delay - if delay > float64(backoffMaxDelay) { - delay = float64(backoffMaxDelay) - } - - // Add jitter (±10%) - jitter := delay * backoffJitter * (2*rand.Float64() - 1) - delay += jitter - - return time.Duration(delay) -} - -// SleepWithContext sleeps for the specified duration, returning early if context is cancelled. -// Returns true if the sleep completed, false if it was interrupted by context cancellation. -func SleepWithContext(ctx context.Context, d time.Duration) bool { - timer := time.NewTimer(d) - defer timer.Stop() - - select { - case <-timer.C: - return true - case <-ctx.Done(): - return false + return isRetryableStatusCode(statusCode), false, 0 } + return isRetryableModelError(err), false, 0 } // FormatError returns a user-friendly error message for model errors. diff --git a/pkg/modelerrors/modelerrors_test.go b/pkg/modelerrors/modelerrors_test.go index 361c89f34..3df53d1bd 100644 --- a/pkg/modelerrors/modelerrors_test.go +++ b/pkg/modelerrors/modelerrors_test.go @@ -66,71 +66,11 @@ func TestIsRetryableModelError(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, tt.expected, IsRetryableModelError(tt.err), "IsRetryableModelError(%v)", tt.err) + assert.Equal(t, tt.expected, isRetryableModelError(tt.err), "isRetryableModelError(%v)", tt.err) }) } } -func TestCalculateBackoff(t *testing.T) { - t.Parallel() - - tests := []struct { - attempt int - minExpected time.Duration - maxExpected time.Duration - }{ - {attempt: 0, minExpected: 180 * time.Millisecond, maxExpected: 220 * time.Millisecond}, - {attempt: 1, minExpected: 360 * time.Millisecond, maxExpected: 440 * time.Millisecond}, - {attempt: 2, minExpected: 720 * time.Millisecond, maxExpected: 880 * time.Millisecond}, - {attempt: 3, minExpected: 1440 * time.Millisecond, maxExpected: 1760 * time.Millisecond}, - {attempt: 10, minExpected: 1800 * time.Millisecond, maxExpected: 2200 * time.Millisecond}, // capped at 2s - } - - for _, tt := range tests { - t.Run(fmt.Sprintf("attempt_%d", tt.attempt), func(t *testing.T) { - t.Parallel() - backoff := CalculateBackoff(tt.attempt) - assert.GreaterOrEqual(t, backoff, tt.minExpected, "backoff should be at least %v", tt.minExpected) - assert.LessOrEqual(t, backoff, tt.maxExpected, "backoff should be at most %v", tt.maxExpected) - }) - } - - t.Run("negative attempt treated as 0", func(t *testing.T) { - t.Parallel() - backoff := CalculateBackoff(-1) - assert.GreaterOrEqual(t, backoff, 180*time.Millisecond) - assert.LessOrEqual(t, backoff, 220*time.Millisecond) - }) -} - -func TestSleepWithContext(t *testing.T) { - t.Parallel() - - t.Run("completes normally", func(t *testing.T) { - t.Parallel() - ctx := t.Context() - start := time.Now() - completed := SleepWithContext(ctx, 10*time.Millisecond) - elapsed := time.Since(start) - - assert.True(t, completed, "should complete normally") - assert.GreaterOrEqual(t, elapsed, 10*time.Millisecond) - }) - - t.Run("interrupted by context", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithCancel(t.Context()) - time.AfterFunc(10*time.Millisecond, cancel) - - start := time.Now() - completed := SleepWithContext(ctx, 1*time.Second) - elapsed := time.Since(start) - - assert.False(t, completed, "should be interrupted") - assert.Less(t, elapsed, 100*time.Millisecond, "should return quickly after cancel") - }) -} - func TestExtractHTTPStatusCode(t *testing.T) { t.Parallel() @@ -145,12 +85,16 @@ func TestExtractHTTPStatusCode(t *testing.T) { {name: "502 in message", err: errors.New("502 bad gateway"), expected: 502}, {name: "401 in message", err: errors.New("401 unauthorized"), expected: 401}, {name: "no status code", err: errors.New("connection refused"), expected: 0}, + // StatusError structural path + {name: "StatusError 429", err: &StatusError{StatusCode: 429, Err: errors.New("rate limited")}, expected: 429}, + {name: "StatusError 500", err: &StatusError{StatusCode: 500, Err: errors.New("server error")}, expected: 500}, + {name: "wrapped StatusError", err: fmt.Errorf("outer: %w", &StatusError{StatusCode: 503, Err: errors.New("unavailable")}), expected: 503}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, tt.expected, ExtractHTTPStatusCode(tt.err), "ExtractHTTPStatusCode(%v)", tt.err) + assert.Equal(t, tt.expected, extractHTTPStatusCode(tt.err), "extractHTTPStatusCode(%v)", tt.err) }) } } @@ -174,7 +118,7 @@ func TestIsRetryableStatusCode(t *testing.T) { for _, tt := range tests { t.Run(fmt.Sprintf("status_%d", tt.statusCode), func(t *testing.T) { t.Parallel() - assert.Equal(t, tt.expected, IsRetryableStatusCode(tt.statusCode), "IsRetryableStatusCode(%d)", tt.statusCode) + assert.Equal(t, tt.expected, isRetryableStatusCode(tt.statusCode), "isRetryableStatusCode(%d)", tt.statusCode) }) } } @@ -219,20 +163,28 @@ func TestContextOverflowError(t *testing.T) { t.Run("wraps underlying error", func(t *testing.T) { t.Parallel() underlying := errors.New("prompt is too long: 226360 tokens > 200000 maximum") - ctxErr := &ContextOverflowError{Underlying: underlying} + ctxErr := NewContextOverflowError(underlying) assert.Contains(t, ctxErr.Error(), "context window overflow") assert.Contains(t, ctxErr.Error(), "prompt is too long") assert.ErrorIs(t, ctxErr, underlying) }) - t.Run("errors.As works", func(t *testing.T) { + t.Run("nil underlying returns fallback message", func(t *testing.T) { + t.Parallel() + ctxErr := NewContextOverflowError(nil) + assert.Equal(t, "context window overflow", ctxErr.Error()) + assert.NoError(t, ctxErr.Unwrap()) + }) + + t.Run("errors.As works through wrapping", func(t *testing.T) { t.Parallel() underlying := errors.New("test error") - wrapped := fmt.Errorf("all models failed: %w", &ContextOverflowError{Underlying: underlying}) + wrapped := fmt.Errorf("all models failed: %w", NewContextOverflowError(underlying)) var ctxErr *ContextOverflowError - assert.ErrorAs(t, wrapped, &ctxErr) + require.ErrorAs(t, wrapped, &ctxErr) + assert.Equal(t, underlying, ctxErr.Underlying) }) } @@ -252,7 +204,7 @@ func TestIsRetryableModelError_ContextOverflow(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - assert.False(t, IsRetryableModelError(tt.err), "context overflow errors should not be retryable: %v", tt.err) + assert.False(t, isRetryableModelError(tt.err), "context overflow errors should not be retryable: %v", tt.err) }) } } @@ -267,13 +219,21 @@ func TestFormatError(t *testing.T) { t.Run("context overflow shows user-friendly message", func(t *testing.T) { t.Parallel() - err := &ContextOverflowError{Underlying: errors.New("prompt is too long")} + err := NewContextOverflowError(errors.New("prompt is too long")) msg := FormatError(err) assert.Contains(t, msg, "context window") assert.Contains(t, msg, "/compact") assert.NotContains(t, msg, "prompt is too long") }) + t.Run("wrapped context overflow shows user-friendly message", func(t *testing.T) { + t.Parallel() + err := fmt.Errorf("outer: %w", NewContextOverflowError(errors.New("prompt is too long"))) + msg := FormatError(err) + assert.Contains(t, msg, "context window") + assert.Contains(t, msg, "/compact") + }) + t.Run("generic error preserves message", func(t *testing.T) { t.Parallel() err := errors.New("authentication failed") @@ -301,15 +261,15 @@ func TestParseRetryAfterHeader(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := ParseRetryAfterHeader(tt.value) - assert.Equal(t, tt.expected, got, "ParseRetryAfterHeader(%q)", tt.value) + got := parseRetryAfterHeader(tt.value) + assert.Equal(t, tt.expected, got, "parseRetryAfterHeader(%q)", tt.value) }) } t.Run("HTTP-date in the future", func(t *testing.T) { t.Parallel() future := time.Now().Add(10 * time.Second).UTC().Format(http.TimeFormat) - got := ParseRetryAfterHeader(future) + got := parseRetryAfterHeader(future) assert.Greater(t, got, 0*time.Second, "should return positive duration for future HTTP-date") assert.LessOrEqual(t, got, 11*time.Second, "should not exceed ~10s for near-future date") }) @@ -317,7 +277,7 @@ func TestParseRetryAfterHeader(t *testing.T) { t.Run("HTTP-date in the past", func(t *testing.T) { t.Parallel() past := time.Now().Add(-10 * time.Second).UTC().Format(http.TimeFormat) - got := ParseRetryAfterHeader(past) + got := parseRetryAfterHeader(past) assert.Equal(t, 0*time.Second, got, "should return 0 for past HTTP-date") }) } @@ -325,11 +285,11 @@ func TestParseRetryAfterHeader(t *testing.T) { func TestStatusError(t *testing.T) { t.Parallel() - t.Run("Error() delegates to wrapped error", func(t *testing.T) { + t.Run("Error() includes status code and wrapped message", func(t *testing.T) { t.Parallel() underlying := errors.New("rate limit exceeded") se := &StatusError{StatusCode: 429, Err: underlying} - assert.Equal(t, underlying.Error(), se.Error()) + assert.Equal(t, "HTTP 429: rate limit exceeded", se.Error()) }) t.Run("Unwrap() allows errors.Is traversal", func(t *testing.T) { @@ -375,7 +335,7 @@ func TestWrapHTTPError(t *testing.T) { require.ErrorAs(t, result, &se) assert.Equal(t, 429, se.StatusCode) assert.Equal(t, time.Duration(0), se.RetryAfter) - assert.Equal(t, origErr.Error(), se.Error()) + assert.Equal(t, "HTTP 429: rate limited", se.Error()) }) t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) { @@ -464,4 +424,15 @@ func TestClassifyModelError(t *testing.T) { assert.True(t, rateLimited) assert.Equal(t, 15*time.Second, retryAfterOut) }) + + t.Run("ContextOverflowError wrapping a StatusError is not retryable", func(t *testing.T) { + t.Parallel() + // A 400 StatusError whose message also triggers context overflow detection + statusErr := &StatusError{StatusCode: 400, Err: errors.New("prompt is too long")} + ctxErr := NewContextOverflowError(statusErr) + retryable, rateLimited, retryAfter := ClassifyModelError(ctxErr) + assert.False(t, retryable, "context overflow should never be retryable") + assert.False(t, rateLimited) + assert.Equal(t, time.Duration(0), retryAfter) + }) } diff --git a/pkg/runtime/fallback.go b/pkg/runtime/fallback.go index 546b3e32f..b6eb347d6 100644 --- a/pkg/runtime/fallback.go +++ b/pkg/runtime/fallback.go @@ -8,6 +8,7 @@ import ( "time" "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/backoff" "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/modelerrors" @@ -69,12 +70,12 @@ func logFallbackAttempt(agentName string, model modelWithFallback, attempt, maxR } // logRetryBackoff logs when we're backing off before a retry -func logRetryBackoff(agentName, modelID string, attempt int, backoff time.Duration) { +func logRetryBackoff(agentName, modelID string, attempt int, backoffDelay time.Duration) { slog.Debug("Backing off before retry", "agent", agentName, "model", modelID, "attempt", attempt+1, - "backoff", backoff) + "backoff", backoffDelay) } // getCooldownState returns the current cooldown state for an agent (thread-safe). @@ -222,9 +223,9 @@ func (r *LocalRuntime) tryModelWithFallback( // Apply backoff before retry (not on first attempt of each model) if attempt > 0 { - backoff := modelerrors.CalculateBackoff(attempt - 1) - logRetryBackoff(a.Name(), modelEntry.provider.ID(), attempt, backoff) - if !modelerrors.SleepWithContext(ctx, backoff) { + backoffDelay := backoff.Calculate(attempt - 1) + logRetryBackoff(a.Name(), modelEntry.provider.ID(), attempt, backoffDelay) + if !backoff.SleepWithContext(ctx, backoffDelay) { return streamResult{}, nil, ctx.Err() } } @@ -325,7 +326,7 @@ func (r *LocalRuntime) tryModelWithFallback( if lastErr != nil { wrapped := fmt.Errorf("all models failed: %w", lastErr) if modelerrors.IsContextOverflowError(lastErr) { - return streamResult{}, nil, &modelerrors.ContextOverflowError{Underlying: wrapped} + return streamResult{}, nil, modelerrors.NewContextOverflowError(wrapped) } return streamResult{}, nil, wrapped } @@ -382,14 +383,14 @@ func (r *LocalRuntime) handleModelError( // Opt-in enabled, no fallbacks → retry same model after honouring Retry-After (or backoff). waitDuration := retryAfter if waitDuration <= 0 { - waitDuration = modelerrors.CalculateBackoff(attempt) - } else if waitDuration > modelerrors.MaxRetryAfterWait { + waitDuration = backoff.Calculate(attempt) + } else if waitDuration > backoff.MaxRetryAfterWait { slog.Warn("Retry-After exceeds maximum, capping", "agent", a.Name(), "model", modelEntry.provider.ID(), "retry_after", retryAfter, - "max", modelerrors.MaxRetryAfterWait) - waitDuration = modelerrors.MaxRetryAfterWait + "max", backoff.MaxRetryAfterWait) + waitDuration = backoff.MaxRetryAfterWait } slog.Warn("Rate limited, retrying (opt-in enabled)", "agent", a.Name(), @@ -398,7 +399,7 @@ func (r *LocalRuntime) handleModelError( "wait", waitDuration, "retry_after_from_header", retryAfter > 0, "error", err) - if !modelerrors.SleepWithContext(ctx, waitDuration) { + if !backoff.SleepWithContext(ctx, waitDuration) { return retryDecisionReturn } return retryDecisionContinue