diff --git a/intercept/responses/base.go b/intercept/responses/base.go index da53c15..e127d1f 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -20,7 +19,6 @@ import ( "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" - "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" "github.com/coder/quartz" @@ -48,9 +46,8 @@ type responsesInterceptionBase struct { recorder recorder.Recorder mcpProxy mcp.ServerProxier - logger slog.Logger - metrics metrics.Metrics - tracer trace.Tracer + logger slog.Logger + tracer trace.Tracer } func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { @@ -96,27 +93,7 @@ func (i *responsesInterceptionBase) Model() string { } func (i *responsesInterceptionBase) CorrelatingToolCallID() *string { - items := gjson.GetBytes(i.reqPayload, "input") - if !items.IsArray() { - return nil - } - - arr := items.Array() - if len(arr) == 0 { - return nil - } - - last := arr[len(arr)-1] - if last.Get(string(constant.ValueOf[constant.Type]())).String() != string(constant.ValueOf[constant.FunctionCallOutput]()) { - return nil - } - - callID := last.Get("call_id").String() - if callID == "" { - return nil - } - - return &callID + return i.reqPayload.correlatingToolCallID() } func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { @@ -178,89 +155,6 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o return opts } -// lastUserPrompt returns input text with "user" role from last input item -// or string input value if it is present + bool indicating if input was found or not. -// If no such input was found empty string + false is returned. -func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (string, bool, error) { - if i == nil { - return "", false, errors.New("cannot get last user prompt: nil struct") - } - if i.reqPayload == nil { - return "", false, errors.New("cannot get last user prompt: nil request struct") - } - - // 'input' can be either a string or an array of input items: - // https://platform.openai.com/docs/api-reference/responses/create#responses_create-input - inputItems := gjson.GetBytes(i.reqPayload, "input") - if !inputItems.Exists() || inputItems.Type == gjson.Null { - return "", false, nil - } - - // String variant: treat the whole input as the user prompt. - if inputItems.Type == gjson.String { - return inputItems.String(), true, nil - } - - // Array variant: checking only the last input item - if !inputItems.IsArray() { - return "", false, fmt.Errorf("unexpected input type: %s", inputItems.Type) - } - - inputItemsArr := inputItems.Array() - if len(inputItemsArr) == 0 { - return "", false, nil - } - - lastItem := inputItemsArr[len(inputItemsArr)-1] - if lastItem.Get("role").Str != string(constant.ValueOf[constant.User]()) { - // Request was likely not initiated by a prompt but is an iteration of agentic loop. - return "", false, nil - } - - // Message content can be either a string or an array of typed content items: - // https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content - content := lastItem.Get(string(constant.ValueOf[constant.Content]())) - if !content.Exists() || content.Type == gjson.Null { - return "", false, nil - } - - // String variant: use it directly as the prompt. - if content.Type == gjson.String { - return content.Str, true, nil - } - - if !content.IsArray() { - return "", false, fmt.Errorf("unexpected input content type: %s", content.Type) - } - - var sb strings.Builder - promptExists := false - for _, c := range content.Array() { - // Ignore non-text content blocks such as images or files. - if c.Get(string(constant.ValueOf[constant.Type]())).Str != string(constant.ValueOf[constant.InputText]()) { - continue - } - - text := c.Get(string(constant.ValueOf[constant.Text]())) - if text.Type != gjson.String { - i.logger.Warn(ctx, fmt.Sprintf("unexpected input content array element text type: %v", text.Type)) - continue - } - - if promptExists { - sb.WriteByte('\n') - } - promptExists = true - sb.WriteString(text.Str) - } - - if !promptExists { - return "", false, nil - } - - return sb.String(), true, nil -} - func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string, prompt string) { if responseID == "" { i.logger.Warn(ctx, "got empty response ID, skipping prompt recording") diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index 9875f58..c1585e0 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -6,229 +6,13 @@ import ( "time" "cdr.dev/slog/v3" - "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/recorder" - "github.com/coder/aibridge/utils" "github.com/google/uuid" oairesponses "github.com/openai/openai-go/v3/responses" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestScanForCorrelatingToolCallID(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - payload []byte - wantCall *string - }{ - { - name: "no input", - payload: []byte(`{"model":"gpt-4o"}`), - }, - { - name: "empty input array", - payload: []byte(`{"model":"gpt-4o","input":[]}`), - }, - { - name: "no function_call_output items", - payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"}]}`), - }, - { - name: "single function_call_output", - payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_abc","output":"result"}]}`), - wantCall: utils.PtrTo("call_abc"), - }, - { - name: "multiple function_call_outputs returns last", - payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_second","output":"r2"}]}`), - wantCall: utils.PtrTo("call_second"), - }, - { - name: "last input is not a tool result", - payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"}]}`), - }, - { - name: "missing call id", - payload: []byte(`{"input":[{"type":"function_call_output","output":"ok"}]}`), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - rp, err := NewResponsesRequestPayload(tc.payload) - require.NoError(t, err) - base := &responsesInterceptionBase{ - reqPayload: rp, - } - - callID := base.CorrelatingToolCallID() - assert.Equal(t, tc.wantCall, callID) - }) - } -} - -func TestLastUserPrompt(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - reqPayload []byte - expect string - }{ - { - name: "input_empty_string", - reqPayload: []byte(`{"input": ""}`), - expect: "", - }, - { - name: "input_array_content_empty_string", - reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`), - expect: "", - }, - { - name: "input_array_content_array_empty_string", - reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": ""}] } ] }`), - }, - { - name: "input_array_content_array_multiple_inputs", - reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": "a"}, {"type": "input_text", "text": "b"}] } ] }`), - expect: "a\nb", - }, - { - name: "simple_string_input", - reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple), - expect: "tell me a joke", - }, - { - name: "array_single_input_string", - reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSingleBuiltinTool), - expect: "Is 3 + 5 a prime number? Use the add function to calculate the sum.", - }, - { - name: "array_multiple_items_content_objects", - reqPayload: fixtures.Request(t, fixtures.OaiResponsesStreamingCodex), - expect: "hello", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - rp, err := NewResponsesRequestPayload(tc.reqPayload) - require.NoError(t, err) - base := &responsesInterceptionBase{ - reqPayload: rp, - } - - prompt, promptFound, err := base.lastUserPrompt(t.Context()) - require.NoError(t, err) - require.Equal(t, tc.expect, prompt) - require.True(t, promptFound) - }) - } -} - -func TestLastUserPromptNotFound(t *testing.T) { - t.Parallel() - - t.Run("nil_struct", func(t *testing.T) { - t.Parallel() - - var base *responsesInterceptionBase - prompt, promptFound, err := base.lastUserPrompt(t.Context()) - require.Error(t, err) - require.Empty(t, prompt) - require.False(t, promptFound) - require.Contains(t, "cannot get last user prompt: nil struct", err.Error()) - }) - - t.Run("nil_request", func(t *testing.T) { - t.Parallel() - - base := responsesInterceptionBase{} - prompt, promptFound, err := base.lastUserPrompt(t.Context()) - require.Error(t, err) - require.Empty(t, prompt) - require.False(t, promptFound) - require.Contains(t, "cannot get last user prompt: nil request struct", err.Error()) - }) - - // Cases where the user prompt is not found / wrong format. - tests := []struct { - name string - reqPayload []byte - expectErr string - }{ - { - name: "non_existing_input", - reqPayload: []byte(`{"model": "gpt-4o"}`), - }, - { - name: "input_empty_array", - reqPayload: []byte(`{"model": "gpt-4o", "input": []}`), - }, - { - name: "input_integer", - reqPayload: []byte(`{"model": "gpt-4o", "input": 123}`), - expectErr: "unexpected input type", - }, - { - name: "no_user_role", - reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "assistant", "content": "hello"}]}`), - }, - { - name: "user_with_empty_content_array", - reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": []}]}`), - }, - { - name: "input_array_integer", - reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": 123}]}`), - expectErr: "unexpected input content type", - }, - { - name: "user_with_non_input_text_content", - reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": [{"type": "input_image", "url": "http://example.com/img.png"}]}]}`), - }, - { - name: "user_content_not_last", - reqPayload: []byte(`{"model": "gpt-4o", "input": [ {"role": "user", "content":"input"}, {"role": "assistant", "content": "hello"} ]}`), - }, - { - name: "input_array_content_array_integer", - reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": 123}] } ] }`), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - rp, err := NewResponsesRequestPayload(tc.reqPayload) - require.NoError(t, err) - - base := &responsesInterceptionBase{ - reqPayload: rp, - } - - prompt, promptFound, err := base.lastUserPrompt(t.Context()) - if tc.expectErr != "" { - require.Error(t, err) - require.Contains(t, err.Error(), tc.expectErr) - } else { - require.NoError(t, err) - } - require.Empty(t, prompt) - require.False(t, promptFound) - }) - } -} - func TestRecordPrompt(t *testing.T) { t.Parallel() diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 9d8c5f3..6294431 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -27,7 +27,7 @@ type BlockingResponsesInterceptor struct { func NewBlockingInterceptor( id uuid.UUID, - reqPayload []byte, + reqPayload ResponsesRequestPayload, cfg config.OpenAI, clientHeaders http.Header, authHeaderName string, @@ -74,7 +74,7 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * firstResponseID string ) - prompt, promptFound, err := i.lastUserPrompt(ctx) + prompt, promptFound, err := i.reqPayload.lastUserPrompt(ctx, i.logger) if err != nil { i.logger.Warn(ctx, "failed to get user prompt", slog.Error(err)) } diff --git a/intercept/responses/reqpayload.go b/intercept/responses/reqpayload.go index b7535c3..238f356 100644 --- a/intercept/responses/reqpayload.go +++ b/intercept/responses/reqpayload.go @@ -2,9 +2,12 @@ package responses import ( "bytes" + "context" "encoding/json" "fmt" + "strings" + "cdr.dev/slog/v3" "github.com/openai/openai-go/v3/responses" "github.com/openai/openai-go/v3/shared/constant" "github.com/tidwall/gjson" @@ -13,13 +16,24 @@ import ( const ( reqPathBackground = "background" + reqPathCallID = "call_id" + reqPathRole = "role" reqPathInput = "input" reqPathParallelToolCalls = "parallel_tool_calls" reqPathStream = "stream" reqPathTools = "tools" ) -var reqPathModel = string(constant.ValueOf[constant.Model]()) +var ( + constFunctionCallOutput = string(constant.ValueOf[constant.FunctionCallOutput]()) + constInputText = string(constant.ValueOf[constant.InputText]()) + constUser = string(constant.ValueOf[constant.User]()) + + reqPathContent = string(constant.ValueOf[constant.Content]()) + reqPathModel = string(constant.ValueOf[constant.Model]()) + reqPathText = string(constant.ValueOf[constant.Text]()) + reqPathType = string(constant.ValueOf[constant.Type]()) +) // ResponsesRequestPayload is raw JSON bytes of a Responses API request. // Methods provide package-specific reads and rewrites while preserving the @@ -50,6 +64,108 @@ func (p ResponsesRequestPayload) background() bool { return gjson.GetBytes(p, reqPathBackground).Bool() } +func (p ResponsesRequestPayload) correlatingToolCallID() *string { + items := gjson.GetBytes(p, reqPathInput) + if !items.IsArray() { + return nil + } + + arr := items.Array() + if len(arr) == 0 { + return nil + } + + last := arr[len(arr)-1] + if last.Get(reqPathType).String() != constFunctionCallOutput { + return nil + } + + callID := last.Get(reqPathCallID).String() + if callID == "" { + return nil + } + + return &callID +} + +// LastUserPrompt returns input text with the "user" role from the last input +// item, or the string input value if present. If no prompt is found, it returns +// empty string, false, nil. Unexpected shapes are treated as unsupported and do +// not fail the request path. +func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog.Logger) (string, bool, error) { + inputItems := gjson.GetBytes(p, reqPathInput) + if !inputItems.Exists() || inputItems.Type == gjson.Null { + return "", false, nil + } + + // 'input' can be either a string or an array of input items: + // https://platform.openai.com/docs/api-reference/responses/create#responses_create-input + + // String variant: treat the whole input as the user prompt. + if inputItems.Type == gjson.String { + return inputItems.String(), true, nil + } + + // Array variant: checking only the last input item + if !inputItems.IsArray() { + return "", false, fmt.Errorf("unexpected input type: %s", inputItems.Type) + } + + inputItemsArr := inputItems.Array() + if len(inputItemsArr) == 0 { + return "", false, nil + } + + lastItem := inputItemsArr[len(inputItemsArr)-1] + if lastItem.Get(reqPathRole).Str != constUser { + // Request was likely not initiated by a prompt but is an iteration of agentic loop. + return "", false, nil + } + + // Message content can be either a string or an array of typed content items: + // https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content + content := lastItem.Get(reqPathContent) + if !content.Exists() || content.Type == gjson.Null { + return "", false, nil + } + + // String variant: use it directly as the prompt. + if content.Type == gjson.String { + return content.Str, true, nil + } + + if !content.IsArray() { + return "", false, fmt.Errorf("unexpected input content type: %s", content.Type) + } + + var sb strings.Builder + promptExists := false + for _, c := range content.Array() { + // Ignore non-text content blocks such as images or files. + if c.Get(reqPathType).Str != constInputText { + continue + } + + text := c.Get(reqPathText) + if text.Type != gjson.String { + logger.Warn(ctx, fmt.Sprintf("unexpected input content array element text type: %v", text.Type)) + continue + } + + if promptExists { + sb.WriteByte('\n') + } + promptExists = true + sb.WriteString(text.Str) + } + + if !promptExists { + return "", false, nil + } + + return sb.String(), true, nil +} + func (p ResponsesRequestPayload) injectTools(injected []responses.ToolUnionParam) (ResponsesRequestPayload, error) { if len(injected) == 0 { return p, nil @@ -136,9 +252,9 @@ func (p ResponsesRequestPayload) toolItems() ([]json.RawMessage, error) { } func (p ResponsesRequestPayload) set(path string, value any) (ResponsesRequestPayload, error) { - b, err := sjson.SetBytes(p, path, value) + updated, err := sjson.SetBytes(p, path, value) if err != nil { - return ResponsesRequestPayload(b), fmt.Errorf("failed to set value at path %s: %w", path, err) + return p, fmt.Errorf("failed to set value at path %s: %w", path, err) } - return ResponsesRequestPayload(b), nil + return updated, nil } diff --git a/intercept/responses/reqpayload_test.go b/intercept/responses/reqpayload_test.go index aa5e906..b0338fb 100644 --- a/intercept/responses/reqpayload_test.go +++ b/intercept/responses/reqpayload_test.go @@ -5,6 +5,9 @@ import ( "fmt" "testing" + "cdr.dev/slog/v3" + "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/utils" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/responses" "github.com/stretchr/testify/assert" @@ -56,7 +59,6 @@ func TestNewResponsesRequestPayload(t *testing.T) { t.Parallel() payload, err := NewResponsesRequestPayload(tc.raw) - assert.EqualValues(t, tc.want, payload) if tc.err != "" { require.ErrorContains(t, err, tc.err) @@ -65,6 +67,8 @@ func TestNewResponsesRequestPayload(t *testing.T) { } require.NoError(t, err) + require.NotNil(t, payload) + assert.EqualValues(t, tc.want, payload) assert.Equal(t, tc.model, payload.model()) assert.Equal(t, tc.stream, payload.Stream()) assert.Equal(t, tc.background, payload.background()) @@ -72,6 +76,186 @@ func TestNewResponsesRequestPayload(t *testing.T) { } } +func TestCorrelatingToolCallID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + payload []byte + wantCall *string + }{ + { + name: "no input items", + payload: []byte(`{"model":"gpt-4o"}`), + }, + { + name: "empty input array", + payload: []byte(`{"model":"gpt-4o","input":[]}`), + }, + { + name: "no function_call_output items", + payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"}]}`), + }, + { + name: "single function_call_output", + payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_abc","output":"result"}]}`), + wantCall: utils.PtrTo("call_abc"), + }, + { + name: "multiple function_call_outputs returns last", + payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_second","output":"r2"}]}`), + wantCall: utils.PtrTo("call_second"), + }, + { + name: "last input is not a tool result", + payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"}]}`), + }, + { + name: "missing call id", + payload: []byte(`{"input":[{"type":"function_call_output","output":"ok"}]}`), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + callID := mustPayload(t, tc.payload).correlatingToolCallID() + assert.Equal(t, tc.wantCall, callID) + }) + } +} + +func TestLastUserPrompt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + reqPayload []byte + expect string + found bool + expectErr string + }{ + { + name: "no input", + reqPayload: []byte(`{}`), + found: false, + }, + { + name: "input null", + reqPayload: []byte(`{"input": null}`), + found: false, + }, + { + name: "empty input array", + reqPayload: []byte(`{"input": []}`), + found: false, + }, + { + name: "input empty string", + reqPayload: []byte(`{"input": ""}`), + expect: "", + found: true, + }, + { + name: "input array content empty string", + reqPayload: []byte(`{"input": [{"role": "user", "content": ""}]}`), + expect: "", + found: true, + }, + { + name: "input array content array empty string", + reqPayload: []byte(`{"input": [ { "role": "user", "content": [{"type": "input_text", "text": ""}] } ] }`), + expect: "", + found: true, + }, + { + name: "input array content array multiple inputs", + reqPayload: []byte(`{"input": [ { "role": "user", "content": [{"type": "input_text", "text": "a"}, {"type": "input_text", "text": "b"}] } ] }`), + expect: "a\nb", + found: true, + }, + { + name: "simple string input", + reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple), + expect: "tell me a joke", + found: true, + }, + { + name: "array single input string", + reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSingleBuiltinTool), + expect: "Is 3 + 5 a prime number? Use the add function to calculate the sum.", + found: true, + }, + { + name: "array multiple items content objects", + reqPayload: fixtures.Request(t, fixtures.OaiResponsesStreamingCodex), + expect: "hello", + found: true, + }, + { + name: "input integer", + reqPayload: []byte(`{"input": 123}`), + expectErr: "unexpected input type", + }, + { + name: "no user role", + reqPayload: []byte(`{"input": [{"role": "assistant", "content": "hello"}]}`), + found: false, + }, + { + name: "user with empty content array", + reqPayload: []byte(`{"input": [{"role": "user", "content": []}]}`), + found: false, + }, + { + name: "user content missing", + reqPayload: []byte(`{"input": [{"role": "user"}]}`), + found: false, + }, + { + name: "user content null", + reqPayload: []byte(`{"input": [{"role": "user", "content": null}]}`), + found: false, + }, + { + name: "input array integer", + reqPayload: []byte(`{"input": [{"role": "user", "content": 123}]}`), + expectErr: "unexpected input content type", + }, + { + name: "user with non input_text content", + reqPayload: []byte(`{"input": [{"role": "user", "content": [{"type": "input_image", "url": "http://example.com/img.png"}]}]}`), + found: false, + }, + { + name: "user content not last", + reqPayload: []byte(`{"input": [ {"role": "user", "content":"input"}, {"role": "assistant", "content": "hello"} ]}`), + found: false, + }, + { + name: "input array content array integer", + reqPayload: []byte(`{"input": [ { "role": "user", "content": [{"type": "input_text", "text": 123}] } ] }`), + found: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + prompt, promptFound, err := mustPayload(t, tc.reqPayload).lastUserPrompt(t.Context(), slog.Make()) + if tc.expectErr != "" { + require.ErrorContains(t, err, tc.expectErr) + return + } + require.NoError(t, err) + require.Equal(t, tc.expect, prompt) + require.Equal(t, tc.found, promptFound) + }) + } +} + func TestInjectTools(t *testing.T) { t.Parallel() @@ -137,7 +321,7 @@ func TestInjectTools(t *testing.T) { } if tc.wantSame { - require.Equal(t, p, updated) + require.EqualValues(t, tc.raw, updated) } for i, wantName := range tc.wantNames { path := fmt.Sprintf("tools.%d.name", i) // name of the i-th element in tools array @@ -282,7 +466,7 @@ func TestAppendInputItems(t *testing.T) { } if tc.wantSame { - require.Equal(t, p, updated) + require.EqualValues(t, tc.raw, updated) } for path, want := range tc.wantPaths { diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 5d899a4..dbd5067 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -34,7 +34,7 @@ type StreamingResponsesInterceptor struct { func NewStreamingInterceptor( id uuid.UUID, - reqPayload []byte, + reqPayload ResponsesRequestPayload, cfg config.OpenAI, clientHeaders http.Header, authHeaderName string, @@ -92,7 +92,7 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r var innerLoopErr error var streamErr error - prompt, promptFound, err := i.lastUserPrompt(ctx) + prompt, promptFound, err := i.reqPayload.lastUserPrompt(ctx, i.logger) if err != nil { i.logger.Warn(ctx, "failed to get user prompt", slog.Error(err)) }