diff --git a/intercept/responses/base.go b/intercept/responses/base.go index f1bc3ae7..da53c15a 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -38,21 +38,19 @@ const ( ) type responsesInterceptionBase struct { - id uuid.UUID - req *ResponsesNewParamsWrapper - reqPayload []byte - cfg config.OpenAI - model string - + id uuid.UUID // clientHeaders are the original HTTP headers from the client request. clientHeaders http.Header authHeaderName string + reqPayload ResponsesRequestPayload + cfg config.OpenAI recorder recorder.Recorder mcpProxy mcp.ServerProxier - logger slog.Logger - metrics metrics.Metrics - tracer trace.Tracer + + logger slog.Logger + metrics metrics.Metrics + tracer trace.Tracer } func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { @@ -88,26 +86,37 @@ func (i *responsesInterceptionBase) ID() uuid.UUID { } func (i *responsesInterceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { - i.logger = logger.With(slog.F("model", i.model)) + i.logger = logger.With(slog.F("model", i.Model())) i.recorder = recorder i.mcpProxy = mcpProxy } func (i *responsesInterceptionBase) Model() string { - return i.model + return i.reqPayload.model() } func (i *responsesInterceptionBase) CorrelatingToolCallID() *string { - if len(i.req.Input.OfInputItemList) == 0 { + items := gjson.GetBytes(i.reqPayload, "input") + if !items.IsArray() { + return nil + } + + arr := items.Array() + if len(arr) == 0 { + return nil + } + + last := arr[len(arr)-1] + if last.Get(string(constant.ValueOf[constant.Type]())).String() != string(constant.ValueOf[constant.FunctionCallOutput]()) { return nil } - // The tool result should be the last input message. - item := i.req.Input.OfInputItemList[len(i.req.Input.OfInputItemList)-1] - if item.OfFunctionCallOutput == nil { + callID := last.Get("call_id").String() + if callID == "" { return nil } - return &item.OfFunctionCallOutput.CallID + + return &callID } func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { @@ -122,13 +131,7 @@ func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streami } func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.ResponseWriter) error { - if i.req == nil { - err := errors.New("developer error: req is nil") - i.sendCustomErr(ctx, w, http.StatusInternalServerError, err) - return err - } - - if i.req.Background.Value { + if i.reqPayload.background() { err := fmt.Errorf("background requests are currently not supported by AI Bridge") i.sendCustomErr(ctx, w, http.StatusNotImplemented, err) return err @@ -161,7 +164,7 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o // eg. Codex CLI produces requests without ID set in reasoning items: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-item-reasoning-id // when re-encoded, ID field is set to empty string which results // in bad request while not sending ID field at all somehow works. - option.WithRequestBody("application/json", i.reqPayload), + option.WithRequestBody("application/json", []byte(i.reqPayload)), // copyMiddleware copies body of original response body to the buffer in responseCopier, // also reference to headers and status code is kept responseCopier. @@ -169,7 +172,7 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o // eliminating any possibility of JSON re-encoding issues. option.WithMiddleware(respCopy.copyMiddleware), } - if !i.req.Stream { + if !i.reqPayload.Stream() { opts = append(opts, option.WithRequestTimeout(requestTimeout)) } return opts @@ -182,77 +185,80 @@ func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (string, if i == nil { return "", false, errors.New("cannot get last user prompt: nil struct") } - if i.req == nil { + if i.reqPayload == nil { return "", false, errors.New("cannot get last user prompt: nil request struct") } - // 'input' field can be a string or array of objects: + // 'input' can be either a string or an array of input items: // https://platform.openai.com/docs/api-reference/responses/create#responses_create-input - - // Check string variant - if i.req.Input.OfString.Valid() { - return i.req.Input.OfString.Value, true, nil + inputItems := gjson.GetBytes(i.reqPayload, "input") + if !inputItems.Exists() || inputItems.Type == gjson.Null { + return "", false, nil } - // Fallback to parsing original bytes since golang SDK doesn't properly decode 'Input' field. - // If 'type' field of input item is not set it will be omitted from 'Input.OfInputItemList' - // It is an optional field according to API: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message - // example: fixtures/openai/responses/blocking/builtin_tool.txtar - inputItems := gjson.GetBytes(i.reqPayload, "input") + // String variant: treat the whole input as the user prompt. + if inputItems.Type == gjson.String { + return inputItems.String(), true, nil + } + // Array variant: checking only the last input item if !inputItems.IsArray() { - if inputItems.Type == gjson.Null { - return "", false, nil - } - return "", false, fmt.Errorf("unexpected input type: %v", inputItems.Type.String()) + return "", false, fmt.Errorf("unexpected input type: %s", inputItems.Type) } inputItemsArr := inputItems.Array() if len(inputItemsArr) == 0 { return "", false, nil } - lastItem := inputItemsArr[len(inputItemsArr)-1] - // Request was likely not human-initiated. + lastItem := inputItemsArr[len(inputItemsArr)-1] if lastItem.Get("role").Str != string(constant.ValueOf[constant.User]()) { + // Request was likely not initiated by a prompt but is an iteration of agentic loop. return "", false, nil } - // content can be a string or array of objects: + // Message content can be either a string or an array of typed content items: // https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content content := lastItem.Get(string(constant.ValueOf[constant.Content]())) + if !content.Exists() || content.Type == gjson.Null { + return "", false, nil + } + + // String variant: use it directly as the prompt. + if content.Type == gjson.String { + return content.Str, true, nil + } - // non array case, should be string if !content.IsArray() { - if content.Type == gjson.String { - return content.Str, true, nil - } - return "", false, fmt.Errorf("unexpected input content type: %v", content.Type.String()) + return "", false, fmt.Errorf("unexpected input content type: %s", content.Type) } var sb strings.Builder promptExists := false for _, c := range content.Array() { - // ignore inputs of not `input_text` type + // Ignore non-text content blocks such as images or files. if c.Get(string(constant.ValueOf[constant.Type]())).Str != string(constant.ValueOf[constant.InputText]()) { continue } text := c.Get(string(constant.ValueOf[constant.Text]())) - if text.Type == gjson.String { - promptExists = true - sb.WriteString(text.Str + "\n") - } else { + if text.Type != gjson.String { i.logger.Warn(ctx, fmt.Sprintf("unexpected input content array element text type: %v", text.Type)) + continue + } + + if promptExists { + sb.WriteByte('\n') } + promptExists = true + sb.WriteString(text.Str) } if !promptExists { return "", false, nil } - prompt := strings.TrimSuffix(sb.String(), "\n") - return prompt, true, nil + return sb.String(), true, nil } func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string, prompt string) { diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index ad0d59ab..9875f580 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -12,6 +12,7 @@ import ( "github.com/coder/aibridge/utils" "github.com/google/uuid" oairesponses "github.com/openai/openai-go/v3/responses" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -20,77 +21,38 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { tests := []struct { name string - input []oairesponses.ResponseInputItemUnionParam - expected *string + payload []byte + wantCall *string }{ { - name: "no input items", - input: nil, - expected: nil, + name: "no input", + payload: []byte(`{"model":"gpt-4o"}`), }, { - name: "no function_call_output items", - input: []oairesponses.ResponseInputItemUnionParam{ - { - OfMessage: &oairesponses.EasyInputMessageParam{ - Role: "user", - }, - }, - }, - expected: nil, + name: "empty input array", + payload: []byte(`{"model":"gpt-4o","input":[]}`), }, { - name: "single function_call_output", - input: []oairesponses.ResponseInputItemUnionParam{ - { - OfMessage: &oairesponses.EasyInputMessageParam{ - Role: "user", - }, - }, - { - OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{ - CallID: "call_abc", - }, - }, - }, - expected: utils.PtrTo("call_abc"), + name: "no function_call_output items", + payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"}]}`), }, { - name: "multiple function_call_outputs returns last", - input: []oairesponses.ResponseInputItemUnionParam{ - { - OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{ - CallID: "call_first", - }, - }, - { - OfMessage: &oairesponses.EasyInputMessageParam{ - Role: "user", - }, - }, - { - OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{ - CallID: "call_second", - }, - }, - }, - expected: utils.PtrTo("call_second"), + name: "single function_call_output", + payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_abc","output":"result"}]}`), + wantCall: utils.PtrTo("call_abc"), }, { - name: "last input is not a tool result", - input: []oairesponses.ResponseInputItemUnionParam{ - { - OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{ - CallID: "call_first", - }, - }, - { - OfMessage: &oairesponses.EasyInputMessageParam{ - Role: "user", - }, - }, - }, - expected: nil, + name: "multiple function_call_outputs returns last", + payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_second","output":"r2"}]}`), + wantCall: utils.PtrTo("call_second"), + }, + { + name: "last input is not a tool result", + payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"}]}`), + }, + { + name: "missing call id", + payload: []byte(`{"input":[{"type":"function_call_output","output":"ok"}]}`), }, } @@ -98,17 +60,14 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() + rp, err := NewResponsesRequestPayload(tc.payload) + require.NoError(t, err) base := &responsesInterceptionBase{ - req: &ResponsesNewParamsWrapper{ - ResponseNewParams: oairesponses.ResponseNewParams{ - Input: oairesponses.ResponseNewParamsInputUnion{ - OfInputItemList: tc.input, - }, - }, - }, + reqPayload: rp, } - require.Equal(t, tc.expected, base.CorrelatingToolCallID()) + callID := base.CorrelatingToolCallID() + assert.Equal(t, tc.wantCall, callID) }) } } @@ -161,13 +120,10 @@ func TestLastUserPrompt(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - req := &ResponsesNewParamsWrapper{} - err := req.UnmarshalJSON(tc.reqPayload) + rp, err := NewResponsesRequestPayload(tc.reqPayload) require.NoError(t, err) - base := &responsesInterceptionBase{ - req: req, - reqPayload: tc.reqPayload, + reqPayload: rp, } prompt, promptFound, err := base.lastUserPrompt(t.Context()) @@ -253,13 +209,11 @@ func TestLastUserPromptNotFound(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - req := &ResponsesNewParamsWrapper{} - err := req.UnmarshalJSON(tc.reqPayload) + rp, err := NewResponsesRequestPayload(tc.reqPayload) require.NoError(t, err) base := &responsesInterceptionBase{ - req: req, - reqPayload: tc.reqPayload, + reqPayload: rp, } prompt, promptFound, err := base.lastUserPrompt(t.Context()) diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index b0625844..9d8c5f32 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -27,10 +27,8 @@ type BlockingResponsesInterceptor struct { func NewBlockingInterceptor( id uuid.UUID, - req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, - model string, clientHeaders http.Header, authHeaderName string, tracer trace.Tracer, @@ -38,10 +36,8 @@ func NewBlockingInterceptor( return &BlockingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ id: id, - req: req, reqPayload: reqPayload, cfg: cfg, - model: model, clientHeaders: clientHeaders, authHeaderName: authHeaderName, tracer: tracer, @@ -138,5 +134,6 @@ func (i *BlockingResponsesInterceptor) newResponse(ctx context.Context, srv resp ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) - return srv.New(ctx, i.req.ResponseNewParams, opts...) + // The body is overridden by option.WithRequestBody(reqPayload) in requestOptions + return srv.New(ctx, responses.ResponseNewParams{}, opts...) } diff --git a/intercept/responses/injected_tools.go b/intercept/responses/injected_tools.go index e3720230..fee27218 100644 --- a/intercept/responses/injected_tools.go +++ b/intercept/responses/injected_tools.go @@ -9,20 +9,19 @@ import ( "cdr.dev/slog/v3" "github.com/coder/aibridge/recorder" "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/packages/param" "github.com/openai/openai-go/v3/responses" "github.com/openai/openai-go/v3/shared/constant" - "github.com/tidwall/sjson" ) func (i *responsesInterceptionBase) injectTools() { - if i.req == nil || i.mcpProxy == nil || !i.hasInjectableTools() { + if i.mcpProxy == nil || !i.hasInjectableTools() { return } i.disableParallelToolCalls() // Inject tools. + var injected []responses.ToolUnionParam for _, tool := range i.mcpProxy.ListTools() { var params map[string]any @@ -40,35 +39,34 @@ func (i *responsesInterceptionBase) injectTools() { params["required"] = tool.Required } - fn := responses.ToolUnionParam{ + injected = append(injected, responses.ToolUnionParam{ OfFunction: &responses.FunctionToolParam{ Name: tool.ID, Strict: openai.Bool(false), // TODO: configurable. Description: openai.String(tool.Description), Parameters: params, }, - } - - i.req.Tools = append(i.req.Tools, fn) + }) } - var err error - i.reqPayload, err = sjson.SetBytes(i.reqPayload, "tools", i.req.Tools) + updated, err := i.reqPayload.injectTools(injected) if err != nil { - i.logger.Warn(context.Background(), "failed to set tools", slog.Error(err)) + i.logger.Warn(context.Background(), "failed to inject tools", slog.Error(err)) + return } + i.reqPayload = updated } // disableParallelToolCalls disables parallel tool calls, to simplify the inner agentic loop. // This is best-effort, and failing to set this flag does not fail the request. // TODO: implement parallel tool calls. func (i *responsesInterceptionBase) disableParallelToolCalls() { - // Disable parallel tool calls to simplify inner agentic loop; best-effort. - var err error - i.reqPayload, err = sjson.SetBytes(i.reqPayload, "parallel_tool_calls", false) + updated, err := i.reqPayload.disableParallelToolCalls() if err != nil { i.logger.Warn(context.Background(), "failed to disable parallel_tool_calls", slog.Error(err)) + return } + i.reqPayload = updated } // handleInnerAgenticLoop orchestrates the inner agentic loop whereby injected tools @@ -120,53 +118,24 @@ func (i *responsesInterceptionBase) handleInjectedToolCalls(ctx context.Context, // prepareRequestForAgenticLoop prepares the request by setting the output of the given // response as input to the next request, in order for the tool call result(s) to make function correctly. func (i *responsesInterceptionBase) prepareRequestForAgenticLoop(ctx context.Context, response *responses.Response, toolResults []responses.ResponseInputItemUnionParam) error { - var err error - originalInputSize := len(i.req.Input.OfInputItemList) - - // Unset the string input; we need a list now. - if i.req.Input.OfString.Valid() { - // convert old string value to list item - i.req.Input.OfInputItemList = responses.ResponseInputParam{ - responses.ResponseInputItemParamOfMessage( - i.req.Input.OfString.Value, - responses.EasyInputMessageRoleUser, - ), - } - - // clear old value - i.req.Input.OfString = param.Opt[string]{} - } + // Collect new items to add: response outputs converted to input format + tool results. + var newItems []responses.ResponseInputItemUnionParam // OutputText is also available, but by definition the trigger for a function call is not a simple // text response from the model. for _, output := range response.Output { if inputItem := i.convertOutputToInput(output); inputItem != nil { - i.req.Input.OfInputItemList = append(i.req.Input.OfInputItemList, *inputItem) + newItems = append(newItems, *inputItem) } } + newItems = append(newItems, toolResults...) - for _, result := range toolResults { - i.req.Input.OfInputItemList = append(i.req.Input.OfInputItemList, result) - } - - // If original payload was in string format or was an empty list re-marshal whole input - if originalInputSize == 0 { - if i.reqPayload, err = sjson.SetBytes(i.reqPayload, "input", i.req.Input.OfInputItemList); err != nil { - i.logger.Error(ctx, "failure to marshal new input in inner agentic loop", slog.Error(err)) - return fmt.Errorf("failed to marshal input: %v", err) - } - return nil - } - - // Append newly added items to reqPayload field - // New items are appended to limit Input re-marshaling. - // See responsesInterceptionBase.requestOptions for more details about marshaling issues. - for j := originalInputSize; j < len(i.req.Input.OfInputItemList); j++ { - if i.reqPayload, err = sjson.SetBytes(i.reqPayload, "input.-1", i.req.Input.OfInputItemList[j]); err != nil { - i.logger.Error(ctx, "failure to marshal output item to new input in inner agentic loop", slog.Error(err)) - return fmt.Errorf("failed to marshal input: %v", err) - } + updated, err := i.reqPayload.appendInputItems(newItems) + if err != nil { + i.logger.Error(ctx, "failed to rewrite input in inner agentic loop", slog.Error(err)) + return fmt.Errorf("failed to rewrite input: %w", err) } + i.reqPayload = updated return nil } diff --git a/intercept/responses/paramswrap.go b/intercept/responses/paramswrap.go deleted file mode 100644 index 253910f2..00000000 --- a/intercept/responses/paramswrap.go +++ /dev/null @@ -1,28 +0,0 @@ -package responses - -import ( - "fmt" - - "github.com/openai/openai-go/v3/responses" - "github.com/tidwall/gjson" -) - -// ResponsesNewParamsWrapper exists because the "stream" param is not included -// in responses.ResponseNewParams. -type ResponsesNewParamsWrapper struct { - responses.ResponseNewParams - Stream bool `json:"stream,omitempty"` -} - -func (r *ResponsesNewParamsWrapper) UnmarshalJSON(raw []byte) error { - err := r.ResponseNewParams.UnmarshalJSON(raw) - if err != nil { - return fmt.Errorf("failed to unmarshal response params: %w", err) - } - - r.Stream = false - if stream := gjson.Get(string(raw), "stream"); stream.Bool() { - r.Stream = stream.Bool() - } - return nil -} diff --git a/intercept/responses/reqpayload.go b/intercept/responses/reqpayload.go new file mode 100644 index 00000000..b7535c35 --- /dev/null +++ b/intercept/responses/reqpayload.go @@ -0,0 +1,144 @@ +package responses + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/openai/openai-go/v3/responses" + "github.com/openai/openai-go/v3/shared/constant" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + reqPathBackground = "background" + reqPathInput = "input" + reqPathParallelToolCalls = "parallel_tool_calls" + reqPathStream = "stream" + reqPathTools = "tools" +) + +var reqPathModel = string(constant.ValueOf[constant.Model]()) + +// ResponsesRequestPayload is raw JSON bytes of a Responses API request. +// Methods provide package-specific reads and rewrites while preserving the +// original body for upstream pass-through. +// Note: No changes are made on schema error. +type ResponsesRequestPayload []byte + +func NewResponsesRequestPayload(raw []byte) (ResponsesRequestPayload, error) { + if len(bytes.TrimSpace(raw)) == 0 { + return nil, fmt.Errorf("empty request body") + } + if !json.Valid(raw) { + return nil, fmt.Errorf("invalid JSON payload") + } + + return ResponsesRequestPayload(raw), nil +} + +func (p ResponsesRequestPayload) Stream() bool { + return gjson.GetBytes(p, reqPathStream).Bool() +} + +func (p ResponsesRequestPayload) model() string { + return gjson.GetBytes(p, reqPathModel).String() +} + +func (p ResponsesRequestPayload) background() bool { + return gjson.GetBytes(p, reqPathBackground).Bool() +} + +func (p ResponsesRequestPayload) injectTools(injected []responses.ToolUnionParam) (ResponsesRequestPayload, error) { + if len(injected) == 0 { + return p, nil + } + + existing, err := p.toolItems() + if err != nil { + return p, fmt.Errorf("failed to get existing tools: %w", err) + } + + allTools := make([]any, 0, len(existing)+len(injected)) + for _, item := range existing { + allTools = append(allTools, item) + } + for _, tool := range injected { + allTools = append(allTools, tool) + } + + return p.set(reqPathTools, allTools) +} + +func (p ResponsesRequestPayload) disableParallelToolCalls() (ResponsesRequestPayload, error) { + return p.set(reqPathParallelToolCalls, false) +} + +func (p ResponsesRequestPayload) appendInputItems(items []responses.ResponseInputItemUnionParam) (ResponsesRequestPayload, error) { + if len(items) == 0 { + return p, nil + } + + existing, err := p.inputItems() + if err != nil { + return p, fmt.Errorf("failed to get existing 'input' items: %w", err) + } + + allInput := make([]any, 0, len(existing)+len(items)) + allInput = append(allInput, existing...) + for _, item := range items { + allInput = append(allInput, item) + } + + return p.set(reqPathInput, allInput) +} + +func (p ResponsesRequestPayload) inputItems() ([]any, error) { + input := gjson.GetBytes(p, reqPathInput) + if !input.Exists() || input.Type == gjson.Null { + return []any{}, nil + } + + if input.Type == gjson.String { + return []any{responses.ResponseInputItemParamOfMessage(input.String(), responses.EasyInputMessageRoleUser)}, nil + } + + if !input.IsArray() { + return nil, fmt.Errorf("unsupported 'input' type: %s", input.Type) + } + + items := input.Array() + existing := make([]any, 0, len(items)) + for _, item := range items { + existing = append(existing, json.RawMessage(item.Raw)) + } + + return existing, nil +} + +func (p ResponsesRequestPayload) toolItems() ([]json.RawMessage, error) { + tools := gjson.GetBytes(p, reqPathTools) + if !tools.Exists() { + return nil, nil + } + if !tools.IsArray() { + return nil, fmt.Errorf("unsupported 'tools' type: %s", tools.Type) + } + + items := tools.Array() + existing := make([]json.RawMessage, 0, len(items)) + for _, item := range items { + existing = append(existing, json.RawMessage(item.Raw)) + } + + return existing, nil +} + +func (p ResponsesRequestPayload) set(path string, value any) (ResponsesRequestPayload, error) { + b, err := sjson.SetBytes(p, path, value) + if err != nil { + return ResponsesRequestPayload(b), fmt.Errorf("failed to set value at path %s: %w", path, err) + } + return ResponsesRequestPayload(b), nil +} diff --git a/intercept/responses/reqpayload_test.go b/intercept/responses/reqpayload_test.go new file mode 100644 index 00000000..aa5e9069 --- /dev/null +++ b/intercept/responses/reqpayload_test.go @@ -0,0 +1,342 @@ +package responses + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/responses" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestNewResponsesRequestPayload(t *testing.T) { + t.Parallel() + + payloadWithWrongTypes := []byte(`{"model":123,"stream":"yes","input":42,"background":"nope"}`) + tests := []struct { + name string + raw []byte + want []byte + model string + stream bool + background bool + err string + }{ + { + name: "empty payload", + raw: nil, + want: nil, + err: "empty request body", + }, + { + name: "invalid json", + raw: []byte(`{broken`), + want: nil, + err: "invalid JSON payload", + }, + { + // ResponsesRequestPayload just checks for JSON validity, + // schema errors are not surfaced here and + // the original body is preserved for upstream handling + // similar to how reverse proxy would behave. + name: "wrong field types still wrap", + raw: payloadWithWrongTypes, + want: payloadWithWrongTypes, + model: "123", + stream: false, + background: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + payload, err := NewResponsesRequestPayload(tc.raw) + assert.EqualValues(t, tc.want, payload) + + if tc.err != "" { + require.ErrorContains(t, err, tc.err) + assert.Nil(t, payload) + return + } + + require.NoError(t, err) + assert.Equal(t, tc.model, payload.model()) + assert.Equal(t, tc.stream, payload.Stream()) + assert.Equal(t, tc.background, payload.background()) + }) + } +} + +func TestInjectTools(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw []byte + injected []responses.ToolUnionParam + wantNames []string + wantErr string + wantSame bool + }{ + { + name: "appends to existing tools", + raw: []byte(`{"model":"gpt-4o","input":"hello","tools":[{"type":"function","name":"existing"}]}`), + injected: []responses.ToolUnionParam{injectedFunctionTool("injected")}, + wantNames: []string{"existing", "injected"}, + }, + { + name: "adds tools when none exist", + raw: []byte(`{"model":"gpt-4o","input":"hello"}`), + injected: []responses.ToolUnionParam{injectedFunctionTool("injected")}, + wantNames: []string{"injected"}, + }, + { + name: "adds to empty tools array", + raw: []byte(`{"model":"gpt-4o","input":"hello","tools":[]}`), + injected: []responses.ToolUnionParam{injectedFunctionTool("injected")}, + wantNames: []string{"injected"}, + }, + { + name: "appends multiple injected tools", + raw: []byte(`{"model":"gpt-4o","input":"hello","tools":[{"type":"function","name":"existing"}]}`), + injected: []responses.ToolUnionParam{ + injectedFunctionTool("injected-one"), + injectedFunctionTool("injected-two"), + }, + wantNames: []string{"existing", "injected-one", "injected-two"}, + }, + { + name: "empty injected tools is no op", + raw: []byte(`{"model":"gpt-4o","input":"hello","tools":[{"type":"function","name":"existing"}]}`), + wantSame: true, + }, + { + name: "errors on unsupported tools shape", + raw: []byte(`{"model":"gpt-4o","input":"hello","tools":"bad"}`), + injected: []responses.ToolUnionParam{injectedFunctionTool("injected")}, + wantErr: "failed to get existing tools: unsupported 'tools' type: String", + wantSame: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := mustPayload(t, tc.raw) + updated, err := p.injectTools(tc.injected) + if tc.wantErr != "" { + require.EqualError(t, err, tc.wantErr) + } else { + require.NoError(t, err) + } + + if tc.wantSame { + require.Equal(t, p, updated) + } + for i, wantName := range tc.wantNames { + path := fmt.Sprintf("tools.%d.name", i) // name of the i-th element in tools array + require.Equal(t, wantName, gjson.GetBytes(updated, path).String()) + } + }) + } +} + +func TestDisableParallelToolCalls(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw []byte + }{ + { + name: "sets flag when not present", + raw: []byte(`{"model":"gpt-4o"}`), + }, + { + name: "overrides when already true", + raw: []byte(`{"model":"gpt-4o","parallel_tool_calls":true}`), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := mustPayload(t, tc.raw) + updated, err := p.disableParallelToolCalls() + require.NoError(t, err) + assert.False(t, gjson.GetBytes(updated, "parallel_tool_calls").Bool()) + }) + } +} + +func TestAppendInputItems(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw []byte + items []responses.ResponseInputItemUnionParam + wantErr string + wantSame bool + wantPaths map[string]string + }{ + { + name: "string input becomes user message", + raw: []byte(`{"model":"gpt-4o","input":"hello"}`), + items: []responses.ResponseInputItemUnionParam{responses.ResponseInputItemParamOfFunctionCallOutput("call_123", "done")}, + wantPaths: map[string]string{ + "input.0.role": "user", + "input.0.content": "hello", + "input.1.type": "function_call_output", + "input.1.call_id": "call_123", + }, + }, + { + name: "array input is preserved and appended", + raw: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hello"}]}`), + items: []responses.ResponseInputItemUnionParam{responses.ResponseInputItemParamOfFunctionCallOutput("call_123", "done")}, + wantPaths: map[string]string{ + "input.0.content": "hello", + "input.1.call_id": "call_123", + }, + }, + { + name: "unsupported input shape errors during rewrite", + raw: []byte(`{"model":"gpt-4o","input":123}`), + items: []responses.ResponseInputItemUnionParam{responses.ResponseInputItemParamOfFunctionCallOutput("call_123", "done")}, + wantErr: "failed to get existing 'input' items: unsupported 'input' type: Number", + wantSame: true, + }, + { + name: "missing input creates appended input", + raw: []byte(`{"model":"gpt-4o"}`), + items: []responses.ResponseInputItemUnionParam{responses.ResponseInputItemParamOfFunctionCallOutput("call_123", "done")}, + wantPaths: map[string]string{ + "input.0.type": "function_call_output", + "input.0.call_id": "call_123", + }, + }, + { + name: "null input creates appended input", + raw: []byte(`{"model":"gpt-4o","input":null}`), + items: []responses.ResponseInputItemUnionParam{responses.ResponseInputItemParamOfFunctionCallOutput("call_123", "done")}, + wantPaths: map[string]string{ + "input.0.type": "function_call_output", + "input.0.call_id": "call_123", + }, + }, + { + name: "multiple output item types are appended in order", + raw: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hello"}]}`), + items: []responses.ResponseInputItemUnionParam{ + responses.ResponseInputItemParamOfCompaction("encrypted-content"), + responses.ResponseInputItemParamOfOutputMessage([]responses.ResponseOutputMessageContentUnionParam{ + { + OfOutputText: &responses.ResponseOutputTextParam{ + Annotations: []responses.ResponseOutputTextAnnotationUnionParam{}, + Text: "assistant text", + }, + }, + }, "msg_123", responses.ResponseOutputMessageStatusCompleted), + responses.ResponseInputItemParamOfFileSearchCall("fs_123", []string{"hello"}, "completed"), + responses.ResponseInputItemParamOfImageGenerationCall("img_123", "base64-image", "completed"), + }, + wantPaths: map[string]string{ + "input.0.content": "hello", + "input.1.type": "compaction", + "input.2.type": "message", + "input.2.id": "msg_123", + "input.2.content.0.type": "output_text", + "input.2.content.0.text": "assistant text", + "input.3.type": "file_search_call", + "input.3.id": "fs_123", + "input.4.type": "image_generation_call", + "input.4.id": "img_123", + }, + }, + { + name: "empty appended items is no op", + raw: []byte(`{"model":"gpt-4o","input":"hello"}`), + wantSame: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := mustPayload(t, tc.raw) + updated, err := p.appendInputItems(tc.items) + + if tc.wantErr != "" { + require.EqualError(t, err, tc.wantErr) + } else { + require.NoError(t, err) + } + + if tc.wantSame { + require.Equal(t, p, updated) + } + + for path, want := range tc.wantPaths { + require.Equal(t, want, gjson.GetBytes(updated, path).String()) + } + }) + } +} + +func TestChainedRewritesProduceValidJSON(t *testing.T) { + t.Parallel() + + p := mustPayload(t, []byte(`{"model":"gpt-4o","input":"hello"}`)) + p, err := p.injectTools([]responses.ToolUnionParam{{ + OfFunction: &responses.FunctionToolParam{ + Name: "tool_a", + Description: openai.String("tool"), + Strict: openai.Bool(false), + Parameters: map[string]any{ + "type": "object", + }, + }, + }}) + require.NoError(t, err) + p, err = p.disableParallelToolCalls() + require.NoError(t, err) + p, err = p.appendInputItems([]responses.ResponseInputItemUnionParam{ + responses.ResponseInputItemParamOfFunctionCallOutput("call_123", "done"), + }) + require.NoError(t, err) + + assert.True(t, json.Valid(p), "chained rewrites should produce valid JSON") + assert.Equal(t, "tool_a", gjson.GetBytes(p, "tools.0.name").String()) + assert.Equal(t, "call_123", gjson.GetBytes(p, "input.1.call_id").String()) + assert.False(t, gjson.GetBytes(p, "parallel_tool_calls").Bool()) +} + +func injectedFunctionTool(name string) responses.ToolUnionParam { + return responses.ToolUnionParam{ + OfFunction: &responses.FunctionToolParam{ + Name: name, + Description: openai.String("tool"), + Strict: openai.Bool(false), + Parameters: map[string]any{ + "type": "object", + }, + }, + } +} + +func mustPayload(t *testing.T, raw []byte) ResponsesRequestPayload { + t.Helper() + + payload, err := NewResponsesRequestPayload(raw) + require.NoError(t, err) + return payload +} diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 80fc5b24..5d899a41 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -34,10 +34,8 @@ type StreamingResponsesInterceptor struct { func NewStreamingInterceptor( id uuid.UUID, - req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, - model string, clientHeaders http.Header, authHeaderName string, tracer trace.Tracer, @@ -45,10 +43,8 @@ func NewStreamingInterceptor( return &StreamingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ id: id, - req: req, reqPayload: reqPayload, cfg: cfg, - model: model, clientHeaders: clientHeaders, authHeaderName: authHeaderName, tracer: tracer, @@ -214,5 +210,6 @@ func (i *StreamingResponsesInterceptor) newStream(ctx context.Context, srv respo ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() - return srv.NewStreaming(ctx, i.req.ResponseNewParams, opts...) + // The body is overridden by option.WithRequestBody(reqPayload) in requestOptions + return srv.NewStreaming(ctx, responses.ResponseNewParams{}, opts...) } diff --git a/provider/copilot.go b/provider/copilot.go index 4bdf6a29..0218acb2 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -158,15 +158,15 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac if err != nil { return nil, fmt.Errorf("read body: %w", err) } - var req responses.ResponsesNewParamsWrapper - if err := json.Unmarshal(payload, &req); err != nil { - return nil, fmt.Errorf("unmarshal responses request body: %w", err) + reqPayload, err := responses.NewResponsesRequestPayload(payload) + if err != nil { + return nil, fmt.Errorf("unmarshal request body: %w", err) } - if req.Stream { - interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, req.Model, r.Header, p.AuthHeader(), tracer) + if reqPayload.Stream() { + interceptor = responses.NewStreamingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, req.Model, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewBlockingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer) } default: diff --git a/provider/copilot_test.go b/provider/copilot_test.go index a5df7bd4..c2a11a50 100644 --- a/provider/copilot_test.go +++ b/provider/copilot_test.go @@ -217,7 +217,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) { require.Error(t, err) require.Nil(t, interceptor) - assert.Contains(t, err.Error(), "unmarshal responses request body") + assert.Contains(t, err.Error(), "invalid JSON payload") }) t.Run("Responses_ClientHeaders", func(t *testing.T) { diff --git a/provider/openai.go b/provider/openai.go index dd68f0d9..f7668e14 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -115,14 +115,14 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace if err != nil { return nil, fmt.Errorf("read body: %w", err) } - var req responses.ResponsesNewParamsWrapper - if err := json.Unmarshal(payload, &req); err != nil { // TODO: should probably change to json.NewDecoder. + reqPayload, err := responses.NewResponsesRequestPayload(payload) + if err != nil { return nil, fmt.Errorf("unmarshal request body: %w", err) } - if req.Stream { - interceptor = responses.NewStreamingInterceptor(id, &req, payload, p.cfg, string(req.Model), r.Header, p.AuthHeader(), tracer) + if reqPayload.Stream() { + interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, &req, payload, p.cfg, string(req.Model), r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.cfg, r.Header, p.AuthHeader(), tracer) } default: