From 578a48d8691c81b9ebbfb278695dd323eaeefded Mon Sep 17 00:00:00 2001 From: pandego <7780875+pandego@users.noreply.github.com> Date: Sat, 11 Apr 2026 18:22:59 +0200 Subject: [PATCH] fix(openai): merge custom-provider system prompts Signed-off-by: pandego <7780875+pandego@users.noreply.github.com> --- pkg/model/provider/openai/client.go | 16 ++++++++-- pkg/model/provider/openai/client_test.go | 37 +++++++++++++++++++++++- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index d2d058f58..463c0c2e1 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -178,10 +178,20 @@ func (c *Client) Close() { } } -// convertMessages converts chat.Message to openai.ChatCompletionMessageParamUnion -// using the shared oaistream implementation. +// convertMessages converts chat.Message to OpenAI chat-completions message params. +// Custom OpenAI-compatible providers may target local model runners that reject +// consecutive system or user messages, so we normalize those prompts to match +// the DMR provider behavior. +func convertMessages(ctx context.Context, cfg *latest.ModelConfig, messages []chat.Message) []openai.ChatCompletionMessageParamUnion { + openaiMessages := oaistream.ConvertMessages(ctx, messages, cfg.Model) + if isCustomProvider(cfg) { + return oaistream.MergeConsecutiveMessages(openaiMessages) + } + return openaiMessages +} + func (c *Client) convertMessages(ctx context.Context, messages []chat.Message) []openai.ChatCompletionMessageParamUnion { - return oaistream.ConvertMessages(ctx, messages, c.ModelConfig.Model) + return convertMessages(ctx, &c.ModelConfig, messages) } // CreateChatCompletionStream creates a streaming chat completion request diff --git a/pkg/model/provider/openai/client_test.go b/pkg/model/provider/openai/client_test.go index 6bc1dd49b..8ceb48b93 100644 --- a/pkg/model/provider/openai/client_test.go +++ b/pkg/model/provider/openai/client_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/tools" ) @@ -82,7 +83,7 @@ func TestConvertMessagesToResponseInput_AssistantTextWithToolCalls(t *testing.T) } func TestConvertMessagesToResponseInput_NoOrphans(t *testing.T) { - // All tool calls have matching results — no placeholder needed. + // All tool calls have matching results - no placeholder needed. messages := []chat.Message{ {Role: chat.MessageRoleUser, Content: "hello"}, { @@ -104,3 +105,37 @@ func TestConvertMessagesToResponseInput_NoOrphans(t *testing.T) { } assert.Equal(t, 1, outputCount, "should not inject extra outputs when all calls have results") } + +func TestConvertMessages_MergesConsecutiveSystemMessagesForCustomProviders(t *testing.T) { + cfg := &latest.ModelConfig{ + ProviderOpts: map[string]any{"api_type": "openai_chatcompletions"}, + } + messages := []chat.Message{ + {Role: chat.MessageRoleSystem, Content: "You are Bob, a coding expert"}, + {Role: chat.MessageRoleSystem, Content: "## Custom Shell Tools\n\n### execute_command"}, + {Role: chat.MessageRoleSystem, Content: "\n what-time-is-it\n"}, + {Role: chat.MessageRoleUser, Content: "what is your favourite colour?"}, + } + + result := convertMessages(t.Context(), cfg, messages) + require.Len(t, result, 2) + require.NotNil(t, result[0].OfSystem) + assert.Contains(t, result[0].OfSystem.Content.OfString.Value, "You are Bob, a coding expert") + assert.Contains(t, result[0].OfSystem.Content.OfString.Value, "Custom Shell Tools") + assert.Contains(t, result[0].OfSystem.Content.OfString.Value, "available_skills") + assert.NotNil(t, result[1].OfUser) +} + +func TestConvertMessages_PreservesConsecutiveSystemMessagesForOpenAIProvider(t *testing.T) { + messages := []chat.Message{ + {Role: chat.MessageRoleSystem, Content: "System 1"}, + {Role: chat.MessageRoleSystem, Content: "System 2"}, + {Role: chat.MessageRoleUser, Content: "hello"}, + } + + result := convertMessages(t.Context(), &latest.ModelConfig{}, messages) + require.Len(t, result, 3) + assert.NotNil(t, result[0].OfSystem) + assert.NotNil(t, result[1].OfSystem) + assert.NotNil(t, result[2].OfUser) +}