From 7291ab848d1f852b8ab473f2c98d8c5328286c80 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 11 May 2026 18:03:47 +0200 Subject: [PATCH 1/3] refactor: extract model-call pipeline into callModel method --- pkg/runtime/loop.go | 139 +++++++++++++++++++++++++++++--------------- 1 file changed, 93 insertions(+), 46 deletions(-) diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 443d647af..48ed7b5d6 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -509,60 +509,23 @@ func (r *LocalRuntime) runTurn( messages := sess.GetMessages(a, slices.Concat(ls.sessionStartMsgs, ls.userPromptMsgs, turnStartMsgs)...) slog.DebugContext(ctx, "Retrieved messages for processing", "agent", a.Name(), "message_count", len(messages)) - // before_llm_call hooks fire just before the model is invoked. - // A terminating verdict (e.g. from the max_iterations builtin) - // stops the run loop here, before any tokens are spent. Hooks - // may also rewrite the outgoing messages by returning - // HookSpecificOutput.UpdatedMessages — the redact_secrets - // builtin uses this to scrub secrets from chat content before - // the LLM ever sees it. The rewrite happens BEFORE the - // runtime's Go-only message transforms so a hook that drops a - // message (e.g. a custom "strip system reminders") doesn't get - // silently overridden by a transform later in the chain. - stop, msg, rewritten := r.executeBeforeLLMCallHooks(ctx, sess, a, modelID, ls.iteration, messages) - if stop { - slog.WarnContext(ctx, "before_llm_call hook signalled run termination", - "agent", a.Name(), "session_id", sess.ID, "reason", msg) - r.emitHookDrivenShutdown(ctx, a, sess, msg, events) + res, usedModel, callCtrl := r.callModel(streamCtx, ctx, sess, a, m, model, modelID, messages, agentTools, ls, streamSpan, events) + switch callCtrl { + case modelCallHookBlocked: endStreamSpan() endReason = turnEndReasonHookBlocked return turnExit - } - if rewritten != nil { - messages = rewritten - } - - // Apply registered before_llm_call message transforms (e.g. - // strip_unsupported_modalities for text-only models, plus any - // embedder-supplied redactor / scrubber registered via - // WithMessageTransform). Runs after the gate so a transform - // failure cannot waste the gate's allow verdict. modelID is - // passed explicitly so transforms see the actual model the - // loop chose (per-tool override + alloy-mode selection), - // not whatever a fresh agent.Model() call would re-randomize. - messages = r.applyBeforeLLMCallTransforms(ctx, sess, a, modelID, messages) - - // Try primary model with fallback chain if configured - res, usedModel, err := r.fallback.execute(streamCtx, a, model, messages, agentTools, sess, m, events) - if err != nil { - outcome := r.handleStreamError(ctx, sess, a, err, contextLimit, &ls.overflowCompactions, streamSpan, events) + case modelCallErrorRetry: + endStreamSpan() + endReason = turnEndReasonError + return turnContinue + case modelCallErrorFatal: endStreamSpan() endReason = turnEndReasonError - if outcome == streamErrorRetry { - return turnContinue - } return turnExit + case modelCallOK: } - // A successful model call resets the overflow compaction counter. - ls.overflowCompactions = 0 - - // after_llm_call hooks fire on success only; failed calls - // fire on_error above. The assistant text content is passed - // via stop_response, matching the stop event's payload, so - // handlers can reuse the same parsing. - r.executeAfterLLMCallHooks(ctx, sess, a, res.Content) - if usedModel != nil && usedModel.ID() != model.ID() { slog.InfoContext(ctx, "Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID()) events.Emit(AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage())) @@ -680,6 +643,90 @@ func (r *LocalRuntime) runTurn( return turnContinue } +// modelCallResult describes the outcome of [LocalRuntime.callModel] so +// the caller can branch without inspecting error values. +type modelCallResult int + +const ( + modelCallOK modelCallResult = iota // LLM call succeeded; response is valid. + modelCallHookBlocked // A before_llm_call hook signalled termination. + modelCallErrorRetry // Stream error; caller should retry the turn. + modelCallErrorFatal // Stream error; caller should exit the loop. +) + +// callModel runs the full LLM invocation pipeline for a single turn: +// +// 1. before_llm_call hooks — gate / rewrite +// 2. before_llm_call message transforms +// 3. fallback.execute (streaming model call with retry chain) +// 4. error handling (overflow compaction, telemetry) +// 5. after_llm_call hooks (success path only) +// +// Extracting this from [runTurn] concentrates the pre/post LLM logic in +// one place and makes it independently testable. The caller still owns +// span lifecycle, turn_end dispatch, and tool processing. +func (r *LocalRuntime) callModel( + streamCtx context.Context, + turnCtx context.Context, + sess *session.Session, + a *agent.Agent, + m *modelsdev.Model, + model provider.Provider, + modelID string, + messages []chat.Message, + agentTools []tools.Tool, + ls *loopState, + streamSpan trace.Span, + events EventSink, +) (streamResult, provider.Provider, modelCallResult) { + // before_llm_call hooks fire just before the model is invoked. + // A terminating verdict (e.g. from the max_iterations builtin) + // stops the run loop here, before any tokens are spent. Hooks + // may also rewrite the outgoing messages. + stop, msg, rewritten := r.executeBeforeLLMCallHooks(turnCtx, sess, a, modelID, ls.iteration, messages) + if stop { + slog.WarnContext(turnCtx, "before_llm_call hook signalled run termination", + "agent", a.Name(), "session_id", sess.ID, "reason", msg) + r.emitHookDrivenShutdown(turnCtx, a, sess, msg, events) + return streamResult{}, nil, modelCallHookBlocked + } + if rewritten != nil { + messages = rewritten + } + + // Apply registered before_llm_call message transforms (e.g. + // strip_unsupported_modalities for text-only models). Runs + // after the hook gate so a transform failure cannot waste the + // gate's allow verdict. + messages = r.applyBeforeLLMCallTransforms(turnCtx, sess, a, modelID, messages) + + // Try primary model with fallback chain if configured. + res, usedModel, err := r.fallback.execute(streamCtx, a, model, messages, agentTools, sess, m, events) + if err != nil { + outcome := r.handleStreamError(turnCtx, sess, a, err, modelContextLimit(m), &ls.overflowCompactions, streamSpan, events) + if outcome == streamErrorRetry { + return streamResult{}, nil, modelCallErrorRetry + } + return streamResult{}, nil, modelCallErrorFatal + } + + // A successful model call resets the overflow compaction counter. + ls.overflowCompactions = 0 + + // after_llm_call hooks fire on success only. + r.executeAfterLLMCallHooks(turnCtx, sess, a, res.Content) + + return res, usedModel, modelCallOK +} + +// modelContextLimit returns the model's context window size, or 0 if unknown. +func modelContextLimit(m *modelsdev.Model) int64 { + if m == nil { + return 0 + } + return int64(m.Limit.Context) +} + // Run executes the agent loop synchronously and returns the final session // messages. This is a convenience wrapper around RunStream for non-streaming // callers. From b054f276a98fc25f270a665113251a8441c65b94 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 11 May 2026 18:06:57 +0200 Subject: [PATCH 2/3] refactor: register built-in tool handlers directly as toolexec.ToolHandler Remove ToolHandlerFunc type and the per-batch adapter loop in processToolCalls. Built-in handlers (transfer_task, handoff, change_model, revert_model, run_skill) now match toolexec.ToolHandler directly and access events through r.toolSink, set for the duration of each processToolCalls call. --- pkg/runtime/agent_delegation.go | 6 +++--- pkg/runtime/loop.go | 5 +---- pkg/runtime/model_picker.go | 14 +++++++------- pkg/runtime/runtime.go | 9 ++++----- pkg/runtime/runtime_test.go | 6 ++++-- pkg/runtime/skill_runner.go | 4 ++-- pkg/runtime/tool_dispatch.go | 17 +++++++---------- 7 files changed, 28 insertions(+), 33 deletions(-) diff --git a/pkg/runtime/agent_delegation.go b/pkg/runtime/agent_delegation.go index 1c24a393a..81b76bd8f 100644 --- a/pkg/runtime/agent_delegation.go +++ b/pkg/runtime/agent_delegation.go @@ -391,7 +391,7 @@ func (r *LocalRuntime) RunAgent(ctx context.Context, params agenttool.RunParams) }, params.OnContent) } -func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, evts EventSink) (*tools.ToolCallResult, error) { +func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Session, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { var params struct { Agent string `json:"agent"` Task string `json:"task"` @@ -415,7 +415,7 @@ func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Ses )) defer span.End() - return r.runForwarding(ctx, sess, evts, delegationRequest{ + return r.runForwarding(ctx, sess, r.toolSink, delegationRequest{ SubSessionConfig: SubSessionConfig{ Task: params.Task, ExpectedOutput: params.ExpectedOutput, @@ -428,7 +428,7 @@ func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Ses }) } -func (r *LocalRuntime) handleHandoff(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, _ EventSink) (*tools.ToolCallResult, error) { +func (r *LocalRuntime) handleHandoff(ctx context.Context, sess *session.Session, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { var params handoff.Args if err := json.Unmarshal([]byte(toolCall.Function.Arguments), ¶ms); err != nil { return nil, fmt.Errorf("invalid arguments: %w", err) diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 48ed7b5d6..5d4b06ba7 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -38,11 +38,8 @@ func (r *LocalRuntime) registerDefaultTools() { r.toolMap[modelpicker.ToolNameChangeModel] = r.handleChangeModel r.toolMap[modelpicker.ToolNameRevertModel] = r.handleRevertModel r.toolMap[skills.ToolNameRunSkill] = r.handleRunSkill - r.bgAgents.RegisterHandlers(func(name string, fn func(context.Context, *session.Session, tools.ToolCall) (*tools.ToolCallResult, error)) { - r.toolMap[name] = func(ctx context.Context, sess *session.Session, tc tools.ToolCall, _ EventSink) (*tools.ToolCallResult, error) { - return fn(ctx, sess, tc) - } + r.toolMap[name] = fn }) } diff --git a/pkg/runtime/model_picker.go b/pkg/runtime/model_picker.go index f788a11a7..4d051a286 100644 --- a/pkg/runtime/model_picker.go +++ b/pkg/runtime/model_picker.go @@ -30,7 +30,7 @@ func (r *LocalRuntime) findModelPickerTool() *modelpicker.Tool { } // handleChangeModel handles the change_model tool call by switching the current agent's model. -func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session, toolCall tools.ToolCall, events EventSink) (*tools.ToolCallResult, error) { +func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { var params modelpicker.ChangeModelArgs if err := json.Unmarshal([]byte(toolCall.Function.Arguments), ¶ms); err != nil { return nil, fmt.Errorf("invalid arguments: %w", err) @@ -53,25 +53,25 @@ func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session )), nil } - return r.setModelAndEmitInfo(ctx, params.Model, events) + return r.setModelAndEmitInfo(ctx, params.Model) } // handleRevertModel handles the revert_model tool call by reverting the current agent to its default model. -func (r *LocalRuntime) handleRevertModel(ctx context.Context, _ *session.Session, _ tools.ToolCall, events EventSink) (*tools.ToolCallResult, error) { - return r.setModelAndEmitInfo(ctx, "", events) +func (r *LocalRuntime) handleRevertModel(ctx context.Context, _ *session.Session, _ tools.ToolCall) (*tools.ToolCallResult, error) { + return r.setModelAndEmitInfo(ctx, "") } // setModelAndEmitInfo sets the model for the current agent and emits an updated // AgentInfo event so the UI reflects the change. An empty modelRef reverts to // the agent's default model. -func (r *LocalRuntime) setModelAndEmitInfo(ctx context.Context, modelRef string, events EventSink) (*tools.ToolCallResult, error) { +func (r *LocalRuntime) setModelAndEmitInfo(ctx context.Context, modelRef string) (*tools.ToolCallResult, error) { currentName := r.CurrentAgentName() if err := r.SetAgentModel(ctx, currentName, modelRef); err != nil { return tools.ResultError(fmt.Sprintf("failed to set model: %v", err)), nil } - if a, err := r.team.Agent(currentName); err == nil { - events.Emit(AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage())) + if a, err := r.team.Agent(currentName); err == nil && r.toolSink != nil { + r.toolSink.Emit(AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage())) } else { slog.WarnContext(ctx, "Failed to retrieve agent after model change; UI may not reflect the update", "agent", currentName, "error", err) } diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index d60c91762..330618d1a 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -22,6 +22,7 @@ import ( "github.com/docker/docker-agent/pkg/hooks/builtins" "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/modelsdev" + "github.com/docker/docker-agent/pkg/runtime/toolexec" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/sessiontitle" "github.com/docker/docker-agent/pkg/team" @@ -32,9 +33,6 @@ import ( mcptools "github.com/docker/docker-agent/pkg/tools/mcp" ) -// ToolHandlerFunc is a function type for handling tool calls -type ToolHandlerFunc func(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, events EventSink) (*tools.ToolCallResult, error) - // Runtime defines the contract for runtime execution type Runtime interface { // CurrentAgentInfo returns information about the currently active agent @@ -165,7 +163,8 @@ type ModelStore interface { // LocalRuntime manages the execution of agents type LocalRuntime struct { - toolMap map[string]ToolHandlerFunc + toolMap map[string]toolexec.ToolHandler + toolSink EventSink team *team.Team agents *agentRouter resumeChan chan ResumeRequest @@ -461,7 +460,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { } r := &LocalRuntime{ - toolMap: make(map[string]ToolHandlerFunc), + toolMap: make(map[string]toolexec.ToolHandler), team: agents, agents: newAgentRouter(agents, defaultAgent.Name()), resumeChan: make(chan ResumeRequest), diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 03d0f5649..ada75a4dd 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -1879,7 +1879,8 @@ func TestTransferTaskRejectsNonSubAgent(t *testing.T) { }, } - result, err := rt.handleTaskTransfer(t.Context(), sess, toolCall, NewChannelSink(evts)) + rt.toolSink = NewChannelSink(evts) + result, err := rt.handleTaskTransfer(t.Context(), sess, toolCall) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError, "transfer to non-sub-agent should return an error result") @@ -1917,7 +1918,8 @@ func TestTransferTaskAllowsSubAgent(t *testing.T) { }, } - result, err := rt.handleTaskTransfer(t.Context(), sess, toolCall, NewChannelSink(evts)) + rt.toolSink = NewChannelSink(evts) + result, err := rt.handleTaskTransfer(t.Context(), sess, toolCall) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsError, "transfer to valid sub-agent should succeed") diff --git a/pkg/runtime/skill_runner.go b/pkg/runtime/skill_runner.go index 5f5fbd2fa..08c3553ba 100644 --- a/pkg/runtime/skill_runner.go +++ b/pkg/runtime/skill_runner.go @@ -28,7 +28,7 @@ import ( // // This implements the `context: fork` behaviour from the SKILL.md frontmatter, // following the same convention as Claude Code. -func (r *LocalRuntime) handleRunSkill(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, evts EventSink) (*tools.ToolCallResult, error) { +func (r *LocalRuntime) handleRunSkill(ctx context.Context, sess *session.Session, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { var args skills.RunSkillArgs if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil { return nil, fmt.Errorf("invalid arguments: %w", err) @@ -83,7 +83,7 @@ func (r *LocalRuntime) handleRunSkill(ctx context.Context, sess *session.Session // run_skill keeps the same agent (skills are sub-sessions of the // caller, not delegations to another agent), so we never swap the // runtime's currentAgent here. - return r.runForwarding(ctx, sess, evts, delegationRequest{ + return r.runForwarding(ctx, sess, r.toolSink, delegationRequest{ SubSessionConfig: SubSessionConfig{ Task: prepared.Task, SystemMessage: prepared.Content, diff --git a/pkg/runtime/tool_dispatch.go b/pkg/runtime/tool_dispatch.go index 6d3b7a44e..27e61679c 100644 --- a/pkg/runtime/tool_dispatch.go +++ b/pkg/runtime/tool_dispatch.go @@ -30,15 +30,12 @@ import ( // halts the *batch* but keeps the loop alive so the synthesised tool // error responses can be sent back to the model on the next turn. func (r *LocalRuntime) processToolCalls(ctx context.Context, sess *session.Session, calls []tools.ToolCall, agentTools []tools.Tool, events EventSink) (stopRun bool, stopMessage string) { - // Bind runtime-managed handlers (transfer_task, handoff, change_model, ...) - // to the current events channel: r.toolMap entries take chan Event, - // toolexec.ToolHandler doesn't. - handlers := make(map[string]toolexec.ToolHandler, len(r.toolMap)) - for name, h := range r.toolMap { - handlers[name] = func(ctx context.Context, sess *session.Session, tc tools.ToolCall) (*tools.ToolCallResult, error) { - return h(ctx, sess, tc, events) - } - } + // Set the per-batch event sink so built-in handlers (transfer_task, + // change_model, run_skill, etc.) can emit events without carrying + // EventSink in their signature. This is safe because processToolCalls + // is called serially within one goroutine per RunStream. + r.toolSink = events + defer func() { r.toolSink = nil }() d := &toolexec.Dispatcher{ Tracer: r.tracer, @@ -46,7 +43,7 @@ func (r *LocalRuntime) processToolCalls(ctx context.Context, sess *session.Sessi Resume: r.resumeChan, AgentFor: r.resolveSessionAgent, Permissions: r.permissionCheckers, - Handlers: handlers, + Handlers: r.toolMap, } return d.Process(ctx, sess, calls, agentTools, &sinkEmitter{events: events}) } From 19ce805fc7974f725bb47ca4035493a2e086d11c Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 11 May 2026 18:10:53 +0200 Subject: [PATCH 3/3] refactor: group compaction logic into sessionCompactor collaborator Move compactWithReason, compactIfNeeded, preCompactSourceFor, and joinPrompts into a sessionCompactor struct. The LocalRuntime holds a *sessionCompactor and delegates all compaction calls to it. This concentrates compaction reasoning (previously spread across runtime.go, loop.go, loop_steps.go, and session_compaction.go) in one place. --- pkg/runtime/loop.go | 54 ++-------- pkg/runtime/loop_steps.go | 2 +- pkg/runtime/runtime.go | 90 +--------------- pkg/runtime/session_compaction_test.go | 8 +- pkg/runtime/session_compactor.go | 142 +++++++++++++++++++++++++ 5 files changed, 156 insertions(+), 140 deletions(-) create mode 100644 pkg/runtime/session_compactor.go diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 5d4b06ba7..d71eec39d 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -14,7 +14,6 @@ import ( "github.com/docker/docker-agent/pkg/agent" "github.com/docker/docker-agent/pkg/chat" - "github.com/docker/docker-agent/pkg/compaction" "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/modelsdev" @@ -363,16 +362,13 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, var contextLimit int64 if m != nil { contextLimit = int64(m.Limit.Context) - - if r.sessionCompaction && compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, 0, contextLimit) { - r.compactWithReason(ctx, sess, "", compactionReasonThreshold, sink) - } + r.compactor.CompactIfOverThreshold(ctx, sess, m, sink) } // Drain steer messages queued while idle or before the first model call // (covers idle-window and first-turn-miss races). if drained, messageCountBeforeSteer := r.drainAndEmitSteered(ctx, sess, sink); drained { - r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeSteer, sink) + r.compactor.CompactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeSteer, sink) } // Everything from turn_start onwards is wrapped in a closure so a @@ -600,7 +596,7 @@ func (r *LocalRuntime) runTurn( // Drain steer messages that arrived during tool calls. if drained, _ := r.drainAndEmitSteered(ctx, sess, events); drained { - r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) + r.compactor.CompactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) endReason = turnEndReasonSteered return turnContinue } @@ -611,7 +607,7 @@ func (r *LocalRuntime) runTurn( // Re-check steer queue: closes the race between the mid-loop drain and this stop. if drained, _ := r.drainAndEmitSteered(ctx, sess, events); drained { - r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) + r.compactor.CompactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) endReason = turnEndReasonSteered return turnContinue } @@ -626,7 +622,7 @@ func (r *LocalRuntime) runTurn( userMsg := session.UserMessage(followUp.Content, followUp.MultiContent...) sess.AddMessage(userMsg) events.Emit(UserMessage(followUp.Content, sess.ID, followUp.MultiContent, len(sess.Messages)-1)) - r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) + r.compactor.CompactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) endReason = turnEndReasonContinue return turnContinue // re-enter the loop for a new turn } @@ -635,7 +631,7 @@ func (r *LocalRuntime) runTurn( return turnExit } - r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) + r.compactor.CompactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) endReason = turnEndReasonContinue return turnContinue } @@ -810,44 +806,6 @@ func (r *LocalRuntime) recordAssistantMessage( return msgUsage } -// compactIfNeeded estimates the token impact of tool results added since -// messageCountBefore and triggers proactive compaction when the estimated -// total exceeds 90% of the context window. This prevents sending an -// oversized request on the next iteration. -func (r *LocalRuntime) compactIfNeeded( - ctx context.Context, - sess *session.Session, - a *agent.Agent, - m *modelsdev.Model, - contextLimit int64, - messageCountBefore int, - events EventSink, -) { - if m == nil || !r.sessionCompaction || contextLimit <= 0 { - return - } - - newMessages := sess.GetAllMessages()[messageCountBefore:] - var addedTokens int64 - for _, msg := range newMessages { - addedTokens += compaction.EstimateMessageTokens(&msg.Message) - } - - if !compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, addedTokens, contextLimit) { - return - } - - slog.InfoContext(ctx, "Proactive compaction: tool results pushed estimated context past 90%% threshold", - "agent", a.Name(), - "input_tokens", sess.InputTokens, - "output_tokens", sess.OutputTokens, - "added_estimated_tokens", addedTokens, - "estimated_total", sess.InputTokens+sess.OutputTokens+addedTokens, - "context_limit", contextLimit, - ) - r.compactWithReason(ctx, sess, "", compactionReasonThreshold, events) -} - // getTools executes tool retrieval with automatic OAuth handling. // emitLifecycleEvents controls whether MCPInitStarted/Finished are emitted; // pass false when calling from reprobe to avoid spurious TUI spinner flicker. diff --git a/pkg/runtime/loop_steps.go b/pkg/runtime/loop_steps.go index f14148512..28fd7f208 100644 --- a/pkg/runtime/loop_steps.go +++ b/pkg/runtime/loop_steps.go @@ -175,7 +175,7 @@ func (r *LocalRuntime) handleStreamError( "The conversation has exceeded the model's context window. Automatically compacting the conversation history...", a.Name(), )) - r.compactWithReason(ctx, sess, "", compactionReasonOverflow, events) + r.compactor.Compact(ctx, sess, "", compactionReasonOverflow, events) return streamErrorRetry } diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 330618d1a..bfa42aa2e 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -20,7 +20,6 @@ import ( "github.com/docker/docker-agent/pkg/config/types" "github.com/docker/docker-agent/pkg/hooks" "github.com/docker/docker-agent/pkg/hooks/builtins" - "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/runtime/toolexec" "github.com/docker/docker-agent/pkg/session" @@ -165,6 +164,7 @@ type ModelStore interface { type LocalRuntime struct { toolMap map[string]toolexec.ToolHandler toolSink EventSink + compactor *sessionCompactor team *team.Team agents *agentRouter resumeChan chan ResumeRequest @@ -557,6 +557,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { // This avoids concurrent map writes when multiple goroutines call // RunStream on the same runtime (e.g. background agent sessions). r.registerDefaultTools() + r.compactor = &sessionCompactor{runtime: r} // Pre-build per-agent hook executors now that workingDir, env and // the team are finalized. Read-only afterwards. @@ -1260,90 +1261,5 @@ func (r *LocalRuntime) startSpan(ctx context.Context, name string, opts ...trace // Internal callers (proactive threshold, overflow recovery) use // [LocalRuntime.compactWithReason] directly to forward a more specific reason. func (r *LocalRuntime) Summarize(ctx context.Context, sess *session.Session, additionalPrompt string, events EventSink) { - r.compactWithReason(ctx, sess, additionalPrompt, compactionReasonManual, events) -} - -// compactWithReason runs a session compaction with the supplied reason and -// emits a TokenUsageEvent so the UI immediately reflects the new context -// pressure. -// -// reason is reported to BeforeCompaction / AfterCompaction hooks as -// CompactionReason. Use [compactionReasonThreshold] for proactive -// 90%-of-context triggers, [compactionReasonOverflow] for post-overflow -// auto-recovery, [compactionReasonToolOverflow] for tool-result-driven -// 90% triggers, or [compactionReasonManual] for user-invoked compactions. -// -// PreCompact hooks fire first via the legacy [hooks.Input.Source] field -// ("auto" / "tool_overflow" / "overflow" / "manual"); they may cancel the -// compaction or contribute additional steering text. BeforeCompaction -// hooks then fire inside [LocalRuntime.doCompact] with [Input.CompactionReason] -// set to the canonical reason; they may veto or supply a custom summary. -func (r *LocalRuntime) compactWithReason(ctx context.Context, sess *session.Session, additionalPrompt, reason string, events EventSink) { - // Stamp the session ID on ctx so the compaction LLM call carries - // `X-Cagent-Session-Id` to the gateway. Manual compaction - // (via `Summarize` from the App) bypasses `runStreamLoop`'s seed; - // internal callers (proactive threshold, overflow recovery) already - // run with a stamped ctx, but re-stamping is idempotent. - ctx = httpclient.ContextWithSessionID(ctx, sess.ID) - a := r.resolveSessionAgent(sess) - - source := preCompactSourceFor(reason) - skip, msg, extraPrompt := r.executePreCompactHooks(ctx, sess, a, source, events) - if skip { - slog.WarnContext(ctx, "pre_compact hook signalled skip", - "agent", a.Name(), "session_id", sess.ID, "source", source, "reason", msg) - if msg != "" { - events.Emit(Warning(msg, a.Name())) - } - return - } - additionalPrompt = joinPrompts(additionalPrompt, extraPrompt) - - r.doCompact(ctx, sess, a, additionalPrompt, reason, events) - - // Emit a TokenUsageEvent so the sidebar immediately reflects the - // compaction: tokens drop to the summary size, context % drops, and - // cost increases by the summary generation cost. - modelID := r.getEffectiveModelID(a) - var contextLimit int64 - if m, err := r.modelsStore.GetModel(ctx, modelID); err == nil && m != nil { - contextLimit = int64(m.Limit.Context) - } - events.Emit(NewTokenUsageEvent(sess.ID, a.Name(), SessionUsage(sess, contextLimit))) -} - -// preCompactSourceFor maps the canonical compaction reason -// ([compactionReasonThreshold] / [compactionReasonOverflow] / -// [compactionReasonManual]) onto the [hooks.Input.Source] string -// surfaced by the pre_compact hook ("auto" / "overflow" / "manual"). -// Unknown reasons fall through unchanged so future, more specific -// reasons (e.g. "tool_overflow") can be forwarded verbatim without -// touching this map. -func preCompactSourceFor(reason string) string { - switch reason { - case compactionReasonThreshold: - return "auto" - case compactionReasonOverflow: - return "overflow" - case compactionReasonManual: - return "manual" - default: - return reason - } -} - -// joinPrompts concatenates two non-empty prompt fragments with a blank -// line, returning whichever is non-empty when the other isn't. Used by -// compactWithReason to splice pre_compact's additional_context into -// the caller's additionalPrompt without having to special-case empty -// strings at the callsite. -func joinPrompts(a, b string) string { - switch { - case a == "": - return b - case b == "": - return a - default: - return a + "\n\n" + b - } + r.compactor.Compact(ctx, sess, additionalPrompt, compactionReasonManual, events) } diff --git a/pkg/runtime/session_compaction_test.go b/pkg/runtime/session_compaction_test.go index 3df7b42d9..78d21263e 100644 --- a/pkg/runtime/session_compaction_test.go +++ b/pkg/runtime/session_compaction_test.go @@ -128,7 +128,7 @@ func TestDoCompactBeforeHookDeniesSkipsCompaction(t *testing.T) { originalLen := len(sess.Messages) events := make(chan Event, 32) - rt.compactWithReason(t.Context(), sess, "", compactionReasonManual, NewChannelSink(events)) + rt.compactor.Compact(t.Context(), sess, "", compactionReasonManual, NewChannelSink(events)) close(events) var sawCompactionEvent, sawSummaryEvent bool @@ -184,7 +184,7 @@ func TestDoCompactBeforeHookSuppliesSummary(t *testing.T) { })) events := make(chan Event, 32) - rt.compactWithReason(t.Context(), sess, "", compactionReasonManual, NewChannelSink(events)) + rt.compactor.Compact(t.Context(), sess, "", compactionReasonManual, NewChannelSink(events)) close(events) var summaryEvent *SessionSummaryEvent @@ -264,7 +264,7 @@ func TestDoCompactAfterHookFires(t *testing.T) { sess.OutputTokens = 567 events := make(chan Event, 32) - rt.compactWithReason(t.Context(), sess, "", compactionReasonThreshold, NewChannelSink(events)) + rt.compactor.Compact(t.Context(), sess, "", compactionReasonThreshold, NewChannelSink(events)) close(events) for range events { } @@ -301,7 +301,7 @@ func TestDoCompactNoHooksMatchesPriorBehavior(t *testing.T) { })) events := make(chan Event, 32) - rt.compactWithReason(t.Context(), sess, "", compactionReasonManual, NewChannelSink(events)) + rt.compactor.Compact(t.Context(), sess, "", compactionReasonManual, NewChannelSink(events)) close(events) var startCount, doneCount int diff --git a/pkg/runtime/session_compactor.go b/pkg/runtime/session_compactor.go new file mode 100644 index 000000000..06cdaa10b --- /dev/null +++ b/pkg/runtime/session_compactor.go @@ -0,0 +1,142 @@ +package runtime + +import ( + "context" + "log/slog" + + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/compaction" + "github.com/docker/docker-agent/pkg/httpclient" + "github.com/docker/docker-agent/pkg/modelsdev" + "github.com/docker/docker-agent/pkg/session" +) + +// sessionCompactor concentrates session compaction logic that was +// previously scattered across runtime.go (compactWithReason, Summarize, +// preCompactSourceFor, joinPrompts), session_compaction.go (doCompact, +// summaryFromHook, compactionContextLimit, runCompactionAgent), and +// loop.go (compactIfNeeded). Grouping them here makes the compaction +// surface self-contained and independently understandable. +// +// The compactor is a collaborator of [LocalRuntime]: it holds references +// to the runtime's stores and uses callback functions for hook dispatch +// (hooks stay on the runtime because they depend on the per-agent +// [hooks.Executor]). +type sessionCompactor struct { + runtime *LocalRuntime +} + +// Compact runs a session compaction with the supplied reason and +// emits a TokenUsageEvent so the UI immediately reflects the new +// context pressure. +func (c *sessionCompactor) Compact(ctx context.Context, sess *session.Session, additionalPrompt, reason string, events EventSink) { + r := c.runtime + + ctx = httpclient.ContextWithSessionID(ctx, sess.ID) + a := r.resolveSessionAgent(sess) + + source := preCompactSourceFor(reason) + skip, msg, extraPrompt := r.executePreCompactHooks(ctx, sess, a, source, events) + if skip { + slog.WarnContext(ctx, "pre_compact hook signalled skip", + "agent", a.Name(), "session_id", sess.ID, "source", source, "reason", msg) + if msg != "" { + events.Emit(Warning(msg, a.Name())) + } + return + } + additionalPrompt = joinPrompts(additionalPrompt, extraPrompt) + + r.doCompact(ctx, sess, a, additionalPrompt, reason, events) + + modelID := r.getEffectiveModelID(a) + var contextLimit int64 + if m, err := r.modelsStore.GetModel(ctx, modelID); err == nil && m != nil { + contextLimit = int64(m.Limit.Context) + } + events.Emit(NewTokenUsageEvent(sess.ID, a.Name(), SessionUsage(sess, contextLimit))) +} + +// CompactIfNeeded estimates the token impact of tool results added since +// messageCountBefore and triggers proactive compaction when the estimated +// total exceeds 90% of the context window. This prevents sending an +// oversized request on the next iteration. +func (c *sessionCompactor) CompactIfNeeded( + ctx context.Context, + sess *session.Session, + a *agent.Agent, + m *modelsdev.Model, + contextLimit int64, + messageCountBefore int, + events EventSink, +) { + if m == nil || !c.runtime.sessionCompaction || contextLimit <= 0 { + return + } + + newMessages := sess.GetAllMessages()[messageCountBefore:] + var addedTokens int64 + for _, msg := range newMessages { + addedTokens += compaction.EstimateMessageTokens(&msg.Message) + } + + if !compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, addedTokens, contextLimit) { + return + } + + slog.InfoContext(ctx, "Proactive compaction: tool results pushed estimated context past 90% threshold", + "agent", a.Name(), + "input_tokens", sess.InputTokens, + "output_tokens", sess.OutputTokens, + "added_estimated_tokens", addedTokens, + "estimated_total", sess.InputTokens+sess.OutputTokens+addedTokens, + "context_limit", contextLimit, + ) + c.Compact(ctx, sess, "", compactionReasonThreshold, events) +} + +// CompactIfOverThreshold triggers compaction when the session's +// current token usage exceeds the 90% context threshold. Unlike +// CompactIfNeeded, this does not estimate additional tokens from new +// messages — it checks the session's existing InputTokens and +// OutputTokens against the model's context limit. +func (c *sessionCompactor) CompactIfOverThreshold(ctx context.Context, sess *session.Session, m *modelsdev.Model, events EventSink) { + if m == nil || !c.runtime.sessionCompaction { + return + } + contextLimit := int64(m.Limit.Context) + if contextLimit <= 0 { + return + } + if compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, 0, contextLimit) { + c.Compact(ctx, sess, "", compactionReasonThreshold, events) + } +} + +// preCompactSourceFor maps the canonical compaction reason onto the +// hooks.Input.Source string surfaced by the pre_compact hook. +func preCompactSourceFor(reason string) string { + switch reason { + case compactionReasonThreshold: + return "auto" + case compactionReasonOverflow: + return "overflow" + case compactionReasonManual: + return "manual" + default: + return reason + } +} + +// joinPrompts concatenates two non-empty prompt fragments with a blank +// line, returning whichever is non-empty when the other isn't. +func joinPrompts(a, b string) string { + switch { + case a == "": + return b + case b == "": + return a + default: + return a + "\n\n" + b + } +}