Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ require (
github.com/aws/aws-sdk-go-v2/credentials v1.19.12
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9
github.com/aws/smithy-go v1.24.2
github.com/aymanbagabas/go-udiff v0.4.1
github.com/blevesearch/bleve/v2 v2.5.7
github.com/bmatcuk/doublestar/v4 v4.10.0
Expand Down Expand Up @@ -90,7 +91,6 @@ require (
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/bits-and-blooms/bitset v1.24.4 // indirect
Expand Down
60 changes: 60 additions & 0 deletions pkg/backoff/backoff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Package backoff provides exponential backoff calculation and
// context-aware sleep utilities.
package backoff

import (
"context"
"math/rand/v2"
"time"
)

// Configuration constants for exponential backoff.
const (
baseDelay = 200 * time.Millisecond
maxDelay = 2 * time.Second
factor = 2.0
jitter = 0.1

// MaxRetryAfterWait caps how long we'll honor a Retry-After header to prevent
// a misbehaving server from blocking the agent for an unreasonable amount of time.
MaxRetryAfterWait = 60 * time.Second
)

// Calculate returns the backoff duration for a given attempt (0-indexed).
// Uses exponential backoff with jitter.
func Calculate(attempt int) time.Duration {
if attempt < 0 {
attempt = 0
}

// Calculate exponential delay
delay := float64(baseDelay)
for range attempt {
delay *= factor
}

// Cap at max delay
if delay > float64(maxDelay) {
delay = float64(maxDelay)
}

// Add jitter (±10%)
j := delay * jitter * (2*rand.Float64() - 1)
delay += j

return time.Duration(delay)
}

// SleepWithContext sleeps for the specified duration, returning early if context is cancelled.
// Returns true if the sleep completed, false if it was interrupted by context cancellation.
func SleepWithContext(ctx context.Context, d time.Duration) bool {
timer := time.NewTimer(d)
defer timer.Stop()

select {
case <-timer.C:
return true
case <-ctx.Done():
return false
}
}
70 changes: 70 additions & 0 deletions pkg/backoff/backoff_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package backoff

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestCalculate(t *testing.T) {
t.Parallel()

tests := []struct {
attempt int
minExpected time.Duration
maxExpected time.Duration
}{
{attempt: 0, minExpected: 180 * time.Millisecond, maxExpected: 220 * time.Millisecond},
{attempt: 1, minExpected: 360 * time.Millisecond, maxExpected: 440 * time.Millisecond},
{attempt: 2, minExpected: 720 * time.Millisecond, maxExpected: 880 * time.Millisecond},
{attempt: 3, minExpected: 1440 * time.Millisecond, maxExpected: 1760 * time.Millisecond},
{attempt: 10, minExpected: 1800 * time.Millisecond, maxExpected: 2200 * time.Millisecond}, // capped at 2s
}

for _, tt := range tests {
t.Run(fmt.Sprintf("attempt_%d", tt.attempt), func(t *testing.T) {
t.Parallel()
b := Calculate(tt.attempt)
assert.GreaterOrEqual(t, b, tt.minExpected, "backoff should be at least %v", tt.minExpected)
assert.LessOrEqual(t, b, tt.maxExpected, "backoff should be at most %v", tt.maxExpected)
})
}

t.Run("negative attempt treated as 0", func(t *testing.T) {
t.Parallel()
b := Calculate(-1)
assert.GreaterOrEqual(t, b, 180*time.Millisecond)
assert.LessOrEqual(t, b, 220*time.Millisecond)
})
}

func TestSleepWithContext(t *testing.T) {
t.Parallel()

t.Run("completes normally", func(t *testing.T) {
t.Parallel()
ctx := t.Context()
start := time.Now()
completed := SleepWithContext(ctx, 10*time.Millisecond)
elapsed := time.Since(start)

assert.True(t, completed, "should complete normally")
assert.GreaterOrEqual(t, elapsed, 10*time.Millisecond)
})

t.Run("interrupted by context", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(t.Context())
time.AfterFunc(10*time.Millisecond, cancel)

start := time.Now()
completed := SleepWithContext(ctx, 1*time.Second)
elapsed := time.Since(start)

assert.False(t, completed, "should be interrupted")
assert.Less(t, elapsed, 100*time.Millisecond, "should return quickly after cancel")
})
}
2 changes: 1 addition & 1 deletion pkg/model/provider/bedrock/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
// Check for errors
if err := a.stream.Err(); err != nil {
slog.Debug("Bedrock stream: error on channel close", "error", err)
return chat.MessageStreamResponse{}, err
return chat.MessageStreamResponse{}, wrapBedrockError(err)
}
// If we have a pending finish reason but never got metadata, emit it now
if a.pendingFinishReason != "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/model/provider/bedrock/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func (c *Client) CreateChatCompletionStream(
output, err := c.bedrockClient.ConverseStream(ctx, input)
if err != nil {
slog.Error("Bedrock ConverseStream failed", "error", err)
return nil, fmt.Errorf("bedrock converse stream failed: %w", err)
return nil, wrapBedrockError(fmt.Errorf("bedrock converse stream failed: %w", err))
}

trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage
Expand Down
35 changes: 35 additions & 0 deletions pkg/model/provider/bedrock/wrap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package bedrock

import (
"errors"

smithyhttp "github.com/aws/smithy-go/transport/http"

"github.com/docker/docker-agent/pkg/modelerrors"
)

// wrapBedrockError wraps an AWS Bedrock SDK error in a *modelerrors.StatusError
// to carry HTTP status code metadata for the retry loop.
// The AWS SDK v2 exposes HTTP status via smithyhttp.ResponseError.
// Non-AWS errors (e.g., io.EOF, network errors) pass through unchanged.
func wrapBedrockError(err error) error {
if err == nil {
return nil
}

var respErr *smithyhttp.ResponseError
if !errors.As(err, &respErr) {
return err
}

var resp *smithyhttp.Response
if respErr.HTTPResponse() != nil {
resp = respErr.HTTPResponse()
}

statusCode := respErr.HTTPStatusCode()
if resp != nil {
return modelerrors.WrapHTTPError(statusCode, resp.Response, err)
}
return modelerrors.WrapHTTPError(statusCode, nil, err)
}
110 changes: 110 additions & 0 deletions pkg/model/provider/bedrock/wrap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package bedrock

import (
"errors"
"fmt"
"net/http"
"testing"
"time"

smithy "github.com/aws/smithy-go"
smithyhttp "github.com/aws/smithy-go/transport/http"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/modelerrors"
)

func makeTestBedrockError(statusCode int, retryAfterValue string) error {
header := http.Header{}
if retryAfterValue != "" {
header.Set("Retry-After", retryAfterValue)
}

httpResp := &http.Response{
StatusCode: statusCode,
Header: header,
}
resp := &smithyhttp.Response{Response: httpResp}

return &smithy.OperationError{
ServiceID: "BedrockRuntime",
OperationName: "ConverseStream",
Err: &smithyhttp.ResponseError{
Response: resp,
Err: &smithy.GenericAPIError{
Code: "ThrottlingException",
Message: "Rate exceeded",
},
},
}
}

func TestWrapBedrockError(t *testing.T) {
t.Parallel()

t.Run("nil returns nil", func(t *testing.T) {
t.Parallel()
assert.NoError(t, wrapBedrockError(nil))
})

t.Run("non-AWS error passes through unchanged", func(t *testing.T) {
t.Parallel()
orig := errors.New("some network error")
result := wrapBedrockError(orig)
assert.Equal(t, orig, result)
var se *modelerrors.StatusError
assert.NotErrorAs(t, result, &se)
})

t.Run("429 without Retry-After wraps with zero RetryAfter", func(t *testing.T) {
t.Parallel()
awsErr := makeTestBedrockError(429, "")
result := wrapBedrockError(awsErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
assert.Equal(t, 429, se.StatusCode)
assert.Equal(t, time.Duration(0), se.RetryAfter)
// Original error still accessible
assert.ErrorIs(t, result, awsErr)
})

t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) {
t.Parallel()
awsErr := makeTestBedrockError(429, "20")
result := wrapBedrockError(awsErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
assert.Equal(t, 429, se.StatusCode)
assert.Equal(t, 20*time.Second, se.RetryAfter)
})

t.Run("500 wraps with correct status code", func(t *testing.T) {
t.Parallel()
awsErr := makeTestBedrockError(500, "")
result := wrapBedrockError(awsErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
assert.Equal(t, 500, se.StatusCode)
assert.Equal(t, time.Duration(0), se.RetryAfter)
})

t.Run("wrapped error is classified correctly by ClassifyModelError", func(t *testing.T) {
t.Parallel()
awsErr := makeTestBedrockError(429, "15")
result := wrapBedrockError(awsErr)
retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(result)
assert.False(t, retryable)
assert.True(t, rateLimited)
assert.Equal(t, 15*time.Second, retryAfter)
})

t.Run("wrapped in fmt.Errorf still classified correctly", func(t *testing.T) {
t.Parallel()
awsErr := makeTestBedrockError(500, "")
wrapped := fmt.Errorf("bedrock converse stream failed: %w", wrapBedrockError(awsErr))
retryable, rateLimited, _ := modelerrors.ClassifyModelError(wrapped)
assert.True(t, retryable)
assert.False(t, rateLimited)
})
}
Loading
Loading