From 785a1042aa5951170381d151be909d00544f41c6 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 7 May 2026 17:59:27 +0200 Subject: [PATCH] refactor: extract SnapshotController so the runtime no longer brokers /undo Drops WithSnapshots in favour of builtins.RegisterSnapshot returning a controller the embedder threads into both the runtime (as an AutoInjector) and the App. LocalRuntime no longer carries snapshot methods, and pkg/runtime/snapshot.go is gone. Fixes #2701 --- cmd/root/run.go | 62 ++++++++- lint/hook_builtins_registered.go | 11 +- pkg/app/app.go | 23 +++- pkg/app/app_test.go | 76 ++++++++--- pkg/app/undo.go | 72 +++++----- pkg/hooks/builtins/builtins.go | 63 +++++---- pkg/hooks/builtins/builtins_test.go | 80 +++++++++-- pkg/hooks/builtins/redact_secrets_test.go | 3 +- pkg/hooks/builtins/snapshot.go | 154 +++++++++++++++++----- pkg/hooks/builtins/snapshot_test.go | 30 ++--- pkg/runtime/hooks.go | 36 ++++- pkg/runtime/runtime.go | 101 +++++++++----- pkg/runtime/snapshot.go | 68 ---------- 13 files changed, 514 insertions(+), 265 deletions(-) delete mode 100644 pkg/runtime/snapshot.go diff --git a/cmd/root/run.go b/cmd/root/run.go index ef9f3251c..1b77e6f62 100644 --- a/cmd/root/run.go +++ b/cmd/root/run.go @@ -18,6 +18,8 @@ import ( "github.com/docker/docker-agent/pkg/app" "github.com/docker/docker-agent/pkg/cli" "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/hooks/builtins" pathx "github.com/docker/docker-agent/pkg/path" "github.com/docker/docker-agent/pkg/paths" "github.com/docker/docker-agent/pkg/permissions" @@ -68,6 +70,16 @@ type runExecFlags struct { // from user config settings. Nil when no global permissions are configured. globalPermissions *permissions.Checker snapshotsEnabled bool + + // snapshotController is the [builtins.SnapshotController] for the + // initial App: it is wired into the initial runtime as an + // auto-injector and into the App via app.WithSnapshotController so + // /undo, /snapshots, /reset drive the same instance that captures + // the checkpoints. Sub-runtimes created by [createSessionSpawner] + // build their own controller (and registry) so each spawned + // session has independent snapshot state; that controller is local + // to the spawner closure and never reaches this field. + snapshotController builtins.SnapshotController } func newRunCmd() *cobra.Command { @@ -319,12 +331,35 @@ func (f *runExecFlags) runtimeOpts(loadResult *teamloader.LoadResult, runConfig runtime.WithTracer(otel.Tracer(AppName)), runtime.WithModelSwitcherConfig(modelSwitcherCfg), } - if f.snapshotsEnabled { - opts = append(opts, runtime.WithSnapshots(true)) - } return opts } +// snapshotRuntimeOpts wires the snapshot builtin into a runtime. +// Returns the [runtime.Opt]s that hand the registry and the +// [builtins.SnapshotController] auto-injector to the runtime, plus +// the controller itself for the embedder to pass to the App via +// [app.WithSnapshotController]. When snapshots aren't enabled, +// returns no opts and a nil controller so callers don't have to +// branch on f.snapshotsEnabled themselves. +// +// A fresh registry is created here rather than reused across runtimes +// so the spawner-created sub-runtimes get their own snapshot state +// (each spawned session has independent /undo history). +func (f *runExecFlags) snapshotRuntimeOpts() ([]runtime.Opt, builtins.SnapshotController, error) { + if !f.snapshotsEnabled { + return nil, nil, nil + } + reg := hooks.NewRegistry() + ctrl, err := builtins.RegisterSnapshot(reg, true) + if err != nil { + return nil, nil, fmt.Errorf("register snapshot builtin: %w", err) + } + return []runtime.Opt{ + runtime.WithHooksRegistry(reg), + runtime.WithAutoInjector(ctrl), + }, ctrl, nil +} + func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadResult *teamloader.LoadResult, req CreateSessionRequest) (runtime.Runtime, *session.Session, error) { t := loadResult.Team @@ -350,10 +385,16 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadRes return nil, nil, fmt.Errorf("creating session store: %w", err) } - localRt, err := runtime.New(t, f.runtimeOpts(loadResult, &f.runConfig, sessStore, agentName)...) + rtOpts, ctrl, err := f.snapshotRuntimeOpts() + if err != nil { + return nil, nil, err + } + runtimeOpts := append(f.runtimeOpts(loadResult, &f.runConfig, sessStore, agentName), rtOpts...) + localRt, err := runtime.New(t, runtimeOpts...) if err != nil { return nil, nil, fmt.Errorf("creating runtime: %w", err) } + f.snapshotController = ctrl var sess *session.Session if req.ResumeSessionID != "" { @@ -463,6 +504,9 @@ func (f *runExecFlags) buildAppOpts(args []string) ([]app.Opt, error) { if f.exitAfterResponse { opts = append(opts, app.WithExitAfterFirstResponse()) } + if f.snapshotController != nil { + opts = append(opts, app.WithSnapshotController(f.snapshotController)) + } return opts, nil } @@ -504,7 +548,12 @@ func (f *runExecFlags) createSessionSpawner(agentSource config.Source, sessStore t.SetPermissions(permissions.Merge(t.Permissions(), f.globalPermissions)) } - localRt, err := runtime.New(t, f.runtimeOpts(loadResult, runConfigCopy, sessStore, agt.Name())...) + rtOpts, ctrl, err := f.snapshotRuntimeOpts() + if err != nil { + return nil, nil, nil, err + } + runtimeOpts := append(f.runtimeOpts(loadResult, runConfigCopy, sessStore, agt.Name()), rtOpts...) + localRt, err := runtime.New(t, runtimeOpts...) if err != nil { return nil, nil, nil, err } @@ -525,6 +574,9 @@ func (f *runExecFlags) createSessionSpawner(agentSource config.Source, sessStore if gen := localRt.TitleGenerator(); gen != nil { appOpts = append(appOpts, app.WithTitleGenerator(gen)) } + if ctrl != nil { + appOpts = append(appOpts, app.WithSnapshotController(ctrl)) + } a := app.New(spawnCtx, localRt, newSess, appOpts...) diff --git a/lint/hook_builtins_registered.go b/lint/hook_builtins_registered.go index 71826fe8c..e85ac1de7 100644 --- a/lint/hook_builtins_registered.go +++ b/lint/hook_builtins_registered.go @@ -8,7 +8,10 @@ import ( // HookBuiltinsRegistered enforces that every builtin-name constant declared // under pkg/hooks/builtins/ is wired into the package's Register() function -// in pkg/hooks/builtins/builtins.go. +// in pkg/hooks/builtins/builtins.go — with one exception: the snapshot +// builtin ships its own entry point ([builtins.RegisterSnapshot]) because +// it returns a [SnapshotController] for embedders, so its declaration in +// snapshot.go is intentionally not registered through Register(). // // Each in-process builtin lives in its own file with a name constant and an // implementation: @@ -71,12 +74,12 @@ var HookBuiltinsRegistered = &cop.Func{ } // exportedBuiltinNames returns the identifiers of every exported `const Name = "..."` -// declaration in pkg/hooks/builtins/, excluding builtins.go itself and any -// test files (which is not where new builtins land). +// declaration in pkg/hooks/builtins/, excluding builtins.go itself, snapshot.go +// (which has its own RegisterSnapshot entry point), and any test files. func exportedBuiltinNames(p *cop.Pass) ([]string, error) { files, err := p.ParseDir(".", cop.ParseDirOptions{ SkipTests: true, - SkipFiles: []string{"builtins.go"}, + SkipFiles: []string{"builtins.go", "snapshot.go"}, }) if err != nil { return nil, err diff --git a/pkg/app/app.go b/pkg/app/app.go index f009e3243..fbcb5f974 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -22,6 +22,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/cli" "github.com/docker/docker-agent/pkg/config/types" + "github.com/docker/docker-agent/pkg/hooks/builtins" "github.com/docker/docker-agent/pkg/runtime" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/sessiontitle" @@ -41,10 +42,11 @@ type App struct { events chan tea.Msg throttleDuration time.Duration cancel context.CancelFunc - currentAgentModel string // Tracks the current agent's model ID from AgentInfoEvent - exitAfterFirstResponse bool // Exit TUI after first assistant response completes - titleGenerating atomic.Bool // True when title generation is in progress - titleGen *sessiontitle.Generator // Title generator for local runtime (nil for remote) + currentAgentModel string // Tracks the current agent's model ID from AgentInfoEvent + exitAfterFirstResponse bool // Exit TUI after first assistant response completes + titleGenerating atomic.Bool // True when title generation is in progress + titleGen *sessiontitle.Generator // Title generator for local runtime (nil for remote) + snapshotController builtins.SnapshotController // Drives /undo, /snapshots, /reset; nil for runtimes that don't capture snapshots subsMu sync.Mutex subs []chan tea.Msg @@ -92,6 +94,19 @@ func WithTitleGenerator(gen *sessiontitle.Generator) Opt { } } +// WithSnapshotController plumbs in the [builtins.SnapshotController] +// the App uses to drive /undo, /snapshots, /reset. Pass the same +// controller to the runtime via runtime.WithAutoInjector so the +// instance that captures the checkpoints is the one the TUI commands +// drive. Pass nil (or omit the option) for runtimes that don't capture +// snapshots; the App then reports SnapshotsEnabled()==false and the +// related commands silently no-op. +func WithSnapshotController(c builtins.SnapshotController) Opt { + return func(a *App) { + a.snapshotController = c + } +} + func New(ctx context.Context, rt runtime.Runtime, sess *session.Session, opts ...Opt) *App { app := &App{ runtime: rt, diff --git a/pkg/app/app_test.go b/pkg/app/app_test.go index b76045f98..c88ceb6a1 100644 --- a/pkg/app/app_test.go +++ b/pkg/app/app_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/docker/docker-agent/pkg/hooks" "github.com/docker/docker-agent/pkg/hooks/builtins" "github.com/docker/docker-agent/pkg/runtime" "github.com/docker/docker-agent/pkg/session" @@ -18,12 +19,13 @@ import ( mcptools "github.com/docker/docker-agent/pkg/tools/mcp" ) -// mockRuntime is a minimal mock for testing App without a real runtime +// mockRuntime is a minimal mock for testing App without a real runtime. +// Snapshot operations are NOT modeled here: they are driven through a +// [builtins.SnapshotController] passed to the App via WithSnapshotController, +// so the mock runtime stays small and focused on the runtime.Runtime +// surface. type mockRuntime struct { - store session.Store - undoFiles int - undoOK bool - undoErr error + store session.Store } func (m *mockRuntime) CurrentAgentInfo(ctx context.Context) runtime.CurrentAgentInfo { @@ -80,18 +82,35 @@ func (m *mockRuntime) Stop() {} func (m *mockRuntime) Steer(_ runtime.QueuedMessage) error { return nil } func (m *mockRuntime) FollowUp(_ runtime.QueuedMessage) error { return nil } func (m *mockRuntime) TogglePause(context.Context) (bool, error) { return false, nil } -func (m *mockRuntime) UndoLastSnapshot(context.Context, *session.Session) (int, bool, error) { - return m.undoFiles, m.undoOK, m.undoErr -} -func (m *mockRuntime) SnapshotsEnabled() bool { return true } -func (m *mockRuntime) ListSnapshots(*session.Session) []builtins.SnapshotInfo { return nil } -func (m *mockRuntime) ResetSnapshot(context.Context, *session.Session, int) (int, bool, error) { - return m.undoFiles, m.undoOK, m.undoErr -} // Verify mockRuntime implements runtime.Runtime var _ runtime.Runtime = (*mockRuntime)(nil) +// stubSnapshotController is a tiny SnapshotController used by the app +// tests to drive /undo without spinning up a real shadow-git +// repository. enabled gates SnapshotsEnabled(), and the (files, ok, +// err) tuple is returned verbatim from UndoLast / Reset so each test +// can assert the result-shaping logic in [snapshotResult]. +type stubSnapshotController struct { + enabled bool + files int + ok bool + err error +} + +func (s *stubSnapshotController) Enabled() bool { return s.enabled } +func (s *stubSnapshotController) UndoLast(context.Context, string, string) (int, bool, error) { + return s.files, s.ok, s.err +} + +func (s *stubSnapshotController) List(string) []builtins.SnapshotInfo { return nil } +func (s *stubSnapshotController) Reset(context.Context, string, string, int) (int, bool, error) { + return s.files, s.ok, s.err +} +func (s *stubSnapshotController) AutoInject(*hooks.Config) {} + +var _ builtins.SnapshotController = (*stubSnapshotController)(nil) + func TestApp_NewSession_PreservesToolsApproved(t *testing.T) { t.Parallel() @@ -247,7 +266,9 @@ func TestApp_UndoLastSnapshot(t *testing.T) { t.Parallel() ctx := t.Context() - app := New(ctx, &mockRuntime{undoFiles: 2, undoOK: true}, session.New()) + app := New(ctx, &mockRuntime{}, session.New(), + WithSnapshotController(&stubSnapshotController{enabled: true, files: 2, ok: true}), + ) result, err := app.UndoLastSnapshot(ctx) require.NoError(t, err) assert.Equal(t, 2, result.RestoredFiles) @@ -257,17 +278,36 @@ func TestApp_UndoLastSnapshot_NoSnapshot(t *testing.T) { t.Parallel() ctx := t.Context() - app := New(ctx, &mockRuntime{}, session.New()) + app := New(ctx, &mockRuntime{}, session.New(), + WithSnapshotController(&stubSnapshotController{enabled: true}), + ) _, err := app.UndoLastSnapshot(ctx) assert.ErrorIs(t, err, ErrNothingToUndo) } +func TestApp_UndoLastSnapshot_NoController(t *testing.T) { + t.Parallel() + + // Without a SnapshotController the App reports nothing to undo, + // so the same UI affordance can light up regardless of which + // runtime the embedder paired the App with. + ctx := t.Context() + app := New(ctx, &mockRuntime{}, session.New()) + _, err := app.UndoLastSnapshot(ctx) + require.ErrorIs(t, err, ErrNothingToUndo) + assert.False(t, app.SnapshotsEnabled()) +} + func TestApp_SnapshotsEnabled_DoesNotRequireSession(t *testing.T) { t.Parallel() - // SnapshotsEnabled answers a runtime-capability question; it must not - // silently return false just because no session is attached. - app := &App{runtime: &mockRuntime{}, session: nil} + // SnapshotsEnabled answers a controller-capability question; it + // must not silently return false just because no session is attached. + app := &App{ + runtime: &mockRuntime{}, + session: nil, + snapshotController: &stubSnapshotController{enabled: true}, + } assert.True(t, app.SnapshotsEnabled()) } diff --git a/pkg/app/undo.go b/pkg/app/undo.go index 8c8b86cd4..ca1023a34 100644 --- a/pkg/app/undo.go +++ b/pkg/app/undo.go @@ -4,9 +4,7 @@ import ( "context" "errors" "fmt" - - "github.com/docker/docker-agent/pkg/hooks/builtins" - "github.com/docker/docker-agent/pkg/session" + "os" ) var ErrNothingToUndo = errors.New("nothing to undo") @@ -15,50 +13,30 @@ type UndoSnapshotResult struct { RestoredFiles int } -// snapshotRuntime is the subset of the runtime API that the App needs to -// drive snapshot commands. Runtimes that don't capture snapshots (e.g. -// remote runtimes) simply don't implement this interface and the related -// commands are then disabled in the UI. -type snapshotRuntime interface { - SnapshotsEnabled() bool - UndoLastSnapshot(ctx context.Context, sess *session.Session) (files int, ok bool, err error) - ListSnapshots(sess *session.Session) []builtins.SnapshotInfo - ResetSnapshot(ctx context.Context, sess *session.Session, keep int) (files int, ok bool, err error) -} - -// snapshotRuntime returns the runtime's snapshot interface, or nil when the -// runtime doesn't support snapshots at all (e.g. remote runtimes). -func (a *App) snapshotRuntime() snapshotRuntime { - r, _ := a.runtime.(snapshotRuntime) - return r -} - -// SnapshotsEnabled reports whether automatic shadow-git snapshots are active -// for the current runtime. The answer doesn't depend on having an active -// session: it's a runtime/configuration capability check. +// SnapshotsEnabled reports whether automatic shadow-git snapshots are +// active. The answer is a controller-level capability check and does +// not depend on having an active session attached. func (a *App) SnapshotsEnabled() bool { - r := a.snapshotRuntime() - return r != nil && r.SnapshotsEnabled() + return a.snapshotController != nil && a.snapshotController.Enabled() } -// UndoLastSnapshot restores the files captured in the most recent snapshot. +// UndoLastSnapshot restores the files captured in the most recent +// snapshot checkpoint for the current session. func (a *App) UndoLastSnapshot(ctx context.Context) (UndoSnapshotResult, error) { - r := a.snapshotRuntime() - if r == nil || a.session == nil { + if a.snapshotController == nil || a.session == nil { return UndoSnapshotResult{}, ErrNothingToUndo } - return snapshotResult(r.UndoLastSnapshot(ctx, a.session)) + return snapshotResult(a.snapshotController.UndoLast(ctx, a.session.ID, a.snapshotCwd())) } -// ListSnapshots returns the file count of every snapshot captured during the -// current session, oldest first. Returns nil when no snapshots exist or when -// the runtime doesn't support them. +// ListSnapshots returns the file count of every snapshot captured during +// the current session, oldest first. Returns nil when no snapshots exist +// or when no controller is configured. func (a *App) ListSnapshots() []int { - r := a.snapshotRuntime() - if r == nil || a.session == nil { + if a.snapshotController == nil || a.session == nil { return nil } - infos := r.ListSnapshots(a.session) + infos := a.snapshotController.List(a.session.ID) counts := make([]int, len(infos)) for i, info := range infos { counts[i] = info.Files @@ -67,14 +45,26 @@ func (a *App) ListSnapshots() []int { } // ResetSnapshot reverts every checkpoint past index keep so the workspace -// returns to the state captured at that snapshot. keep == 0 resets to the -// original pre-agent state. +// returns to the state captured at that snapshot. keep == 0 resets to +// the original pre-agent state. func (a *App) ResetSnapshot(ctx context.Context, keep int) (UndoSnapshotResult, error) { - r := a.snapshotRuntime() - if r == nil || a.session == nil { + if a.snapshotController == nil || a.session == nil { return UndoSnapshotResult{}, ErrNothingToUndo } - return snapshotResult(r.ResetSnapshot(ctx, a.session, keep)) + return snapshotResult(a.snapshotController.Reset(ctx, a.session.ID, a.snapshotCwd(), keep)) +} + +// snapshotCwd resolves the working directory the snapshot operations +// should run against. Sessions carry their own WorkingDir (set by the +// embedder when the session is constructed); if it's empty we fall +// back to os.Getwd so snapshot commands keep working in setups that +// don't propagate a working dir on the session. +func (a *App) snapshotCwd() string { + if a.session != nil && a.session.WorkingDir != "" { + return a.session.WorkingDir + } + cwd, _ := os.Getwd() + return cwd } // snapshotResult adapts the (files, ok, err) tuple returned by snapshot diff --git a/pkg/hooks/builtins/builtins.go b/pkg/hooks/builtins/builtins.go index b5af32656..520ac4383 100644 --- a/pkg/hooks/builtins/builtins.go +++ b/pkg/hooks/builtins/builtins.go @@ -15,7 +15,9 @@ // - snapshot (session_start, // turn_start, turn_end, // pre_tool_use, post_tool_use, -// session_end) — shadow-git snapshots +// session_end) — shadow-git snapshots. Installed via +// [RegisterSnapshot] (separate entry point) so the embedder receives +// a [SnapshotController] to drive /undo, /snapshots, /reset. // - redact_secrets (pre_tool_use, // before_llm_call, // tool_response_transform) — scrub secrets @@ -28,11 +30,10 @@ // Reference any of them from a hook YAML entry as // `{type: builtin, command: ""}`. The runtime additionally // auto-injects add_date / add_environment_info / add_prompt_files / -// redact_secrets from the matching agent flags, and snapshot from -// global user config. Setting redact_secrets at the agent level is -// exactly equivalent to writing -// the three matching hook entries by hand — -// [ApplyAgentDefaults] performs the auto-injection. +// redact_secrets from the matching agent flags via [ApplyAgentDefaults]. +// snapshot auto-injection lives on the controller returned by +// [RegisterSnapshot] and is plumbed into the runtime as an +// [AutoInjector], not as another bool on [AgentDefaults]. // // turn_start builtins recompute every turn (date, git state). // session_start builtins run once per session for context that's @@ -53,13 +54,13 @@ import ( "github.com/docker/docker-agent/pkg/hooks" ) -// Register installs the stock builtin hooks on r and returns the -// shared [*Snapshots] tracker so the caller (typically the runtime) -// can drive /undo, /list-snapshots, and /reset against the same -// in-memory checkpoint history the snapshot hook is writing to. -func Register(r *hooks.Registry) (*Snapshots, error) { - snapshots := NewSnapshots() - if err := errors.Join( +// Register installs the stock builtin hooks on r. +// +// Note: the snapshot builtin is NOT installed by Register. It ships +// its own entry point ([RegisterSnapshot]) so the embedder receives a +// [SnapshotController] for driving /undo, /snapshots, /reset. +func Register(r *hooks.Registry) error { + return errors.Join( r.RegisterBuiltin(AddDate, addDate), r.RegisterBuiltin(AddEnvironmentInfo, addEnvironmentInfo), r.RegisterBuiltin(AddPromptFiles, addPromptFiles), @@ -69,17 +70,13 @@ func Register(r *hooks.Registry) (*Snapshots, error) { r.RegisterBuiltin(AddUserInfo, addUserInfo), r.RegisterBuiltin(AddRecentCommits, addRecentCommits), r.RegisterBuiltin(MaxIterations, maxIterations), - r.RegisterBuiltin(Snapshot, snapshots.Hook), r.RegisterBuiltin(RedactSecrets, redactSecrets), r.RegisterBuiltin(HTTPPost, httpPost), - ); err != nil { - return nil, err - } - return snapshots, nil + ) } // AgentDefaults captures defaults that map onto stock builtin hook entries. -// Pass each AgentConfig.AddXxx flag as-is; Snapshot comes from runtime/global config. +// Pass each AgentConfig.AddXxx flag as-is. type AgentDefaults struct { AddDate bool AddEnvironmentInfo bool @@ -91,15 +88,29 @@ type AgentDefaults struct { // makes the auto-injection idempotent against an explicit YAML // entry that already names the same builtin. RedactSecrets bool - // Snapshot auto-injects shadow-git snapshots at - // turn boundaries. session_end is included to garbage-collect the - // shadow repository; undo history remains available after a response stops. - Snapshot bool +} + +// AutoInjector adds default hooks to an agent's hook configuration. +// The runtime invokes AutoInject for every registered injector when +// it builds per-agent executors, so a builtin that wants to be +// auto-wired only needs to ship its own AutoInjector and let the +// embedder plumb it in via runtime.WithAutoInjector. +// +// The snapshot controller returned by [RegisterSnapshot] satisfies +// this interface and is the canonical use case today; future builtins +// can opt in the same way without growing the central +// [ApplyAgentDefaults] table. +type AutoInjector interface { + AutoInject(cfg *hooks.Config) } // ApplyAgentDefaults appends the stock builtin hook entries implied by // d to cfg. A nil cfg is treated as empty. Returns nil iff no hook // (user-configured or auto-injected) is present. +// +// Snapshot auto-injection is handled separately via [SnapshotController] +// (an [AutoInjector]) so it can be configured by the embedder rather +// than by another bool on AgentDefaults. func ApplyAgentDefaults(cfg *hooks.Config, d AgentDefaults) *hooks.Config { if cfg == nil { cfg = &hooks.Config{} @@ -113,12 +124,6 @@ func ApplyAgentDefaults(cfg *hooks.Config, d AgentDefaults) *hooks.Config { if d.AddEnvironmentInfo { cfg.SessionStart = append(cfg.SessionStart, builtinHook(AddEnvironmentInfo)) } - if d.Snapshot { - cfg.SessionStart = append(cfg.SessionStart, builtinHook(Snapshot)) - cfg.TurnStart = append(cfg.TurnStart, builtinHook(Snapshot)) - cfg.TurnEnd = append(cfg.TurnEnd, builtinHook(Snapshot)) - cfg.SessionEnd = append(cfg.SessionEnd, builtinHook(Snapshot)) - } if d.RedactSecrets { // Wire all three legs at once. The same builtin handles each // event — it dispatches on input.HookEventName — so a single diff --git a/pkg/hooks/builtins/builtins_test.go b/pkg/hooks/builtins/builtins_test.go index 8fa1dc43c..200c7545a 100644 --- a/pkg/hooks/builtins/builtins_test.go +++ b/pkg/hooks/builtins/builtins_test.go @@ -21,8 +21,7 @@ func TestRegisterInstallsAllBuiltins(t *testing.T) { t.Parallel() r := hooks.NewRegistry() - _, err := builtins.Register(r) - require.NoError(t, err) + require.NoError(t, builtins.Register(r)) for _, name := range []string{ builtins.AddDate, @@ -34,7 +33,6 @@ func TestRegisterInstallsAllBuiltins(t *testing.T) { builtins.AddUserInfo, builtins.AddRecentCommits, builtins.MaxIterations, - builtins.Snapshot, builtins.RedactSecrets, builtins.HTTPPost, } { @@ -177,8 +175,7 @@ func TestAddPromptFilesNoArgsIsNoop(t *testing.T) { func lookup(t *testing.T, name string) hooks.BuiltinFunc { t.Helper() r := hooks.NewRegistry() - _, err := builtins.Register(r) - require.NoError(t, err) + require.NoError(t, builtins.Register(r)) fn, ok := r.LookupBuiltin(name) require.True(t, ok, "builtin %q must be registered", name) require.NotNil(t, fn) @@ -221,13 +218,57 @@ func TestApplyAgentDefaultsInjectsExpectedEvents(t *testing.T) { assert.Equal(t, builtins.AddEnvironmentInfo, cfg.SessionStart[0].Command) } -// TestApplyAgentDefaultsInjectsSnapshotHooks verifies the global snapshot -// default wires turn-boundary capture plus session_end shadow-repo cleanup. -func TestApplyAgentDefaultsInjectsSnapshotHooks(t *testing.T) { +// TestRegisterSnapshotInstallsBuiltin verifies that the dedicated +// snapshot entry point installs the snapshot builtin and returns a +// controller wired up to the registered hook. +func TestRegisterSnapshotInstallsBuiltin(t *testing.T) { t.Parallel() - cfg := builtins.ApplyAgentDefaults(nil, builtins.AgentDefaults{Snapshot: true}) - require.NotNil(t, cfg) + r := hooks.NewRegistry() + ctrl, err := builtins.RegisterSnapshot(r, true) + require.NoError(t, err) + require.NotNil(t, ctrl) + assert.True(t, ctrl.Enabled()) + + fn, ok := r.LookupBuiltin(builtins.Snapshot) + assert.True(t, ok, "snapshot must be registered by RegisterSnapshot") + assert.NotNil(t, fn) +} + +// TestRegisterSnapshotDisabledStillExposesController verifies that an +// embedder can install the snapshot builtin without auto-injection, in +// which case the controller still exists (so /undo etc. work for hooks +// the user wired manually) but Enabled() reports false. +func TestRegisterSnapshotDisabledStillExposesController(t *testing.T) { + t.Parallel() + + r := hooks.NewRegistry() + ctrl, err := builtins.RegisterSnapshot(r, false) + require.NoError(t, err) + require.NotNil(t, ctrl) + assert.False(t, ctrl.Enabled()) + + _, ok := r.LookupBuiltin(builtins.Snapshot) + assert.True(t, ok) +} + +// TestSnapshotControllerAutoInjectWiresFourEvents verifies that the +// controller's AutoInject mounts the snapshot hook on session_start, +// turn_start, turn_end, and session_end — the four boundaries needed +// to bracket every session and every turn. Per-tool capture stays +// opt-in via YAML. +func TestSnapshotControllerAutoInjectWiresFourEvents(t *testing.T) { + t.Parallel() + + r := hooks.NewRegistry() + ctrl, err := builtins.RegisterSnapshot(r, true) + require.NoError(t, err) + + inj, ok := ctrl.(builtins.AutoInjector) + require.True(t, ok, "controller must satisfy AutoInjector") + + cfg := &hooks.Config{} + inj.AutoInject(cfg) require.Len(t, cfg.SessionStart, 1) require.Len(t, cfg.TurnStart, 1) require.Len(t, cfg.TurnEnd, 1) @@ -238,6 +279,25 @@ func TestApplyAgentDefaultsInjectsSnapshotHooks(t *testing.T) { assert.Equal(t, builtins.Snapshot, cfg.SessionEnd[0].Command) } +// TestSnapshotControllerAutoInjectDisabledIsNoop verifies that a +// controller constructed with enabled=false makes no changes to cfg, +// so an embedder can pass it unconditionally to the runtime as an +// AutoInjector and rely on the bool to gate auto-injection. +func TestSnapshotControllerAutoInjectDisabledIsNoop(t *testing.T) { + t.Parallel() + + r := hooks.NewRegistry() + ctrl, err := builtins.RegisterSnapshot(r, false) + require.NoError(t, err) + + inj, ok := ctrl.(builtins.AutoInjector) + require.True(t, ok) + + cfg := &hooks.Config{} + inj.AutoInject(cfg) + assert.True(t, cfg.IsEmpty(), "disabled controller must not inject any hooks") +} + func TestApplyAgentDefaultsAppendsToUserHooks(t *testing.T) { t.Parallel() diff --git a/pkg/hooks/builtins/redact_secrets_test.go b/pkg/hooks/builtins/redact_secrets_test.go index f7888929d..13ca68889 100644 --- a/pkg/hooks/builtins/redact_secrets_test.go +++ b/pkg/hooks/builtins/redact_secrets_test.go @@ -128,8 +128,7 @@ func TestRedactSecretsIsRegistered(t *testing.T) { t.Parallel() reg := hooks.NewRegistry() - _, err := Register(reg) - require.NoError(t, err) + require.NoError(t, Register(reg)) handler, ok := reg.LookupBuiltin(RedactSecrets) require.Truef(t, ok, "builtin %q must be registered", RedactSecrets) diff --git a/pkg/hooks/builtins/snapshot.go b/pkg/hooks/builtins/snapshot.go index 5a2d8656c..4b5014907 100644 --- a/pkg/hooks/builtins/snapshot.go +++ b/pkg/hooks/builtins/snapshot.go @@ -19,12 +19,104 @@ type SnapshotInfo struct { Files int } -// Snapshots tracks per-session shadow-git checkpoints. The same +// SnapshotController exposes the operations the embedder uses to drive +// shadow-git snapshot commands (/undo, /snapshots, /reset). It is +// returned by [RegisterSnapshot] and intentionally narrow: the runtime +// no longer brokers snapshot operations on the embedder's behalf. +// +// Enabled() reports whether snapshot auto-injection (capturing +// checkpoints at session/turn boundaries) is configured. The other +// methods always work against any checkpoints already captured for +// sessionID, regardless of Enabled(). +// +// SnapshotController also satisfies [AutoInjector] so the same +// instance the App uses for /undo can be passed to the runtime via +// runtime.WithAutoInjector. AutoInject is a runtime-internal call; +// embedders normally don't invoke it directly. +type SnapshotController interface { + AutoInjector + Enabled() bool + UndoLast(ctx context.Context, sessionID, cwd string) (files int, ok bool, err error) + List(sessionID string) []SnapshotInfo + Reset(ctx context.Context, sessionID, cwd string, keep int) (files int, ok bool, err error) +} + +// RegisterSnapshot installs the snapshot builtin on r and returns a +// [SnapshotController]. enabled controls whether the controller's +// AutoInject mounts the snapshot hook on session/turn boundaries; pass +// false to keep the hook resolvable for users who wire it manually via +// YAML without auto-capturing checkpoints. +// +// Embedders typically pass the same controller to both the runtime +// (via runtime.WithAutoInjector) and the App (via +// app.WithSnapshotController) so /undo et al. drive the same instance +// that captures the checkpoints. +func RegisterSnapshot(r *hooks.Registry, enabled bool) (SnapshotController, error) { + b := newSnapshotBuiltin() + if err := r.RegisterBuiltin(Snapshot, b.hook); err != nil { + return nil, err + } + return &snapshotController{builtin: b, enabled: enabled}, nil +} + +type snapshotController struct { + builtin *snapshotBuiltin + enabled bool +} + +var ( + _ SnapshotController = (*snapshotController)(nil) + _ AutoInjector = (*snapshotController)(nil) +) + +func (c *snapshotController) Enabled() bool { + return c != nil && c.enabled +} + +func (c *snapshotController) UndoLast(ctx context.Context, sessionID, cwd string) (int, bool, error) { + if c == nil || c.builtin == nil || sessionID == "" || cwd == "" { + return 0, false, nil + } + return c.builtin.undoLast(ctx, sessionID, cwd) +} + +func (c *snapshotController) List(sessionID string) []SnapshotInfo { + if c == nil || c.builtin == nil || sessionID == "" { + return nil + } + return c.builtin.listSnapshots(sessionID) +} + +func (c *snapshotController) Reset(ctx context.Context, sessionID, cwd string, keep int) (int, bool, error) { + if c == nil || c.builtin == nil || sessionID == "" || cwd == "" { + return 0, false, nil + } + return c.builtin.resetSnapshot(ctx, sessionID, cwd, keep) +} + +// AutoInject mounts the snapshot hook on session/turn boundaries when +// the controller is enabled. The four-event surface (session_start, +// turn_start, turn_end, session_end) matches what the snapshot builtin +// needs to bracket every session and every model turn; per-tool +// capture (pre_tool_use / post_tool_use) is opt-in via YAML and is not +// auto-wired here. +func (c *snapshotController) AutoInject(cfg *hooks.Config) { + if c == nil || !c.enabled || cfg == nil { + return + } + hook := hooks.Hook{Type: hooks.HookTypeBuiltin, Command: Snapshot} + cfg.SessionStart = append(cfg.SessionStart, hook) + cfg.TurnStart = append(cfg.TurnStart, hook) + cfg.TurnEnd = append(cfg.TurnEnd, hook) + cfg.SessionEnd = append(cfg.SessionEnd, hook) +} + +// snapshotBuiltin tracks per-session shadow-git checkpoints. The same // instance is dispatched as the snapshot builtin (registered under -// [Snapshot] via [Hook]) and exposed to the runtime for /undo, -// /list-snapshots and /reset (via [UndoLast] / [List] / [Reset]). -// Construct with [NewSnapshots]; the zero value is not usable. -type Snapshots struct { +// [Snapshot] via [snapshotBuiltin.hook]) and exposed to embedders via +// [snapshotController] for /undo, /snapshots, and /reset. Construct +// with [newSnapshotBuiltin]; the zero value is not usable. +type snapshotBuiltin struct { manager *snapshot.Manager mu sync.Mutex session map[string]*snapshotSession @@ -41,21 +133,21 @@ type snapshotCheckpoint struct { files []string } -// NewSnapshots returns a fresh snapshot tracker. Held by the runtime -// for /undo, /list-snapshots and /reset; the same instance backs the -// snapshot hook registered under [Snapshot]. -func NewSnapshots() *Snapshots { - return &Snapshots{ +// newSnapshotBuiltin returns a fresh snapshot tracker. Held by +// [snapshotController] for /undo, /snapshots and /reset; the same +// instance backs the snapshot hook registered under [Snapshot]. +func newSnapshotBuiltin() *snapshotBuiltin { + return &snapshotBuiltin{ manager: snapshot.NewManager(""), session: map[string]*snapshotSession{}, } } -// Hook is the [hooks.BuiltinFunc] dispatched on every snapshot event. +// hook is the [hooks.BuiltinFunc] dispatched on every snapshot event. // It tracks per-session turn/tool hashes, captures patches at // turn_end / post_tool_use, and runs the shadow-repo cleanup at // session_end. -func (b *Snapshots) Hook(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, error) { +func (b *snapshotBuiltin) hook(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, error) { if in == nil || in.Cwd == "" || in.SessionID == "" { return nil, nil } @@ -128,7 +220,7 @@ func (b *Snapshots) Hook(ctx context.Context, in *hooks.Input, _ []string) (*hoo return nil, nil } -func (b *Snapshots) getSession(sessionID string) *snapshotSession { +func (b *snapshotBuiltin) getSession(sessionID string) *snapshotSession { s := b.session[sessionID] if s == nil { s = &snapshotSession{tools: map[string]string{}} @@ -137,13 +229,13 @@ func (b *Snapshots) getSession(sessionID string) *snapshotSession { return s } -func (b *Snapshots) setTurn(sessionID, hash string) { +func (b *snapshotBuiltin) setTurn(sessionID, hash string) { b.mu.Lock() defer b.mu.Unlock() b.getSession(sessionID).turn = hash } -func (b *Snapshots) popTurn(sessionID string) string { +func (b *snapshotBuiltin) popTurn(sessionID string) string { b.mu.Lock() defer b.mu.Unlock() s := b.session[sessionID] @@ -155,7 +247,7 @@ func (b *Snapshots) popTurn(sessionID string) string { return hash } -func (b *Snapshots) setTool(sessionID, toolUseID, hash string) { +func (b *snapshotBuiltin) setTool(sessionID, toolUseID, hash string) { if toolUseID == "" { return } @@ -164,7 +256,7 @@ func (b *Snapshots) setTool(sessionID, toolUseID, hash string) { b.getSession(sessionID).tools[toolUseID] = hash } -func (b *Snapshots) popTool(sessionID, toolUseID string) string { +func (b *snapshotBuiltin) popTool(sessionID, toolUseID string) string { if toolUseID == "" { return "" } @@ -179,7 +271,7 @@ func (b *Snapshots) popTool(sessionID, toolUseID string) string { return hash } -func (b *Snapshots) pushCheckpoint(sessionID string, checkpoint snapshotCheckpoint) { +func (b *snapshotBuiltin) pushCheckpoint(sessionID string, checkpoint snapshotCheckpoint) { if len(checkpoint.files) == 0 { return } @@ -189,7 +281,7 @@ func (b *Snapshots) pushCheckpoint(sessionID string, checkpoint snapshotCheckpoi s.history = append(s.history, checkpoint) } -func (b *Snapshots) popCheckpoint(sessionID string) (snapshotCheckpoint, bool) { +func (b *snapshotBuiltin) popCheckpoint(sessionID string) (snapshotCheckpoint, bool) { b.mu.Lock() defer b.mu.Unlock() s := b.session[sessionID] @@ -203,10 +295,10 @@ func (b *Snapshots) popCheckpoint(sessionID string) (snapshotCheckpoint, bool) { return checkpoint, true } -// UndoLast restores the files captured by the most recent checkpoint. +// undoLast restores the files captured by the most recent checkpoint. // Returns (filesRestored, true, nil) on success, (0, false, nil) when // there is nothing to undo. -func (b *Snapshots) UndoLast(ctx context.Context, sessionID, cwd string) (files int, ok bool, err error) { +func (b *snapshotBuiltin) undoLast(ctx context.Context, sessionID, cwd string) (files int, ok bool, err error) { checkpoint, ok := b.popCheckpoint(sessionID) if !ok { return 0, false, nil @@ -225,9 +317,9 @@ func (b *Snapshots) UndoLast(ctx context.Context, sessionID, cwd string) (files return len(checkpoint.files), true, nil } -// List returns the completed checkpoints for a session in chronological -// order (oldest first). The returned slice may be empty. -func (b *Snapshots) List(sessionID string) []SnapshotInfo { +// listSnapshots returns the completed checkpoints for a session in +// chronological order (oldest first). The returned slice may be empty. +func (b *snapshotBuiltin) listSnapshots(sessionID string) []SnapshotInfo { b.mu.Lock() defer b.mu.Unlock() s := b.session[sessionID] @@ -241,12 +333,12 @@ func (b *Snapshots) List(sessionID string) []SnapshotInfo { return out } -// Reset reverts every checkpoint with index >= keep so the workspace -// returns to the state captured at snapshot keep. keep == 0 means -// "reset to the original state". A keep value greater than or equal -// to the snapshot count is a no-op. Reverted checkpoints are dropped -// from the session history. -func (b *Snapshots) Reset(ctx context.Context, sessionID, cwd string, keep int) (files int, ok bool, err error) { +// resetSnapshot reverts every checkpoint with index >= keep so the +// workspace returns to the state captured at snapshot keep. keep == 0 +// means "reset to the original state". A keep value greater than or +// equal to the snapshot count is a no-op. Reverted checkpoints are +// dropped from the session history. +func (b *snapshotBuiltin) resetSnapshot(ctx context.Context, sessionID, cwd string, keep int) (files int, ok bool, err error) { tail := b.popHistoryTail(sessionID, keep) if len(tail) == 0 { return 0, false, nil @@ -273,7 +365,7 @@ func (b *Snapshots) Reset(ctx context.Context, sessionID, cwd string, keep int) // the surviving prefix in the session history. keep is clamped to [0, len]. // The popped slots in the backing array are zeroed so the dropped file lists // can be garbage-collected before the slice grows past them again. -func (b *Snapshots) popHistoryTail(sessionID string, keep int) []snapshotCheckpoint { +func (b *snapshotBuiltin) popHistoryTail(sessionID string, keep int) []snapshotCheckpoint { b.mu.Lock() defer b.mu.Unlock() s := b.session[sessionID] diff --git a/pkg/hooks/builtins/snapshot_test.go b/pkg/hooks/builtins/snapshot_test.go index 1d07ec5a8..1495f9cb0 100644 --- a/pkg/hooks/builtins/snapshot_test.go +++ b/pkg/hooks/builtins/snapshot_test.go @@ -22,7 +22,7 @@ func TestSnapshotBuiltinUndoSurvivesStreamEnd(t *testing.T) { t.Cleanup(func() { paths.SetDataDir("") }) r := hooks.NewRegistry() - snapshots, err := builtins.Register(r) + ctrl, err := builtins.RegisterSnapshot(r, true) require.NoError(t, err) fn, ok := r.LookupBuiltin(builtins.Snapshot) require.True(t, ok) @@ -74,7 +74,7 @@ func TestSnapshotBuiltinUndoSurvivesStreamEnd(t *testing.T) { require.Len(t, entries, 1) require.DirExists(t, filepath.Join(paths.GetDataDir(), "snapshot", entries[0].Name())) - files, restored, err := snapshots.UndoLast(t.Context(), "s", dir) + files, restored, err := ctrl.UndoLast(t.Context(), "s", dir) require.NoError(t, err) assert.True(t, restored) assert.Equal(t, 1, files) @@ -89,7 +89,7 @@ func TestSnapshotBuiltinListAndReset(t *testing.T) { t.Cleanup(func() { paths.SetDataDir("") }) r := hooks.NewRegistry() - snapshots, err := builtins.Register(r) + ctrl, err := builtins.RegisterSnapshot(r, true) require.NoError(t, err) fn, ok := r.LookupBuiltin(builtins.Snapshot) require.True(t, ok) @@ -97,7 +97,7 @@ func TestSnapshotBuiltinListAndReset(t *testing.T) { dir := snapshotBuiltinRepo(t) // Initially: no checkpoints. - assert.Empty(t, snapshots.List("s")) + assert.Empty(t, ctrl.List("s")) // Capture three snapshots: each turn modifies one file. recordTurn := func(t *testing.T, name, contents string) { @@ -122,33 +122,33 @@ func TestSnapshotBuiltinListAndReset(t *testing.T) { recordTurn(t, "b.txt", "b") recordTurn(t, "c.txt", "c") - snaps := snapshots.List("s") + snaps := ctrl.List("s") require.Len(t, snaps, 3) assert.Equal(t, 1, snaps[0].Files) assert.Equal(t, 1, snaps[1].Files) assert.Equal(t, 1, snaps[2].Files) // Reset to snapshot 2: revert turn 3 only, leaving a.txt and b.txt intact. - files, restored, err := snapshots.Reset(t.Context(), "s", dir, 2) + files, restored, err := ctrl.Reset(t.Context(), "s", dir, 2) require.NoError(t, err) assert.True(t, restored) assert.Equal(t, 1, files) assert.FileExists(t, filepath.Join(dir, "a.txt")) assert.FileExists(t, filepath.Join(dir, "b.txt")) assert.NoFileExists(t, filepath.Join(dir, "c.txt")) - require.Len(t, snapshots.List("s"), 2) + require.Len(t, ctrl.List("s"), 2) // Reset to original: revert remaining checkpoints, deleting all three files. - files, restored, err = snapshots.Reset(t.Context(), "s", dir, 0) + files, restored, err = ctrl.Reset(t.Context(), "s", dir, 0) require.NoError(t, err) assert.True(t, restored) assert.Equal(t, 2, files) assert.NoFileExists(t, filepath.Join(dir, "a.txt")) assert.NoFileExists(t, filepath.Join(dir, "b.txt")) - assert.Empty(t, snapshots.List("s")) + assert.Empty(t, ctrl.List("s")) // Subsequent reset is a no-op (nothing to revert). - _, restored, err = snapshots.Reset(t.Context(), "s", dir, 0) + _, restored, err = ctrl.Reset(t.Context(), "s", dir, 0) require.NoError(t, err) assert.False(t, restored) } @@ -161,7 +161,7 @@ func TestSnapshotBuiltinResetKeepBeyondHistoryIsNoop(t *testing.T) { t.Cleanup(func() { paths.SetDataDir("") }) r := hooks.NewRegistry() - snapshots, err := builtins.Register(r) + ctrl, err := builtins.RegisterSnapshot(r, true) require.NoError(t, err) fn, ok := r.LookupBuiltin(builtins.Snapshot) require.True(t, ok) @@ -183,18 +183,18 @@ func TestSnapshotBuiltinResetKeepBeyondHistoryIsNoop(t *testing.T) { require.NoError(t, err) // keep == len(history) means "keep everything" — no checkpoints reverted. - files, restored, err := snapshots.Reset(t.Context(), "s", dir, 1) + files, restored, err := ctrl.Reset(t.Context(), "s", dir, 1) require.NoError(t, err) assert.False(t, restored) assert.Equal(t, 0, files) assert.FileExists(t, filepath.Join(dir, "a.txt")) - require.Len(t, snapshots.List("s"), 1) + require.Len(t, ctrl.List("s"), 1) // keep way past the end is also a no-op. - _, restored, err = snapshots.Reset(t.Context(), "s", dir, 99) + _, restored, err = ctrl.Reset(t.Context(), "s", dir, 99) require.NoError(t, err) assert.False(t, restored) - require.Len(t, snapshots.List("s"), 1) + require.Len(t, ctrl.List("s"), 1) } func snapshotBuiltinRepo(t *testing.T) string { diff --git a/pkg/runtime/hooks.go b/pkg/runtime/hooks.go index 39b9405b1..158cacc55 100644 --- a/pkg/runtime/hooks.go +++ b/pkg/runtime/hooks.go @@ -15,11 +15,13 @@ import ( ) // buildHooksExecutors builds a [hooks.Executor] for every agent in the -// team that has user-configured hooks, an agent-flag or runtime setting -// that maps to a builtin (AddDate / AddEnvironmentInfo / AddPromptFiles / -// Snapshot), or a configured response cache (which auto-injects a -// cache_response stop hook). Agents with no hooks have no entry; lookups fall through to -// nil so callers can short-circuit cheaply. +// team that has user-configured hooks, an agent-flag setting that maps +// to a builtin (AddDate / AddEnvironmentInfo / AddPromptFiles), an +// [builtins.AutoInjector] supplied via [WithAutoInjector] (today the +// snapshot controller), or a configured response cache (which +// auto-injects a cache_response stop hook). Agents with no hooks have +// no entry; lookups fall through to nil so callers can short-circuit +// cheaply. // // Called once from [NewLocalRuntime] after r.workingDir, r.env and // r.hooksRegistry are finalized; the resulting map is read-only for @@ -37,8 +39,8 @@ func (r *LocalRuntime) buildHooksExecutors() { AddEnvironmentInfo: a.AddEnvironmentInfo(), AddPromptFiles: a.AddPromptFiles(), RedactSecrets: a.RedactSecrets(), - Snapshot: r.snapshotsEnabled, }) + cfg = applyAutoInjectors(cfg, r.autoInjectors) cfg = applyCacheDefault(cfg, a) if cfg == nil { continue @@ -47,6 +49,28 @@ func (r *LocalRuntime) buildHooksExecutors() { } } +// applyAutoInjectors runs each AutoInjector against cfg, allocating a +// fresh Config when needed so a previously-empty agent picks up the +// injector's hooks. Returns nil iff cfg ends up empty after every +// injector has run, mirroring [ApplyAgentDefaults]. +func applyAutoInjectors(cfg *hooks.Config, injectors []builtins.AutoInjector) *hooks.Config { + if len(injectors) == 0 { + return cfg + } + if cfg == nil { + cfg = &hooks.Config{} + } + for _, inj := range injectors { + if inj != nil { + inj.AutoInject(cfg) + } + } + if cfg.IsEmpty() { + return nil + } + return cfg +} + // hooksExec returns the pre-built [hooks.Executor] for a, or nil when // the agent has no hooks (see [buildHooksExecutors]). func (r *LocalRuntime) hooksExec(a *agent.Agent) *hooks.Executor { diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index f1219a088..33b254efc 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -161,7 +161,6 @@ type LocalRuntime struct { workingDir string // Working directory for hooks execution env []string // Environment variables for hooks execution modelSwitcherCfg *ModelSwitcherConfig - snapshotsEnabled bool // hooksRegistry is the runtime-private hooks.Registry used to build // every Executor. It carries the runtime-owned builtin hooks @@ -170,12 +169,11 @@ type LocalRuntime struct { // touching any process-wide state. hooksRegistry *hooks.Registry - // snapshots is the shared shadow-git snapshot tracker. The hook - // side ([Snapshots.Hook], wired in via [builtins.Register]) writes - // checkpoints during the run; the runtime side - // ([UndoLastSnapshot] / [ListSnapshots] / [ResetSnapshot]) - // reads them for /undo and /reset commands. - snapshots *builtins.Snapshots + // autoInjectors run on every per-agent hook config during + // [buildHooksExecutors] so embedders can plug in builtins (today + // snapshot via [builtins.SnapshotController]) without the runtime + // hard-coding their wiring. Set via [WithAutoInjector]. + autoInjectors []builtins.AutoInjector // hooksExecByAgent holds the per-agent [hooks.Executor], keyed by // agent name. Built once in [NewLocalRuntime.buildHooksExecutors] @@ -389,6 +387,40 @@ func WithRetryOnRateLimit() Opt { } } +// WithAutoInjector adds an [builtins.AutoInjector] that augments every +// per-agent hook configuration during executor build. The canonical +// use case is the snapshot controller returned by +// [builtins.RegisterSnapshot]: pass the same controller to the App via +// app.WithSnapshotController so /undo and friends drive the same +// instance that captures the checkpoints. +// +// Multiple calls accumulate; injectors run in registration order. +func WithAutoInjector(inj builtins.AutoInjector) Opt { + return func(r *LocalRuntime) { + if inj != nil { + r.autoInjectors = append(r.autoInjectors, inj) + } + } +} + +// WithHooksRegistry plugs a pre-populated [hooks.Registry] into the +// runtime instead of letting it allocate a fresh one. Embedders use +// this to pre-register builtins they own (today snapshot, tomorrow +// any custom builtin) so the auto-injection chain set up by +// [WithAutoInjector] resolves against the same registry. +// +// The runtime continues to register its own stateless and +// closure-bound builtins (add_date, max_iterations, cache_response, +// unload, ...) on top of the supplied registry, so the embedder only +// needs to install entries that the runtime can't construct itself. +func WithHooksRegistry(reg *hooks.Registry) Opt { + return func(r *LocalRuntime) { + if reg != nil { + r.hooksRegistry = reg + } + } +} + // New creates a runtime ready to drive an agent loop. It is a thin // alias for [NewLocalRuntime] returning the [Runtime] interface, kept // for source compatibility with callers written before persistence @@ -408,13 +440,6 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { return nil, err } - hooksRegistry := hooks.NewRegistry() - snapshots, err := builtins.Register(hooksRegistry) - if err != nil { - return nil, fmt.Errorf("register builtin hooks: %w", err) - } - registerModelHook(hooksRegistry) - r := &LocalRuntime{ toolMap: make(map[string]ToolHandlerFunc), team: agents, @@ -426,8 +451,6 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { sessionCompaction: true, managedOAuth: true, sessionStore: session.NewInMemorySessionStore(), - hooksRegistry: hooksRegistry, - snapshots: snapshots, fallback: newFallbackExecutor(), now: time.Now, telemetry: defaultTelemetry{}, @@ -435,22 +458,6 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { } r.bgAgents = agenttool.NewHandler(r) - // cache_response is registered here (not in pkg/hooks/builtins) because - // it needs to capture the runtime to resolve the agent referenced by - // Input.AgentName. The other builtins are stateless and can stay as - // package-level functions registered via builtins.Register above. - if err := hooksRegistry.RegisterBuiltin(BuiltinCacheResponse, r.cacheResponseBuiltin); err != nil { - return nil, fmt.Errorf("register %q builtin: %w", BuiltinCacheResponse, err) - } - - // unload is registered alongside cache_response for the same - // reason: it needs to walk Input.FromAgent up to the previous agent's - // configured models and dispatch to provider.Unloader implementations, - // which the runtime owns through the team. - if err := hooksRegistry.RegisterBuiltin(BuiltinUnload, r.unloadBuiltin); err != nil { - return nil, fmt.Errorf("register %q builtin: %w", BuiltinUnload, err) - } - // stripUnsupportedModalitiesTransform captures the runtime closure to // resolve the agent from Input.AgentName, so it lives here rather // than as a stateless builtin in pkg/hooks/builtins. It drops image @@ -474,6 +481,36 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { opt(r) } + // Set up the hooks registry. Use the embedder-supplied registry + // (via [WithHooksRegistry]) when present so any builtins the + // embedder pre-registered — typically the snapshot builtin from + // [builtins.RegisterSnapshot] — are visible to the runtime, then + // register the runtime-owned builtins on top. + if r.hooksRegistry == nil { + r.hooksRegistry = hooks.NewRegistry() + } + if err := builtins.Register(r.hooksRegistry); err != nil { + return nil, fmt.Errorf("register builtin hooks: %w", err) + } + registerModelHook(r.hooksRegistry) + + // cache_response is registered here (not in pkg/hooks/builtins) + // because it needs to capture the runtime to resolve the agent + // referenced by Input.AgentName. The other builtins are stateless + // and can stay as package-level functions registered via + // [builtins.Register] above. + if err := r.hooksRegistry.RegisterBuiltin(BuiltinCacheResponse, r.cacheResponseBuiltin); err != nil { + return nil, fmt.Errorf("register %q builtin: %w", BuiltinCacheResponse, err) + } + + // unload is registered alongside cache_response for the same reason: + // it needs to walk Input.FromAgent up to the previous agent's + // configured models and dispatch to provider.Unloader + // implementations, which the runtime owns through the team. + if err := r.hooksRegistry.RegisterBuiltin(BuiltinUnload, r.unloadBuiltin); err != nil { + return nil, fmt.Errorf("register %q builtin: %w", BuiltinUnload, err) + } + // Build the cooldown manager and wire the fallback executor's // runtime-bound dependencies after opts so they pick up the final // clock and telemetry sink ([WithClock] / [WithTelemetry]). diff --git a/pkg/runtime/snapshot.go b/pkg/runtime/snapshot.go deleted file mode 100644 index d527e80e1..000000000 --- a/pkg/runtime/snapshot.go +++ /dev/null @@ -1,68 +0,0 @@ -package runtime - -import ( - "context" - "os" - - "github.com/docker/docker-agent/pkg/hooks/builtins" - "github.com/docker/docker-agent/pkg/session" -) - -// WithSnapshots configures whether snapshot hooks are auto-injected for every agent. -func WithSnapshots(enabled bool) Opt { - return func(r *LocalRuntime) { - r.snapshotsEnabled = enabled - } -} - -// SnapshotsEnabled reports whether automatic snapshot hooks are active for -// this runtime. Used by the TUI to hide snapshot-related commands when the -// feature is off. -func (r *LocalRuntime) SnapshotsEnabled() bool { - return r != nil && r.snapshotsEnabled -} - -// UndoLastSnapshot restores files recorded for the latest completed snapshot hook checkpoint. -func (r *LocalRuntime) UndoLastSnapshot(ctx context.Context, sess *session.Session) (int, bool, error) { - cwd := r.snapshotCwd(sess) - if cwd == "" { - return 0, false, nil - } - return r.snapshots.UndoLast(ctx, sess.ID, cwd) -} - -// ListSnapshots returns the completed snapshot checkpoints recorded for the -// session, oldest first. Returns nil when none exist. -func (r *LocalRuntime) ListSnapshots(sess *session.Session) []builtins.SnapshotInfo { - if r == nil || sess == nil { - return nil - } - return r.snapshots.List(sess.ID) -} - -// ResetSnapshot reverts every checkpoint past index keep so the workspace -// returns to the state captured at that snapshot. keep == 0 resets to the -// original (pre-agent) state. -func (r *LocalRuntime) ResetSnapshot(ctx context.Context, sess *session.Session, keep int) (int, bool, error) { - cwd := r.snapshotCwd(sess) - if cwd == "" { - return 0, false, nil - } - return r.snapshots.Reset(ctx, sess.ID, cwd, keep) -} - -// snapshotCwd resolves the working directory used to open the shadow -// repository for snapshot operations. Returns "" when no candidate is usable. -func (r *LocalRuntime) snapshotCwd(sess *session.Session) string { - if r == nil || sess == nil { - return "" - } - if sess.WorkingDir != "" { - return sess.WorkingDir - } - if r.workingDir != "" { - return r.workingDir - } - cwd, _ := os.Getwd() - return cwd -}