From 9a687c27daa24a4a543459ab5a20a6735329b835 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:13:43 +0800 Subject: [PATCH 01/49] feat(channel): coalesce duplicate approval prompts by dest:port --- internal/channel/broker.go | 230 ++++++++++++++- internal/channel/channel_test.go | 461 ++++++++++++++++++++++++++++++- 2 files changed, 676 insertions(+), 15 deletions(-) diff --git a/internal/channel/broker.go b/internal/channel/broker.go index 0e1d645..eb01891 100644 --- a/internal/channel/broker.go +++ b/internal/channel/broker.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "strconv" "sync" "sync/atomic" "time" @@ -33,6 +34,19 @@ type Broker struct { timedOut map[string]time.Time nextID atomic.Int64 + // dedupIndex maps a persistence-equivalent target key ("dest:port") + // to the primary request ID currently holding an open prompt for that + // target. Concurrent requests to the same target while the primary is + // pending attach to the primary as coalesced subscribers instead of + // opening their own prompt. + dedupIndex map[string]string + + // coalesced retains the final coalesced count for a primary request ID + // after its waiter has been removed (resolved/timed-out/cancelled), so + // channels editing the resolved/cancelled message can render how many + // requests the single decision covered. GC'd like timedOut. + coalesced map[string]coalescedRecord + // closed is set to true by CancelAll under the mutex, before the done // channel is closed. Request checks this flag under the same mutex to // prevent registering new waiters after CancelAll has copied and reset @@ -59,9 +73,25 @@ type Broker struct { } // waiter tracks a pending approval request and its response channel. +// +// subs holds buffered (cap 1) response channels for coalesced requests that +// attached to this primary waiter while it was pending. count starts at 1 +// (the primary) and increments for every attached sub. dedupKey is the +// "dest:port" key under which this waiter is registered in dedupIndex (empty +// when the request opted out of coalescing). type waiter struct { - ch chan Response - req ApprovalRequest + ch chan Response + req ApprovalRequest + subs []chan Response + count int + dedupKey string +} + +// coalescedRecord retains a resolved primary's final coalesced count for a +// bounded TTL so message-edit paths can render it after the waiter is gone. +type coalescedRecord struct { + count int + at time.Time } // BrokerOption configures a Broker. @@ -88,6 +118,8 @@ func NewBroker(channels []Channel, opts ...BrokerOption) *Broker { channels: channels, waiters: make(map[string]waiter), timedOut: make(map[string]time.Time), + dedupIndex: make(map[string]string), + coalesced: make(map[string]coalescedRecord), done: make(chan struct{}), MaxPendingRequests: 50, destRateMax: 5, @@ -117,6 +149,18 @@ type RequestOption func(*requestConfig) type requestConfig struct { req ApprovalRequest bypassRateLimit bool + noCoalesce bool +} + +// WithNoCoalesce disables broker-level coalescing for this request. Use it +// when distinct requests to the same "dest:port" are NOT semantically +// equivalent — e.g. MCP tool calls, whose ToolArgs differ and feed +// arg-sensitive ContentInspector/exec rules. Such requests must each get +// their own prompt. +func WithNoCoalesce() RequestOption { + return func(c *requestConfig) { + c.noCoalesce = true + } } // WithToolArgs sets the truncated tool arguments on an MCP approval request. @@ -157,13 +201,17 @@ func WithBypassRateLimit() RequestOption { // Request sends an approval request to all channels and blocks until one // responds or the timeout expires. Returns the first response received. +// +// Coalescing: concurrent requests sharing a persistence-equivalent target +// ("dest:port") collapse onto the first one's prompt. Only the first opens a +// prompt; later arrivals attach as buffered subscribers and receive the same +// response when the primary resolves/times out/is cancelled. Pass +// WithNoCoalesce to opt out (MCP tool calls). func (b *Broker) Request(dest string, port int, protocol string, timeout time.Duration, opts ...RequestOption) (Response, error) { - id := fmt.Sprintf("req_%d", b.nextID.Add(1)) ch := make(chan Response, 1) cfg := requestConfig{ req: ApprovalRequest{ - ID: id, Destination: dest, Port: port, Protocol: protocol, @@ -174,11 +222,42 @@ func (b *Broker) Request(dest string, port int, protocol string, timeout time.Du opt(&cfg) } + var dedupKey string + if !cfg.noCoalesce { + dedupKey = dest + ":" + strconv.Itoa(port) + } + + // Single deadline for the entire request lifecycle (covers both the + // primary path and the coalesced-subscriber path). + deadline := time.NewTimer(timeout) + defer deadline.Stop() + b.mu.Lock() if b.closed { b.mu.Unlock() return ResponseDeny, fmt.Errorf("approval broker shutting down") } + + // Coalesce: attach to an existing pending prompt for the same target. + // This runs before the pending/rate-limit checks because a coalesced + // subscriber consumes neither budget — the primary already did, and the + // whole point is to avoid both the prompt wall and spurious rate-limit + // denials for a burst the operator will answer with a single tap. + if dedupKey != "" { + if primaryID, ok := b.dedupIndex[dedupKey]; ok { + if w, ok := b.waiters[primaryID]; ok { + subCh := make(chan Response, 1) + w.subs = append(w.subs, subCh) + w.count++ + b.waiters[primaryID] = w + count := w.count + b.mu.Unlock() + b.notifyCoalesced(primaryID, count) + return b.waitSub(primaryID, subCh, deadline.C, timeout) + } + } + } + // Check pending limit. if b.MaxPendingRequests > 0 && len(b.waiters) >= b.MaxPendingRequests { b.mu.Unlock() @@ -218,17 +297,18 @@ func (b *Broker) Request(dest string, port int, protocol string, timeout time.Du } } + id := fmt.Sprintf("req_%d", b.nextID.Add(1)) + cfg.req.ID = id req := cfg.req - b.waiters[id] = waiter{ch: ch, req: req} + b.waiters[id] = waiter{ch: ch, req: req, count: 1, dedupKey: dedupKey} + if dedupKey != "" { + b.dedupIndex[dedupKey] = id + } b.mu.Unlock() // Broadcast to all channels (non-blocking). b.broadcast(req) - // Single deadline for the entire request lifecycle. - deadline := time.NewTimer(timeout) - defer deadline.Stop() - select { case resp := <-ch: return resp, nil @@ -241,16 +321,31 @@ func (b *Broker) Request(dest string, port int, protocol string, timeout time.Du default: } b.mu.Lock() - delete(b.waiters, id) + w, ok := b.waiters[id] + if ok { + delete(b.waiters, id) + if w.dedupKey != "" { + delete(b.dedupIndex, w.dedupKey) + } + } b.mu.Unlock() + // Fan the terminal deny to any coalesced subscribers (buffered + // cap 1, so a send to a detached sub never blocks). + for _, sub := range w.subs { + sub <- ResponseDeny + } b.cancelOnChannels(id) return ResponseDeny, fmt.Errorf("approval broker shutting down") case <-deadline.C: b.mu.Lock() - _, stillPending := b.waiters[id] + w, stillPending := b.waiters[id] if stillPending { delete(b.waiters, id) + if w.dedupKey != "" { + delete(b.dedupIndex, w.dedupKey) + } b.timedOut[id] = b.now() + b.recordCoalescedLocked(id, w.count) // Garbage-collect stale timedOut entries. now := b.now() for k, t := range b.timedOut { @@ -266,11 +361,101 @@ func (b *Broker) Request(dest string, port int, protocol string, timeout time.Du // what the channel showed to the operator. return <-ch, nil } + // Fan the terminal deny to every coalesced subscriber. + for _, sub := range w.subs { + sub <- ResponseDeny + } b.cancelOnChannels(id) return ResponseDeny, fmt.Errorf("approval timeout after %v", timeout) } } +// waitSub blocks a coalesced subscriber until the primary resolves (its +// response is fanned to subCh), the deadline fires, or the broker shuts +// down. A subscriber owns no waiter/timedOut entry: on timeout it detaches +// only itself from the primary's subs slice and must never tear down the +// shared primary waiter. +func (b *Broker) waitSub(primaryID string, subCh chan Response, deadlineC <-chan time.Time, timeout time.Duration) (Response, error) { + select { + case resp := <-subCh: + return resp, nil + case <-b.done: + // Prefer a response the primary may have already fanned out. + select { + case resp := <-subCh: + return resp, nil + default: + } + b.detachSub(primaryID, subCh) + return ResponseDeny, fmt.Errorf("approval broker shutting down") + case <-deadlineC: + b.detachSub(primaryID, subCh) + // The primary may have resolved between the deadline firing and + // the detach completing. The sub chan is buffered (cap 1), so a + // concurrent fan-out send already landed; honor it rather than + // denying an approved request. + select { + case resp := <-subCh: + return resp, nil + default: + } + return ResponseDeny, fmt.Errorf("approval timeout after %v", timeout) + } +} + +// detachSub removes a single subscriber channel from a primary waiter's subs +// slice if the waiter is still present. It never deletes the waiter itself. +func (b *Broker) detachSub(primaryID string, subCh chan Response) { + b.mu.Lock() + defer b.mu.Unlock() + w, ok := b.waiters[primaryID] + if !ok { + return + } + for i, c := range w.subs { + if c == subCh { + w.subs = append(w.subs[:i], w.subs[i+1:]...) + b.waiters[primaryID] = w + return + } + } +} + +// recordCoalescedLocked stores a resolved/timed-out primary's final coalesced +// count for a bounded TTL so message-edit paths can render it after the +// waiter is gone. Caller must hold b.mu. +func (b *Broker) recordCoalescedLocked(id string, count int) { + now := b.now() + b.coalesced[id] = coalescedRecord{count: count, at: now} + for k, r := range b.coalesced { + if now.Sub(r.at) > timedOutTTL { + delete(b.coalesced, k) + } + } +} + +// CoalescedCount reports how many requests a single approval decision covered +// for the given primary request ID. While the waiter is still pending it +// returns the live count; after resolution it returns the retained final +// count; if nothing is known it returns 1 (a lone request). +func (b *Broker) CoalescedCount(id string) int { + b.mu.Lock() + defer b.mu.Unlock() + if w, ok := b.waiters[id]; ok { + return w.count + } + if r, ok := b.coalesced[id]; ok { + return r.count + } + return 1 +} + +// notifyCoalesced is the Phase 1 no-op hook for live mid-burst "+N pending" +// indicators. Phase 2 fills this in to best-effort call channels that +// implement a CoalesceNotifier interface. Keeping the call site here means +// Phase 2 is a localized change with no churn to Request. +func (b *Broker) notifyCoalesced(_ string, _ int) {} + // broadcast sends the approval request to all channels. Errors and panics // from individual channels are logged but do not prevent other channels from // receiving the request. @@ -345,12 +530,28 @@ func (b *Broker) Resolve(id string, resp Response) bool { b.mu.Lock() w, ok := b.waiters[id] if ok { + // Delete the waiter AND its dedup index entry in the same locked + // section. This is what closes the late-attach race: any request + // that took b.mu before this point either found the waiter (and + // attached as a sub captured in w.subs below) or, after this, + // finds neither the dedupIndex entry nor the waiter and opens its + // own fresh prompt — it can never attach to a dead waiter. delete(b.waiters, id) + if w.dedupKey != "" { + delete(b.dedupIndex, w.dedupKey) + } + b.recordCoalescedLocked(id, w.count) } b.mu.Unlock() if ok { w.ch <- resp + // Fan the same response to every coalesced subscriber. All sub + // chans are buffered (cap 1), so a send to a subscriber that + // already timed out and detached never blocks. + for _, sub := range w.subs { + sub <- resp + } // Cancel on all channels so they can clean up (e.g. edit message). b.cancelOnChannels(id) } @@ -372,13 +573,18 @@ func (b *Broker) CancelAll() { waiters[id] = w } b.waiters = make(map[string]waiter) + b.dedupIndex = make(map[string]string) b.mu.Unlock() // Send deny responses before closing done. This ensures goroutines in // the select see the response on ch before they see done closed, so - // they return the response without an error. + // they return the response without an error. Coalesced subscribers are + // fanned the same deny on their buffered (cap 1) chans. for id, w := range waiters { w.ch <- ResponseDeny + for _, sub := range w.subs { + sub <- ResponseDeny + } b.cancelOnChannels(id) } diff --git a/internal/channel/channel_test.go b/internal/channel/channel_test.go index 5c1bd09..528c899 100644 --- a/internal/channel/channel_test.go +++ b/internal/channel/channel_test.go @@ -3,12 +3,20 @@ package channel import ( "context" "errors" + "fmt" "sync" "sync/atomic" "testing" "time" ) +// result bundles a Broker.Request return for fan-in over a channel in +// concurrency tests. +type result struct { + resp Response + err error +} + // mockChannel implements Channel for testing. type mockChannel struct { typ ChannelType @@ -353,12 +361,15 @@ func TestBrokerPendingLimitExceeded(t *testing.T) { broker := NewBroker([]Channel{ch1}, WithMaxPending(3)) // Fill up the pending slots by sending requests that won't be resolved. + // Distinct destinations so each opens its own waiter (same-dest:port + // requests now coalesce onto a single waiter by design). var wg sync.WaitGroup for i := 0; i < 3; i++ { + dest := fmt.Sprintf("example-%d.com", i) wg.Add(1) go func() { defer wg.Done() - _, _ = broker.Request("example.com", 443, "", 2*time.Second) + _, _ = broker.Request(dest, 443, "", 2*time.Second) }() } // Wait until all 3 are registered as waiters. @@ -507,10 +518,13 @@ func TestBrokerCancelAllDeniesAllPending(t *testing.T) { } results := make(chan result, n) - // Start n requests that will block waiting for approval. + // Start n requests that will block waiting for approval. Distinct + // destinations so each registers its own waiter (same-dest:port + // requests now coalesce by design). for i := 0; i < n; i++ { + dest := fmt.Sprintf("cancel-test-%d.com", i) go func() { - resp, err := broker.Request("cancel-test.com", 443, "", 5*time.Second) + resp, err := broker.Request(dest, 443, "", 5*time.Second) results <- result{resp, err} }() } @@ -844,3 +858,444 @@ func TestBrokerChannelErrorDoesNotBlockOthers(t *testing.T) { t.Errorf("expected AllowOnce, got %v", resp) } } + +// --- Coalescing tests (broker-level dedup by dest:port) --- + +// fireCoalescedBurst starts n concurrent Request calls to the same +// dest:port and waits until the broker reports all n have attached to a +// single primary waiter. It returns the primary request ID and a channel +// that yields each call's (resp, err) result. +func fireCoalescedBurst(t *testing.T, broker *Broker, ch *mockChannel, dest string, port, n int, timeout time.Duration) (string, <-chan result) { + t.Helper() + type res = result + out := make(chan res, n) + for i := 0; i < n; i++ { + go func() { + resp, err := broker.Request(dest, port, "https", timeout) + out <- res{resp, err} + }() + } + // Wait for the primary prompt to land. + deadline := time.After(5 * time.Second) + for { + reqs := ch.getRequests() + if len(reqs) >= 1 { + id := reqs[0].ID + if broker.CoalescedCount(id) >= n { + return id, out + } + } + select { + case <-deadline: + t.Fatalf("burst did not fully coalesce: got %d requests, count=%v", + len(ch.getRequests()), func() int { + if r := ch.getRequests(); len(r) > 0 { + return broker.CoalescedCount(r[0].ID) + } + return 0 + }()) + default: + time.Sleep(time.Millisecond) + } + } +} + +func TestBrokerCoalesceOneBroadcastFanToAll(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const n = 8 + primaryID, out := fireCoalescedBurst(t, broker, ch, "cas.example.com", 443, n, 5*time.Second) + + // Exactly one prompt was broadcast for the whole burst. + if got := len(ch.getRequests()); got != 1 { + t.Fatalf("expected exactly 1 broadcast, got %d", got) + } + if c := broker.CoalescedCount(primaryID); c != n { + t.Fatalf("expected coalesced count %d, got %d", n, c) + } + if pc := broker.PendingCount(); pc != 1 { + t.Fatalf("expected 1 pending waiter, got %d", pc) + } + + if !broker.Resolve(primaryID, ResponseAlwaysAllow) { + t.Fatal("Resolve returned false for primary") + } + + for i := 0; i < n; i++ { + r := <-out + if r.err != nil { + t.Errorf("request %d: unexpected error %v", i, r.err) + } + if r.resp != ResponseAlwaysAllow { + t.Errorf("request %d: expected AlwaysAllow, got %v", i, r.resp) + } + } + // Final count retained for message-edit paths after the waiter is gone. + if c := broker.CoalescedCount(primaryID); c != n { + t.Errorf("expected retained coalesced count %d, got %d", n, c) + } +} + +func TestBrokerCoalesceDenyFanOut(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const n = 5 + primaryID, out := fireCoalescedBurst(t, broker, ch, "deny.example.com", 443, n, 5*time.Second) + broker.Resolve(primaryID, ResponseDeny) + + for i := 0; i < n; i++ { + r := <-out + if r.resp != ResponseDeny { + t.Errorf("request %d: expected Deny, got %v", i, r.resp) + } + } +} + +func TestBrokerCoalesceTimeoutFanOut(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const n = 4 + // No resolve: the primary times out and fans the terminal Deny to + // every subscriber. The primary itself returns the timeout error; + // subscribers receive Deny via the fan-out (nil err, like any + // terminal resolution). Every caller must end up denied. + _, out := fireCoalescedBurst(t, broker, ch, "slowburst.example.com", 443, n, 80*time.Millisecond) + + timeoutErrs := 0 + for i := 0; i < n; i++ { + r := <-out + if r.resp != ResponseDeny { + t.Errorf("request %d: expected Deny on timeout, got %v", i, r.resp) + } + if r.err != nil { + timeoutErrs++ + } + } + if timeoutErrs == 0 { + t.Error("expected at least the primary to report a timeout error") + } +} + +func TestBrokerCoalesceShutdownFanOut(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const n = 6 + _, out := fireCoalescedBurst(t, broker, ch, "shutdown.example.com", 443, n, 5*time.Second) + broker.CancelAll() + + for i := 0; i < n; i++ { + r := <-out + if r.resp != ResponseDeny { + t.Errorf("request %d: expected Deny on shutdown, got %v", i, r.resp) + } + } +} + +func TestBrokerCoalesceSubTimeoutDoesNotBlockFanOut(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + // Primary with a long timeout so it stays pending. + primaryOut := make(chan result, 1) + go func() { + resp, err := broker.Request("subtimeout.example.com", 443, "https", 5*time.Second) + primaryOut <- result{resp, err} + }() + var primaryID string + for { + reqs := ch.getRequests() + if len(reqs) == 1 { + primaryID = reqs[0].ID + break + } + time.Sleep(time.Millisecond) + } + + // A coalesced sub with a very short timeout: it detaches itself. + subOut := make(chan result, 1) + go func() { + resp, err := broker.Request("subtimeout.example.com", 443, "https", 30*time.Millisecond) + subOut <- result{resp, err} + }() + // Wait for the sub to attach (count == 2) then time out (count back + // near 1 once it detaches; tolerate the race by just waiting for the + // sub result). + for broker.CoalescedCount(primaryID) < 2 { + time.Sleep(time.Millisecond) + } + sr := <-subOut + if sr.resp != ResponseDeny || sr.err == nil { + t.Fatalf("sub should have timed out with Deny+err, got %v / %v", sr.resp, sr.err) + } + + // Resolving the primary must not block on the departed sub. + done := make(chan bool, 1) + go func() { done <- broker.Resolve(primaryID, ResponseAllowOnce) }() + select { + case ok := <-done: + if !ok { + t.Fatal("Resolve returned false") + } + case <-time.After(2 * time.Second): + t.Fatal("Resolve blocked on a detached sub") + } + pr := <-primaryOut + if pr.resp != ResponseAllowOnce { + t.Errorf("primary: expected AllowOnce, got %v", pr.resp) + } +} + +func TestBrokerCoalesceLateAttachOpensNewPrompt(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + // First prompt for the target. + out1 := make(chan result, 1) + go func() { + resp, err := broker.Request("late.example.com", 443, "https", 5*time.Second) + out1 <- result{resp, err} + }() + var id1 string + for { + reqs := ch.getRequests() + if len(reqs) == 1 { + id1 = reqs[0].ID + break + } + time.Sleep(time.Millisecond) + } + + // Resolve the first; dedupIndex entry is cleared in the same locked + // section as the waiter delete. + broker.Resolve(id1, ResponseAllowOnce) + if r := <-out1; r.resp != ResponseAllowOnce { + t.Fatalf("first request: expected AllowOnce, got %v", r.resp) + } + + // A new request to the same target after resolution must NOT attach to + // the dead waiter — it must open a fresh prompt with a new ID. + out2 := make(chan result, 1) + go func() { + resp, err := broker.Request("late.example.com", 443, "https", 5*time.Second) + out2 <- result{resp, err} + }() + var id2 string + for { + reqs := ch.getRequests() + if len(reqs) == 2 { + id2 = reqs[1].ID + break + } + select { + case <-time.After(2 * time.Second): + t.Fatal("late request did not open a new prompt (attached to dead waiter)") + default: + time.Sleep(time.Millisecond) + } + } + if id2 == id1 { + t.Fatalf("late request reused dead primary id %q", id1) + } + broker.Resolve(id2, ResponseDeny) + if r := <-out2; r.resp != ResponseDeny { + t.Errorf("second request: expected Deny, got %v", r.resp) + } +} + +func TestBrokerCoalesceConcurrentResolveAndAttach(t *testing.T) { + // Stress the resolve/attach interleave: no sub may end up attached to + // a deleted waiter (which would hang forever) and none may be lost. + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const rounds = 40 + for round := 0; round < rounds; round++ { + dest := fmt.Sprintf("race-%d.example.com", round) + out := make(chan result, 3) + for i := 0; i < 3; i++ { + go func() { + resp, err := broker.Request(dest, 443, "https", 3*time.Second) + out <- result{resp, err} + }() + } + // Resolve as soon as the first prompt appears, racing the other + // two arrivals (some attach as subs, some open fresh prompts). + var firstID string + for { + for _, r := range ch.getRequests() { + if r.Destination == dest { + firstID = r.ID + break + } + } + if firstID != "" { + break + } + time.Sleep(time.Microsecond * 200) + } + // Keep resolving every pending prompt for this dest until all + // three callers return. Any caller that opened its own prompt + // gets resolved here too. + got := 0 + timeout := time.After(3 * time.Second) + for got < 3 { + for _, r := range broker.PendingRequests() { + if r.Destination == dest { + broker.Resolve(r.ID, ResponseAllowOnce) + } + } + select { + case res := <-out: + if res.resp != ResponseAllowOnce { + t.Fatalf("round %d: expected AllowOnce, got %v (err %v)", round, res.resp, res.err) + } + got++ + case <-timeout: + t.Fatalf("round %d: only %d/3 callers returned (deadlock?)", round, got) + default: + time.Sleep(time.Microsecond * 200) + } + } + } +} + +func TestBrokerDistinctDestNotCoalesced(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + out := make(chan result, 2) + go func() { + resp, err := broker.Request("a.example.com", 443, "https", 5*time.Second) + out <- result{resp, err} + }() + go func() { + resp, err := broker.Request("b.example.com", 443, "https", 5*time.Second) + out <- result{resp, err} + }() + // Two distinct targets -> two waiters, two broadcasts. + for broker.PendingCount() < 2 { + time.Sleep(time.Millisecond) + } + if got := len(ch.getRequests()); got != 2 { + t.Fatalf("expected 2 broadcasts for distinct targets, got %d", got) + } + for _, r := range broker.PendingRequests() { + broker.Resolve(r.ID, ResponseAllowOnce) + } + <-out + <-out +} + +func TestBrokerSamePortDifferentDestNotCoalesced(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + out := make(chan result, 2) + // Same host, different port -> different dedup key, not coalesced. + go func() { + resp, err := broker.Request("svc.example.com", 443, "https", 5*time.Second) + out <- result{resp, err} + }() + go func() { + resp, err := broker.Request("svc.example.com", 8443, "https", 5*time.Second) + out <- result{resp, err} + }() + for broker.PendingCount() < 2 { + time.Sleep(time.Millisecond) + } + if got := len(ch.getRequests()); got != 2 { + t.Fatalf("expected 2 broadcasts for differing ports, got %d", got) + } + for _, r := range broker.PendingRequests() { + broker.Resolve(r.ID, ResponseAllowOnce) + } + <-out + <-out +} + +func TestBrokerWithNoCoalesceNeverCoalesces(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const n = 4 + out := make(chan result, n) + for i := 0; i < n; i++ { + go func() { + resp, err := broker.Request("mcp-tool", 0, "mcp", 5*time.Second, WithNoCoalesce()) + out <- result{resp, err} + }() + } + for broker.PendingCount() < n { + time.Sleep(time.Millisecond) + } + if got := len(ch.getRequests()); got != n { + t.Fatalf("WithNoCoalesce: expected %d separate prompts, got %d", n, got) + } + for _, r := range broker.PendingRequests() { + broker.Resolve(r.ID, ResponseAllowOnce) + } + for i := 0; i < n; i++ { + <-out + } +} + +func TestBrokerCoalesceCrossChannelFirstWins(t *testing.T) { + ch1 := newMockChannel(ChannelTelegram) + ch2 := newMockChannel(ChannelHTTP) + broker := NewBroker([]Channel{ch1, ch2}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const n = 5 + out := make(chan result, n) + for i := 0; i < n; i++ { + go func() { + resp, err := broker.Request("xchan.example.com", 443, "https", 5*time.Second) + out <- result{resp, err} + }() + } + var primaryID string + for { + reqs := ch1.getRequests() + if len(reqs) == 1 && broker.CoalescedCount(reqs[0].ID) >= n { + primaryID = reqs[0].ID + break + } + time.Sleep(time.Millisecond) + } + // Both channels saw exactly one prompt (the primary). + if len(ch1.getRequests()) != 1 || len(ch2.getRequests()) != 1 { + t.Fatalf("expected 1 prompt per channel, got ch1=%d ch2=%d", + len(ch1.getRequests()), len(ch2.getRequests())) + } + // Two channels race to resolve the same primary; first wins, and the + // whole coalesced burst gets that winner's response. + r1 := make(chan bool, 1) + r2 := make(chan bool, 1) + go func() { r1 <- broker.Resolve(primaryID, ResponseAlwaysAllow) }() + go func() { r2 <- broker.Resolve(primaryID, ResponseDeny) }() + wins := 0 + if <-r1 { + wins++ + } + if <-r2 { + wins++ + } + if wins != 1 { + t.Fatalf("expected exactly 1 winning Resolve, got %d", wins) + } + first := result{} + for i := 0; i < n; i++ { + r := <-out + if i == 0 { + first = r + } else if r.resp != first.resp { + t.Fatalf("coalesced burst got mixed responses: %v vs %v", first.resp, r.resp) + } + } + if first.resp != ResponseAlwaysAllow && first.resp != ResponseDeny { + t.Fatalf("unexpected response %v", first.resp) + } +} From e7bde0e6a3ab51d2835f4e1ded665ddf747be58b Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:22:50 +0800 Subject: [PATCH 02/49] wip(proxy): checkpoint persist-once + store work before respawn --- internal/proxy/request_policy_test.go | 29 +++++++++ internal/proxy/server.go | 12 ++++ internal/proxy/server_test.go | 85 +++++++++++++++++++++++++++ internal/store/store.go | 31 ++++++++++ internal/store/store_test.go | 81 +++++++++++++++++++++++++ 5 files changed, 238 insertions(+) diff --git a/internal/proxy/request_policy_test.go b/internal/proxy/request_policy_test.go index 9a93fc3..549d9e1 100644 --- a/internal/proxy/request_policy_test.go +++ b/internal/proxy/request_policy_test.go @@ -20,6 +20,11 @@ type fakeChannel struct { response channel.Response requests []channel.ApprovalRequest onRequestCh chan struct{} + // release, when non-nil, gates resolution: the resolve goroutine waits + // for this channel to be closed before calling broker.Resolve. Used to + // deterministically hold a primary prompt pending while concurrent + // requests to the same target coalesce onto it. + release chan struct{} } func newFakeChannel(resp channel.Response) *fakeChannel { @@ -31,10 +36,14 @@ func (f *fakeChannel) RequestApproval(_ context.Context, req channel.ApprovalReq f.requests = append(f.requests, req) resp := f.response broker := f.broker + release := f.release f.mu.Unlock() // Resolve asynchronously so the broker goroutine can register the // waiter before we deliver the response. go func() { + if release != nil { + <-release + } broker.Resolve(req.ID, resp) }() select { @@ -63,6 +72,26 @@ func (f *fakeChannel) requestCount() int { return len(f.requests) } +// gate installs a release channel so primary prompts stay pending until +// releaseAll() is called. Returns the close func. +func (f *fakeChannel) gate() func() { + f.mu.Lock() + ch := make(chan struct{}) + f.release = ch + f.mu.Unlock() + var once sync.Once + return func() { once.Do(func() { close(ch) }) } +} + +func (f *fakeChannel) firstReqID() string { + f.mu.Lock() + defer f.mu.Unlock() + if len(f.requests) == 0 { + return "" + } + return f.requests[0].ID +} + // newTestChecker builds a RequestPolicyChecker wired to a fake // channel/broker and a policy engine loaded from the given TOML. func newTestChecker(t *testing.T, toml string, resp channel.Response) (*RequestPolicyChecker, *fakeChannel) { diff --git a/internal/proxy/server.go b/internal/proxy/server.go index b16551c..3fdfa6f 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -512,6 +512,18 @@ func (r *policyRuleSet) persistApprovalRule(verdict, dest string, port int) bool log.Printf("[WARN] always-%s for %s:%d not persisted (no store)", verdict, dest, port) return false } + // Idempotent under coalesced approval fan-out: a burst of "Always + // Allow"/"Always Deny" responses for one target serializes here on + // reloadMu. The first caller inserts the rule; the rest see it already + // present and skip the redundant AddRule + engine recompile. The + // returned bool stays true so every caller treats the decision as + // persisted. + if exists, existsErr := r.store.HasApprovalRule(verdict, dest, port); existsErr != nil { + log.Printf("[WARN] failed to check existing %s rule for %s:%d: %v", verdict, dest, port, existsErr) + } else if exists { + log.Printf("[approval] %s rule for %s:%d already present; skipping duplicate persist", verdict, dest, port) + return true + } if _, storeErr := r.store.AddRule(verdict, store.RuleOpts{Destination: dest, Ports: []int{port}, Source: "approval"}); storeErr != nil { log.Printf("[WARN] failed to persist %s rule for %s:%d: %v", verdict, dest, port, storeErr) return false diff --git a/internal/proxy/server_test.go b/internal/proxy/server_test.go index f533820..3a9974e 100644 --- a/internal/proxy/server_test.go +++ b/internal/proxy/server_test.go @@ -15,6 +15,7 @@ import ( "net/http/httptest" "strings" "sync" + "sync/atomic" "testing" "time" @@ -4958,3 +4959,87 @@ func TestQUICSNIAccumulatorTTLCleanup(t *testing.T) { t.Error("fresh accumulator should still be present") } } + +// TestPersistApprovalRuleIdempotentUnderConcurrency exercises the +// persist-once path: a burst of coalesced "Always Allow" responses for one +// target must serialize on reloadMu and produce exactly one stored rule +// (first inserts, the rest see HasApprovalRule==true and no-op), while every +// caller still observes success. +func TestPersistApprovalRuleIdempotentUnderConcurrency(t *testing.T) { + st, err := store.New(":memory:") + if err != nil { + t.Fatalf("store.New: %v", err) + } + defer func() { _ = st.Close() }() + + eng, err := policy.LoadFromStore(st) + if err != nil { + t.Fatalf("LoadFromStore: %v", err) + } + engPtr := new(atomic.Pointer[policy.Engine]) + engPtr.Store(eng) + var reloadMu sync.Mutex + r := &policyRuleSet{engine: engPtr, reloadMu: &reloadMu, store: st} + + const m = 16 + var wg sync.WaitGroup + for i := 0; i < m; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if !r.persistApprovalRule("allow", "cas.example.com", 443) { + t.Error("persistApprovalRule returned false") + } + }() + } + wg.Wait() + + rules, err := st.ListRules(store.RuleFilter{}) + if err != nil { + t.Fatalf("ListRules: %v", err) + } + matches := 0 + for _, ru := range rules { + if ru.Destination == "cas.example.com" && ru.Source == "approval" { + matches++ + } + } + if matches != 1 { + t.Fatalf("expected exactly 1 persisted approval rule, got %d", matches) + } +} + +// TestPersistApprovalRuleSinglePersistUnchanged guards the non-coalesced +// path: a single Always-Allow still writes exactly one rule and recompiles +// the engine. +func TestPersistApprovalRuleSinglePersistUnchanged(t *testing.T) { + st, err := store.New(":memory:") + if err != nil { + t.Fatalf("store.New: %v", err) + } + defer func() { _ = st.Close() }() + + eng, err := policy.LoadFromStore(st) + if err != nil { + t.Fatalf("LoadFromStore: %v", err) + } + engPtr := new(atomic.Pointer[policy.Engine]) + engPtr.Store(eng) + var reloadMu sync.Mutex + r := &policyRuleSet{engine: engPtr, reloadMu: &reloadMu, store: st} + + if !r.persistApprovalRule("deny", "blocked.example.com", 443) { + t.Fatal("persistApprovalRule returned false") + } + has, err := st.HasApprovalRule("deny", "blocked.example.com", 443) + if err != nil { + t.Fatalf("HasApprovalRule: %v", err) + } + if !has { + t.Fatal("expected the rule to be persisted") + } + // The freshly compiled engine must be installed (pointer swapped). + if engPtr.Load() == eng { + t.Error("expected engine pointer to be swapped after persist") + } +} diff --git a/internal/store/store.go b/internal/store/store.go index 64aaedc..74f92de 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -304,6 +304,37 @@ func (s *Store) RuleExists(verdict string, opts RuleExistsOpts) (bool, error) { return count > 0, nil } +// HasApprovalRule reports whether an approval-sourced rule already exists for +// the given verdict, destination, and single port. It is a read-only SELECT +// (no migration) used by the proxy to make approval-rule persistence +// idempotent: when a burst of coalesced "Always Allow"/"Always Deny" +// responses fan out, the first caller inserts the rule and the rest see it +// already present and skip the insert + engine recompile. +// +// The match is intentionally narrow — source='approval', exact verdict, +// exact destination, and exact ports JSON for the single approval port — +// mirroring exactly what persistApprovalRule writes via AddRule. It is not a +// general dedup for manually added rules. +func (s *Store) HasApprovalRule(verdict, dest string, port int) (bool, error) { + if verdict == "" || dest == "" { + return false, fmt.Errorf("verdict and destination are required") + } + portsJSON := portsToJSONPtr([]int{port}) + query := "SELECT COUNT(*) FROM rules WHERE source = 'approval' AND verdict = ? AND destination = ? AND " + args := []any{verdict, dest} + if portsJSON != nil { + query += "ports = ?" + args = append(args, *portsJSON) + } else { + query += "ports IS NULL" + } + var count int + if err := s.db.QueryRow(query, args...).Scan(&count); err != nil { + return false, err + } + return count > 0, nil +} + // --- Config --- // Config represents the typed singleton row in the config table. diff --git a/internal/store/store_test.go b/internal/store/store_test.go index d04d1eb..3998042 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -4587,3 +4587,84 @@ func TestAddBindingValidation(t *testing.T) { t.Errorf("unexpected error for valid input: %v", err) } } + +// --- HasApprovalRule (idempotent approval persistence) --- + +func TestHasApprovalRule(t *testing.T) { + s := newTestStore(t) + + // No approval rule yet. + has, err := s.HasApprovalRule("allow", "cas.example.com", 443) + if err != nil { + t.Fatalf("HasApprovalRule: %v", err) + } + if has { + t.Fatal("expected no approval rule before insert") + } + + // Insert exactly what persistApprovalRule writes. + if _, err := s.AddRule("allow", RuleOpts{ + Destination: "cas.example.com", Ports: []int{443}, Source: "approval", + }); err != nil { + t.Fatalf("AddRule: %v", err) + } + + has, err = s.HasApprovalRule("allow", "cas.example.com", 443) + if err != nil { + t.Fatalf("HasApprovalRule: %v", err) + } + if !has { + t.Fatal("expected approval rule to be found after insert") + } + + // Mismatches must not match. + cases := []struct { + name string + verdict, dest string + port int + }{ + {"different verdict", "deny", "cas.example.com", 443}, + {"different dest", "allow", "other.example.com", 443}, + {"different port", "allow", "cas.example.com", 8443}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + has, err := s.HasApprovalRule(c.verdict, c.dest, c.port) + if err != nil { + t.Fatalf("HasApprovalRule: %v", err) + } + if has { + t.Errorf("expected no match for %s", c.name) + } + }) + } +} + +func TestHasApprovalRuleIgnoresNonApprovalSource(t *testing.T) { + s := newTestStore(t) + + // A manually-added rule with the same shape must NOT be treated as an + // approval rule (HasApprovalRule filters source='approval'). + if _, err := s.AddRule("allow", RuleOpts{ + Destination: "manual.example.com", Ports: []int{443}, Source: "manual", + }); err != nil { + t.Fatalf("AddRule: %v", err) + } + has, err := s.HasApprovalRule("allow", "manual.example.com", 443) + if err != nil { + t.Fatalf("HasApprovalRule: %v", err) + } + if has { + t.Error("manual-sourced rule must not satisfy HasApprovalRule") + } +} + +func TestHasApprovalRuleValidatesInput(t *testing.T) { + s := newTestStore(t) + if _, err := s.HasApprovalRule("", "x", 443); err == nil { + t.Error("expected error for empty verdict") + } + if _, err := s.HasApprovalRule("allow", "", 443); err == nil { + t.Error("expected error for empty destination") + } +} From b76a0f35742eb06ea7ae84d75066885e39de3008 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:24:03 +0800 Subject: [PATCH 03/49] test(proxy): assert concurrent same-target asks coalesce to one prompt --- internal/proxy/request_policy_test.go | 37 ++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/internal/proxy/request_policy_test.go b/internal/proxy/request_policy_test.go index 549d9e1..b9ecbd3 100644 --- a/internal/proxy/request_policy_test.go +++ b/internal/proxy/request_policy_test.go @@ -677,11 +677,13 @@ destination = "api.example.com" } } -func TestRequestPolicyChecker_ConcurrentAllowOnceSerializesApprovals(t *testing.T) { +func TestRequestPolicyChecker_ConcurrentAllowOnceCoalesces(t *testing.T) { // Multiple HTTP/2 streams on the same connection may invoke - // CheckAndConsume concurrently. Each call must be independently - // answered by the broker (no single-shot caching). This test - // confirms the contract holds up under concurrency. + // CheckAndConsume concurrently against the same dest:port. Broker-level + // coalescing collapses a concurrent burst onto ONE pending prompt: the + // operator taps once and every in-flight request gets that verdict. + // (Sequential, non-overlapping requests still each ask — coalescing + // only applies while a prompt is actually pending.) toml := ` [policy] default = "deny" @@ -690,6 +692,10 @@ default = "deny" destination = "api.example.com" ` checker, fc := newTestChecker(t, toml, channel.ResponseAllowOnce) + // Hold the primary prompt pending so the whole burst coalesces onto it + // deterministically before any resolution happens. + releaseAll := fc.gate() + const n = 5 var wg sync.WaitGroup results := make([]policy.Verdict, n) @@ -704,6 +710,24 @@ destination = "api.example.com" results[i] = v }(i) } + + // Wait until exactly one prompt is broadcast and all n requests have + // coalesced onto it (primary + n-1 subscribers => count == n). + waitDeadline := time.After(3 * time.Second) + for { + id := fc.firstReqID() + if id != "" && checker.broker.CoalescedCount(id) >= n { + break + } + select { + case <-waitDeadline: + t.Fatalf("burst did not coalesce: requests=%d", fc.requestCount()) + default: + time.Sleep(time.Millisecond) + } + } + + releaseAll() done := make(chan struct{}) go func() { wg.Wait(); close(done) }() select { @@ -716,7 +740,8 @@ destination = "api.example.com" t.Errorf("result[%d] = %v, want Allow", i, v) } } - if fc.requestCount() != n { - t.Errorf("broker request count = %d, want %d (every request should ask)", fc.requestCount(), n) + // Exactly one prompt for the whole coalesced burst. + if fc.requestCount() != 1 { + t.Errorf("broker request count = %d, want 1 (concurrent burst coalesces)", fc.requestCount()) } } From 4d3b326e74758210c6d389a3d521297138742e1f Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:26:01 +0800 Subject: [PATCH 04/49] feat(store): add credential pool + health schema and store API Migration 000006 adds credential_pools, credential_pool_members and credential_health tables. Store API: CreatePoolWithMembers (atomic; namespace mutual-exclusion + oauth-member validation in one tx), GetPool, ListPools, RemovePool, PoolExists, PoolsForMember, and Set/Get/ListCredentialHealth. Covered by CRUD/ordering/validation tests and an up/down migration reversibility test. --- .../000006_credential_pools.down.sql | 3 + .../migrations/000006_credential_pools.up.sql | 36 ++ internal/store/pools.go | 360 ++++++++++++++++++ internal/store/pools_test.go | 275 +++++++++++++ 4 files changed, 674 insertions(+) create mode 100644 internal/store/migrations/000006_credential_pools.down.sql create mode 100644 internal/store/migrations/000006_credential_pools.up.sql create mode 100644 internal/store/pools.go create mode 100644 internal/store/pools_test.go diff --git a/internal/store/migrations/000006_credential_pools.down.sql b/internal/store/migrations/000006_credential_pools.down.sql new file mode 100644 index 0000000..c91e520 --- /dev/null +++ b/internal/store/migrations/000006_credential_pools.down.sql @@ -0,0 +1,3 @@ +DROP TABLE IF EXISTS credential_health; +DROP TABLE IF EXISTS credential_pool_members; +DROP TABLE IF EXISTS credential_pools; diff --git a/internal/store/migrations/000006_credential_pools.up.sql b/internal/store/migrations/000006_credential_pools.up.sql new file mode 100644 index 0000000..f7e531f --- /dev/null +++ b/internal/store/migrations/000006_credential_pools.up.sql @@ -0,0 +1,36 @@ +-- Credential pools with auto-failover. +-- +-- A pool is a named group of OAuth credentials. A single phantom identity +-- the agent sees is backed by N real OAuth credentials; sluice picks which +-- real account to inject and fails over between members on 429/401. +-- +-- credential_pools one row per pool (strategy reserved: failover only) +-- credential_pool_members ordered membership; position drives failover order +-- credential_health per-credential health used to skip cooled-down +-- members during active-member selection +-- +-- A pool name and a credential name share one namespace; mutual exclusion +-- is enforced at the application layer (see store.CreatePoolWithMembers and +-- the cred-add path), not by a cross-table SQL constraint. + +CREATE TABLE credential_pools ( + name TEXT PRIMARY KEY, + strategy TEXT NOT NULL DEFAULT 'failover' CHECK(strategy IN ('failover')), + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +CREATE TABLE credential_pool_members ( + pool TEXT NOT NULL, + credential TEXT NOT NULL, + position INTEGER NOT NULL, + PRIMARY KEY (pool, credential), + FOREIGN KEY (pool) REFERENCES credential_pools(name) ON DELETE CASCADE +); + +CREATE TABLE credential_health ( + credential TEXT PRIMARY KEY, + status TEXT NOT NULL DEFAULT 'healthy' CHECK(status IN ('healthy','cooldown')), + cooldown_until TEXT, + last_failure_reason TEXT, + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); diff --git a/internal/store/pools.go b/internal/store/pools.go new file mode 100644 index 0000000..245672c --- /dev/null +++ b/internal/store/pools.go @@ -0,0 +1,360 @@ +package store + +import ( + "database/sql" + "fmt" + "time" +) + +// PoolStrategyFailover is the only supported pool strategy. Round-robin and +// weighted strategies are reserved for future work; the schema CHECK keeps +// the column constrained to this value. +const PoolStrategyFailover = "failover" + +// Pool is a named group of OAuth credentials backing a single phantom +// identity. Members are returned ordered by position (failover order). +type Pool struct { + Name string + Strategy string + CreatedAt string + Members []PoolMember +} + +// PoolMember is one credential entry in a pool. Position determines the +// failover order (lowest first). +type PoolMember struct { + Credential string + Position int +} + +// CredentialHealth records whether a credential is currently eligible for +// injection. A cooled-down member is skipped during active-member selection +// until CooldownUntil passes (lazy recovery, no scheduler). +type CredentialHealth struct { + Credential string + Status string // "healthy" or "cooldown" + CooldownUntil time.Time // zero when Status == "healthy" or unset + LastFailureReason string + UpdatedAt string +} + +// parseHealthTime parses a cooldown_until value. Values are written as +// RFC3339 (SetCredentialHealth), but a NULL or empty string yields the zero +// time. A legacy "2006-01-02 15:04:05" SQLite datetime form is also accepted +// defensively. +func parseHealthTime(s sql.NullString) time.Time { + if !s.Valid || s.String == "" { + return time.Time{} + } + if t, err := time.Parse(time.RFC3339, s.String); err == nil { + return t + } + if t, err := time.Parse("2006-01-02 15:04:05", s.String); err == nil { + return t.UTC() + } + return time.Time{} +} + +// PoolExists reports whether a pool with the given name exists. +func (s *Store) PoolExists(name string) (bool, error) { + var one int + err := s.db.QueryRow("SELECT 1 FROM credential_pools WHERE name = ?", name).Scan(&one) + if err == sql.ErrNoRows { + return false, nil + } + if err != nil { + return false, fmt.Errorf("check pool exists %q: %w", name, err) + } + return true, nil +} + +// validatePoolMemberTx verifies a credential is an existing OAuth credential +// with a non-empty token_url. Static credentials are rejected because the +// pool failover machinery is OAuth-specific (phantom indirection, refresh +// attribution). Runs inside the supplied transaction so the check and the +// member insert are atomic. +func validatePoolMemberTx(tx *sql.Tx, credential string) error { + var credType string + var tokenURL sql.NullString + err := tx.QueryRow( + "SELECT cred_type, token_url FROM credential_meta WHERE name = ?", credential, + ).Scan(&credType, &tokenURL) + if err == sql.ErrNoRows { + return fmt.Errorf("credential %q does not exist (add it with --type oauth first)", credential) + } + if err != nil { + return fmt.Errorf("look up credential %q: %w", credential, err) + } + if credType != "oauth" { + return fmt.Errorf("credential %q is %s, pools require oauth credentials", credential, credType) + } + if tokenURL.String == "" { + return fmt.Errorf("credential %q has no token_url; pools require oauth credentials with a token endpoint", credential) + } + return nil +} + +// CreatePoolWithMembers creates a pool and its ordered members atomically. +// Member positions are assigned from the slice order (0-based). It enforces +// the pool/credential namespace mutual-exclusion (a pool name must not +// collide with an existing credential) and validates every member is an +// existing oauth credential with a token_url. At least two members are +// required for failover to be meaningful, but a single-member pool is +// permitted (it degrades to a plain indirection with no failover target). +func (s *Store) CreatePoolWithMembers(name, strategy string, members []string) error { + if name == "" { + return fmt.Errorf("pool name is required") + } + if strategy == "" { + strategy = PoolStrategyFailover + } + if strategy != PoolStrategyFailover { + return fmt.Errorf("invalid pool strategy %q: only %q is supported", strategy, PoolStrategyFailover) + } + if len(members) == 0 { + return fmt.Errorf("pool %q requires at least one member", name) + } + seen := make(map[string]bool, len(members)) + for _, m := range members { + if m == "" { + return fmt.Errorf("pool %q has an empty member name", name) + } + if seen[m] { + return fmt.Errorf("pool %q lists credential %q more than once", name, m) + } + seen[m] = true + } + + tx, err := s.db.Begin() + if err != nil { + return fmt.Errorf("begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + // Namespace mutual-exclusion: a pool must not shadow a credential. + var credName string + switch err := tx.QueryRow("SELECT name FROM credential_meta WHERE name = ?", name).Scan(&credName); { + case err == nil: + return fmt.Errorf("name %q is already a credential; pool and credential names share one namespace", name) + case err == sql.ErrNoRows: + // ok + default: + return fmt.Errorf("check name collision for %q: %w", name, err) + } + + if _, err := tx.Exec( + "INSERT INTO credential_pools (name, strategy) VALUES (?, ?)", name, strategy, + ); err != nil { + return fmt.Errorf("insert pool %q: %w", name, err) + } + + for i, m := range members { + if err := validatePoolMemberTx(tx, m); err != nil { + return err + } + if _, err := tx.Exec( + "INSERT INTO credential_pool_members (pool, credential, position) VALUES (?, ?, ?)", + name, m, i, + ); err != nil { + return fmt.Errorf("insert pool member %q: %w", m, err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit: %w", err) + } + return nil +} + +// GetPool returns a pool by name with members ordered by position, or nil if +// the pool does not exist. +func (s *Store) GetPool(name string) (*Pool, error) { + var p Pool + err := s.db.QueryRow( + "SELECT name, strategy, created_at FROM credential_pools WHERE name = ?", name, + ).Scan(&p.Name, &p.Strategy, &p.CreatedAt) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get pool %q: %w", name, err) + } + + rows, err := s.db.Query( + "SELECT credential, position FROM credential_pool_members WHERE pool = ? ORDER BY position", name, + ) + if err != nil { + return nil, fmt.Errorf("list pool members %q: %w", name, err) + } + defer func() { _ = rows.Close() }() + for rows.Next() { + var m PoolMember + if err := rows.Scan(&m.Credential, &m.Position); err != nil { + return nil, fmt.Errorf("scan pool member: %w", err) + } + p.Members = append(p.Members, m) + } + if err := rows.Err(); err != nil { + return nil, err + } + return &p, nil +} + +// ListPools returns all pools with their members ordered by position. +func (s *Store) ListPools() ([]Pool, error) { + rows, err := s.db.Query("SELECT name, strategy, created_at FROM credential_pools ORDER BY name") + if err != nil { + return nil, fmt.Errorf("list pools: %w", err) + } + var names []string + pools := make(map[string]*Pool) + for rows.Next() { + var p Pool + if err := rows.Scan(&p.Name, &p.Strategy, &p.CreatedAt); err != nil { + _ = rows.Close() + return nil, fmt.Errorf("scan pool: %w", err) + } + cp := p + pools[p.Name] = &cp + names = append(names, p.Name) + } + if err := rows.Err(); err != nil { + _ = rows.Close() + return nil, err + } + _ = rows.Close() + + mrows, err := s.db.Query( + "SELECT pool, credential, position FROM credential_pool_members ORDER BY pool, position", + ) + if err != nil { + return nil, fmt.Errorf("list pool members: %w", err) + } + defer func() { _ = mrows.Close() }() + for mrows.Next() { + var pool string + var m PoolMember + if err := mrows.Scan(&pool, &m.Credential, &m.Position); err != nil { + return nil, fmt.Errorf("scan pool member: %w", err) + } + if p, ok := pools[pool]; ok { + p.Members = append(p.Members, m) + } + } + if err := mrows.Err(); err != nil { + return nil, err + } + + result := make([]Pool, 0, len(names)) + for _, n := range names { + result = append(result, *pools[n]) + } + return result, nil +} + +// RemovePool deletes a pool and (via ON DELETE CASCADE) its members. Returns +// true if a pool row was deleted. +func (s *Store) RemovePool(name string) (bool, error) { + res, err := s.db.Exec("DELETE FROM credential_pools WHERE name = ?", name) + if err != nil { + return false, fmt.Errorf("delete pool %q: %w", name, err) + } + n, _ := res.RowsAffected() + return n > 0, nil +} + +// PoolsForMember returns the names of all pools that include the given +// credential as a member. Used to block "cred remove" of a live pool member +// so no dangling member rows are left behind. +func (s *Store) PoolsForMember(credential string) ([]string, error) { + rows, err := s.db.Query( + "SELECT pool FROM credential_pool_members WHERE credential = ? ORDER BY pool", credential, + ) + if err != nil { + return nil, fmt.Errorf("list pools for member %q: %w", credential, err) + } + defer func() { _ = rows.Close() }() + var pools []string + for rows.Next() { + var p string + if err := rows.Scan(&p); err != nil { + return nil, fmt.Errorf("scan pool name: %w", err) + } + pools = append(pools, p) + } + return pools, rows.Err() +} + +// SetCredentialHealth upserts a credential's health row. When status is +// "healthy" the cooldown is cleared. cooldown_until is stored as RFC3339. +func (s *Store) SetCredentialHealth(credential, status string, cooldownUntil time.Time, reason string) error { + if credential == "" { + return fmt.Errorf("credential name is required") + } + if status != "healthy" && status != "cooldown" { + return fmt.Errorf("invalid health status %q: must be healthy or cooldown", status) + } + var cu interface{} + if status == "cooldown" && !cooldownUntil.IsZero() { + cu = cooldownUntil.UTC().Format(time.RFC3339) + } else { + cu = nil + } + _, err := s.db.Exec( + `INSERT INTO credential_health (credential, status, cooldown_until, last_failure_reason, updated_at) + VALUES (?, ?, ?, ?, datetime('now')) + ON CONFLICT(credential) DO UPDATE SET + status = excluded.status, + cooldown_until = excluded.cooldown_until, + last_failure_reason = excluded.last_failure_reason, + updated_at = excluded.updated_at`, + credential, status, cu, nilIfEmpty(reason), + ) + if err != nil { + return fmt.Errorf("set credential health %q: %w", credential, err) + } + return nil +} + +// GetCredentialHealth returns the health row for a credential, or nil if no +// row exists (which callers treat as healthy). +func (s *Store) GetCredentialHealth(credential string) (*CredentialHealth, error) { + var h CredentialHealth + var cu, reason sql.NullString + err := s.db.QueryRow( + "SELECT credential, status, cooldown_until, last_failure_reason, updated_at FROM credential_health WHERE credential = ?", + credential, + ).Scan(&h.Credential, &h.Status, &cu, &reason, &h.UpdatedAt) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get credential health %q: %w", credential, err) + } + h.CooldownUntil = parseHealthTime(cu) + h.LastFailureReason = reason.String + return &h, nil +} + +// ListCredentialHealth returns all credential health rows ordered by name. +func (s *Store) ListCredentialHealth() ([]CredentialHealth, error) { + rows, err := s.db.Query( + "SELECT credential, status, cooldown_until, last_failure_reason, updated_at FROM credential_health ORDER BY credential", + ) + if err != nil { + return nil, fmt.Errorf("list credential health: %w", err) + } + defer func() { _ = rows.Close() }() + var out []CredentialHealth + for rows.Next() { + var h CredentialHealth + var cu, reason sql.NullString + if err := rows.Scan(&h.Credential, &h.Status, &cu, &reason, &h.UpdatedAt); err != nil { + return nil, fmt.Errorf("scan credential health: %w", err) + } + h.CooldownUntil = parseHealthTime(cu) + h.LastFailureReason = reason.String + out = append(out, h) + } + return out, rows.Err() +} diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go new file mode 100644 index 0000000..1ddc31d --- /dev/null +++ b/internal/store/pools_test.go @@ -0,0 +1,275 @@ +package store + +import ( + "path/filepath" + "testing" + "time" + + "github.com/golang-migrate/migrate/v4" + migsqlite "github.com/golang-migrate/migrate/v4/database/sqlite" + "github.com/golang-migrate/migrate/v4/source/iofs" +) + +// seedOAuthCred registers a credential_meta row so a pool member passes the +// oauth+token_url validation. +func seedOAuthCred(t *testing.T, s *Store, name string) { + t.Helper() + if err := s.AddCredentialMeta(name, "oauth", "https://auth.example.com/token"); err != nil { + t.Fatalf("seed oauth cred %q: %v", name, err) + } +} + +func TestCreatePoolWithMembersAndGet(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "acct_a") + seedOAuthCred(t, s, "acct_b") + + if err := s.CreatePoolWithMembers("codex", "", []string{"acct_a", "acct_b"}); err != nil { + t.Fatalf("CreatePoolWithMembers: %v", err) + } + + p, err := s.GetPool("codex") + if err != nil { + t.Fatalf("GetPool: %v", err) + } + if p == nil { + t.Fatal("GetPool returned nil for existing pool") + } + if p.Strategy != PoolStrategyFailover { + t.Errorf("strategy = %q, want %q", p.Strategy, PoolStrategyFailover) + } + if len(p.Members) != 2 { + t.Fatalf("members = %d, want 2", len(p.Members)) + } + // Ordering must follow the slice order via position. + if p.Members[0].Credential != "acct_a" || p.Members[0].Position != 0 { + t.Errorf("member[0] = %+v, want acct_a@0", p.Members[0]) + } + if p.Members[1].Credential != "acct_b" || p.Members[1].Position != 1 { + t.Errorf("member[1] = %+v, want acct_b@1", p.Members[1]) + } + + exists, err := s.PoolExists("codex") + if err != nil || !exists { + t.Errorf("PoolExists(codex) = %v, %v; want true, nil", exists, err) + } + if got, _ := s.GetPool("missing"); got != nil { + t.Errorf("GetPool(missing) = %+v, want nil", got) + } +} + +func TestCreatePoolRejectsStaticMember(t *testing.T) { + s := newTestStore(t) + if err := s.AddCredentialMeta("static_key", "static", ""); err != nil { + t.Fatalf("AddCredentialMeta: %v", err) + } + err := s.CreatePoolWithMembers("p", "failover", []string{"static_key"}) + if err == nil { + t.Fatal("expected error creating pool with static member") + } + // The pool row must not survive a failed member insert (tx rollback). + if exists, _ := s.PoolExists("p"); exists { + t.Error("pool row leaked after failed member validation") + } +} + +func TestCreatePoolRejectsMissingMember(t *testing.T) { + s := newTestStore(t) + if err := s.CreatePoolWithMembers("p", "failover", []string{"nope"}); err == nil { + t.Fatal("expected error for non-existent member credential") + } +} + +func TestCreatePoolRejectsBadStrategyAndDupes(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "a") + if err := s.CreatePoolWithMembers("p", "roundrobin", []string{"a"}); err == nil { + t.Error("expected error for unsupported strategy") + } + if err := s.CreatePoolWithMembers("p", "failover", []string{"a", "a"}); err == nil { + t.Error("expected error for duplicate member") + } + if err := s.CreatePoolWithMembers("p", "failover", nil); err == nil { + t.Error("expected error for empty member list") + } +} + +func TestPoolCredentialNamespaceMutualExclusion(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "acct_a") + // "acct_a" is a credential; a pool may not shadow it. + if err := s.CreatePoolWithMembers("acct_a", "failover", []string{"acct_a"}); err == nil { + t.Fatal("expected namespace collision error (pool name == credential name)") + } +} + +func TestListPoolsOrdersMembers(t *testing.T) { + s := newTestStore(t) + for _, n := range []string{"a", "b", "c"} { + seedOAuthCred(t, s, n) + } + if err := s.CreatePoolWithMembers("p1", "failover", []string{"c", "a"}); err != nil { + t.Fatalf("create p1: %v", err) + } + if err := s.CreatePoolWithMembers("p2", "failover", []string{"b"}); err != nil { + t.Fatalf("create p2: %v", err) + } + pools, err := s.ListPools() + if err != nil { + t.Fatalf("ListPools: %v", err) + } + if len(pools) != 2 { + t.Fatalf("pools = %d, want 2", len(pools)) + } + // Pools ordered by name; p1 members in insertion order (c, a). + if pools[0].Name != "p1" || len(pools[0].Members) != 2 || + pools[0].Members[0].Credential != "c" || pools[0].Members[1].Credential != "a" { + t.Errorf("p1 members wrong: %+v", pools[0]) + } + if pools[1].Name != "p2" || len(pools[1].Members) != 1 { + t.Errorf("p2 wrong: %+v", pools[1]) + } +} + +func TestRemovePoolCascadesMembers(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "a") + if err := s.CreatePoolWithMembers("p", "failover", []string{"a"}); err != nil { + t.Fatalf("create: %v", err) + } + removed, err := s.RemovePool("p") + if err != nil || !removed { + t.Fatalf("RemovePool = %v, %v; want true, nil", removed, err) + } + // Members cascade-deleted via FK ON DELETE CASCADE. + mp, _ := s.PoolsForMember("a") + if len(mp) != 0 { + t.Errorf("PoolsForMember after remove = %v, want empty", mp) + } + if removed, _ := s.RemovePool("p"); removed { + t.Error("RemovePool of missing pool returned true") + } +} + +func TestPoolsForMember(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "shared") + seedOAuthCred(t, s, "x") + if err := s.CreatePoolWithMembers("p1", "failover", []string{"shared", "x"}); err != nil { + t.Fatalf("create p1: %v", err) + } + if err := s.CreatePoolWithMembers("p2", "failover", []string{"shared"}); err != nil { + t.Fatalf("create p2: %v", err) + } + pools, err := s.PoolsForMember("shared") + if err != nil { + t.Fatalf("PoolsForMember: %v", err) + } + if len(pools) != 2 || pools[0] != "p1" || pools[1] != "p2" { + t.Errorf("PoolsForMember(shared) = %v, want [p1 p2]", pools) + } +} + +func TestCredentialHealthCRUD(t *testing.T) { + s := newTestStore(t) + + // No row -> nil (callers treat as healthy). + h, err := s.GetCredentialHealth("a") + if err != nil || h != nil { + t.Fatalf("GetCredentialHealth(absent) = %+v, %v; want nil, nil", h, err) + } + + until := time.Now().Add(60 * time.Second).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("a", "cooldown", until, "429 rate limited"); err != nil { + t.Fatalf("SetCredentialHealth: %v", err) + } + h, err = s.GetCredentialHealth("a") + if err != nil || h == nil { + t.Fatalf("GetCredentialHealth = %+v, %v", h, err) + } + if h.Status != "cooldown" || h.LastFailureReason != "429 rate limited" { + t.Errorf("health = %+v, want cooldown/429", h) + } + if !h.CooldownUntil.Equal(until) { + t.Errorf("CooldownUntil = %v, want %v", h.CooldownUntil, until) + } + + // Upsert back to healthy clears the cooldown. + if err := s.SetCredentialHealth("a", "healthy", time.Time{}, ""); err != nil { + t.Fatalf("SetCredentialHealth healthy: %v", err) + } + h, _ = s.GetCredentialHealth("a") + if h.Status != "healthy" || !h.CooldownUntil.IsZero() { + t.Errorf("after healthy upsert = %+v, want healthy/zero", h) + } + + if err := s.SetCredentialHealth("b", "bogus", time.Time{}, ""); err == nil { + t.Error("expected error for invalid health status") + } + + all, err := s.ListCredentialHealth() + if err != nil { + t.Fatalf("ListCredentialHealth: %v", err) + } + if len(all) != 1 || all[0].Credential != "a" { + t.Errorf("ListCredentialHealth = %+v, want [a]", all) + } +} + +// TestMigration000006DownUp verifies the pool migration is reversible. +func TestMigration000006DownUp(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "m.db") + s, err := New(dbPath) + if err != nil { + t.Fatalf("New: %v", err) + } + defer func() { _ = s.Close() }() + + tableExists := func(name string) bool { + var n string + err := s.db.QueryRow( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", name, + ).Scan(&n) + return err == nil && n == name + } + + for _, tbl := range []string{"credential_pools", "credential_pool_members", "credential_health"} { + if !tableExists(tbl) { + t.Fatalf("table %q missing after up migration", tbl) + } + } + + src, err := iofs.New(migrationsFS, "migrations") + if err != nil { + t.Fatalf("iofs: %v", err) + } + drv, err := migsqlite.WithInstance(s.db, &migsqlite.Config{}) + if err != nil { + t.Fatalf("driver: %v", err) + } + m, err := migrate.NewWithInstance("iofs", src, "sqlite", drv) + if err != nil { + t.Fatalf("migrator: %v", err) + } + + // Step down one migration (000006 -> 000005). + if err := m.Steps(-1); err != nil { + t.Fatalf("down 1: %v", err) + } + for _, tbl := range []string{"credential_pools", "credential_pool_members", "credential_health"} { + if tableExists(tbl) { + t.Errorf("table %q still present after down migration", tbl) + } + } + + // Step back up; tables return. + if err := m.Steps(1); err != nil { + t.Fatalf("up 1: %v", err) + } + for _, tbl := range []string{"credential_pools", "credential_pool_members", "credential_health"} { + if !tableExists(tbl) { + t.Errorf("table %q missing after re-up migration", tbl) + } + } +} From 5fb56579541548c4b0a2168318fb7aff260db585 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:26:12 +0800 Subject: [PATCH 05/49] feat(vault): add PoolResolver pool->active-member chokepoint PoolResolver is the single place a bound pool name expands to a concrete credential. IsPool/ResolveActive (first healthy or expired-cooldown member in position order; degrade to soonest-recovering when all down) and MarkCooldown for Phase 2 synchronous in-memory failover. Locking discipline documented: membership immutable per instance (atomic-pointer swap on reload), health mutated in place under an RWMutex. RateLimit (60s) and AuthFail (300s) cooldown TTL consts. Nil-safe. --- internal/vault/pool.go | 190 ++++++++++++++++++++++++++++++++++++ internal/vault/pool_test.go | 128 ++++++++++++++++++++++++ 2 files changed, 318 insertions(+) create mode 100644 internal/vault/pool.go create mode 100644 internal/vault/pool_test.go diff --git a/internal/vault/pool.go b/internal/vault/pool.go new file mode 100644 index 0000000..ceb54f7 --- /dev/null +++ b/internal/vault/pool.go @@ -0,0 +1,190 @@ +package vault + +import ( + "log" + "sync" + "time" + + "github.com/nemirovsky/sluice/internal/store" +) + +// Cooldown TTLs applied when a pool member is failed over. A rate-limited +// account usually recovers within the provider's window, so it is retried +// relatively soon. An auth failure (revoked/expired refresh token, bad +// client) will not self-heal quickly, so it is parked far longer to avoid +// thrashing a broken account on every request. +const ( + RateLimitCooldown = 60 * time.Second + AuthFailCooldown = 300 * time.Second +) + +// memberHealth is the in-memory health view for one credential. Status is +// derived: a credential with a zero cooldownUntil is healthy. +type memberHealth struct { + cooldownUntil time.Time + reason string +} + +// PoolResolver maps a pool name to its currently active member. It is the +// single chokepoint every credential consumer routes through (injection +// passes, OAuthIndex.Has gating, persist attribution), so a pool name is +// expanded to a real credential in exactly one place. +// +// Locking discipline: pool membership is immutable for the lifetime of a +// PoolResolver instance (membership changes rebuild a fresh resolver that +// the server atomically pointer-swaps). Health, by contrast, is mutated +// synchronously on the response path during Phase 2 failover, so the health +// map is guarded by mu. ResolveActive takes mu.RLock; MarkCooldown takes +// mu.Lock. Readers therefore always observe a consistent active member even +// while a concurrent response is recording a failover. +type PoolResolver struct { + // pools maps pool name -> ordered member credential names. + pools map[string][]string + // memberOf maps a credential name -> the pools that contain it. + memberOf map[string][]string + + mu sync.RWMutex + health map[string]memberHealth +} + +// NewPoolResolver builds a resolver from store snapshots. Health rows with +// status "cooldown" and a future cooldown_until seed the in-memory health +// map; healthy rows and expired cooldowns are treated as eligible. +func NewPoolResolver(pools []store.Pool, healthRows []store.CredentialHealth) *PoolResolver { + pr := &PoolResolver{ + pools: make(map[string][]string, len(pools)), + memberOf: make(map[string][]string), + health: make(map[string]memberHealth), + } + for _, p := range pools { + members := make([]string, 0, len(p.Members)) + for _, m := range p.Members { + members = append(members, m.Credential) + pr.memberOf[m.Credential] = append(pr.memberOf[m.Credential], p.Name) + } + pr.pools[p.Name] = members + } + for _, h := range healthRows { + if h.Status == "cooldown" && !h.CooldownUntil.IsZero() { + pr.health[h.Credential] = memberHealth{ + cooldownUntil: h.CooldownUntil, + reason: h.LastFailureReason, + } + } + } + return pr +} + +// IsPool reports whether name is a configured pool. +func (pr *PoolResolver) IsPool(name string) bool { + if pr == nil { + return false + } + pr.mu.RLock() + defer pr.mu.RUnlock() + _, ok := pr.pools[name] + return ok +} + +// PoolForMember returns the first pool that contains the given credential, +// or "" if the credential is not a pool member. Used by the response path to +// attribute a failover/refresh to its pool for audit + Telegram. +func (pr *PoolResolver) PoolForMember(credential string) string { + if pr == nil { + return "" + } + pr.mu.RLock() + defer pr.mu.RUnlock() + if pools := pr.memberOf[credential]; len(pools) > 0 { + return pools[0] + } + return "" +} + +// Members returns the ordered member list for a pool (copy), or nil. +func (pr *PoolResolver) Members(pool string) []string { + if pr == nil { + return nil + } + pr.mu.RLock() + defer pr.mu.RUnlock() + m, ok := pr.pools[pool] + if !ok { + return nil + } + return append([]string(nil), m...) +} + +// ResolveActive expands a name to the credential that should actually be +// used. For a plain credential (not a pool) the name is returned unchanged. +// For a pool, the first member that is healthy or whose cooldown has expired +// (in position order) is returned. If every member is still cooling down, +// the member with the soonest recovery is returned and a WARNING is logged +// (degraded: sluice keeps serving with the least-bad account rather than +// failing the request outright). +func (pr *PoolResolver) ResolveActive(name string) (member string, ok bool) { + if pr == nil { + return name, true + } + pr.mu.RLock() + defer pr.mu.RUnlock() + + members, isPool := pr.pools[name] + if !isPool { + // Plain credential: passthrough unchanged. + return name, true + } + if len(members) == 0 { + return "", false + } + + now := time.Now() + var soonest string + var soonestUntil time.Time + for _, m := range members { + h, tracked := pr.health[m] + if !tracked || h.cooldownUntil.IsZero() || !h.cooldownUntil.After(now) { + return m, true + } + if soonest == "" || h.cooldownUntil.Before(soonestUntil) { + soonest = m + soonestUntil = h.cooldownUntil + } + } + log.Printf("[POOL] all %d members of pool %q are in cooldown; degrading to %q (recovers %s)", + len(members), name, soonest, soonestUntil.Format(time.RFC3339)) + return soonest, true +} + +// MarkCooldown records, in memory and synchronously, that a member should be +// skipped until `until`. Phase 2 failover calls this on the response path +// BEFORE the response returns so the very next request injects the next +// member; the durable store write only reconciles afterwards. Calling with a +// zero/past `until` clears the cooldown (recovery). +func (pr *PoolResolver) MarkCooldown(credential string, until time.Time, reason string) { + if pr == nil { + return + } + pr.mu.Lock() + defer pr.mu.Unlock() + if until.IsZero() || !until.After(time.Now()) { + delete(pr.health, credential) + return + } + pr.health[credential] = memberHealth{cooldownUntil: until, reason: reason} +} + +// CooldownUntil returns the in-memory cooldown expiry for a credential and +// whether it is currently cooling down (future expiry). +func (pr *PoolResolver) CooldownUntil(credential string) (time.Time, bool) { + if pr == nil { + return time.Time{}, false + } + pr.mu.RLock() + defer pr.mu.RUnlock() + h, ok := pr.health[credential] + if !ok || h.cooldownUntil.IsZero() || !h.cooldownUntil.After(time.Now()) { + return time.Time{}, false + } + return h.cooldownUntil, true +} diff --git a/internal/vault/pool_test.go b/internal/vault/pool_test.go new file mode 100644 index 0000000..0293d3d --- /dev/null +++ b/internal/vault/pool_test.go @@ -0,0 +1,128 @@ +package vault + +import ( + "testing" + "time" + + "github.com/nemirovsky/sluice/internal/store" +) + +func mkPool(name string, members ...string) store.Pool { + p := store.Pool{Name: name, Strategy: store.PoolStrategyFailover} + for i, m := range members { + p.Members = append(p.Members, store.PoolMember{Credential: m, Position: i}) + } + return p +} + +func TestResolveActivePassthroughForNonPool(t *testing.T) { + pr := NewPoolResolver(nil, nil) + got, ok := pr.ResolveActive("plain_cred") + if !ok || got != "plain_cred" { + t.Errorf("ResolveActive(plain) = %q,%v; want plain_cred,true", got, ok) + } + if pr.IsPool("plain_cred") { + t.Error("IsPool(plain_cred) = true, want false") + } +} + +func TestResolveActivePicksFirstHealthy(t *testing.T) { + pr := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, nil) + if !pr.IsPool("pool") { + t.Fatal("IsPool(pool) = false") + } + got, ok := pr.ResolveActive("pool") + if !ok || got != "a" { + t.Errorf("ResolveActive = %q,%v; want a,true", got, ok) + } +} + +func TestResolveActiveSkipsCooledDownMember(t *testing.T) { + future := time.Now().Add(60 * time.Second) + health := []store.CredentialHealth{ + {Credential: "a", Status: "cooldown", CooldownUntil: future, LastFailureReason: "429"}, + } + pr := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, health) + got, ok := pr.ResolveActive("pool") + if !ok || got != "b" { + t.Errorf("ResolveActive = %q,%v; want b,true (a is cooling down)", got, ok) + } +} + +func TestResolveActiveExpiredCooldownIsEligible(t *testing.T) { + past := time.Now().Add(-1 * time.Second) + health := []store.CredentialHealth{ + {Credential: "a", Status: "cooldown", CooldownUntil: past}, + } + pr := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, health) + got, _ := pr.ResolveActive("pool") + if got != "a" { + t.Errorf("ResolveActive = %q; want a (cooldown expired -> eligible)", got) + } +} + +func TestResolveActiveAllDownDegradesToSoonest(t *testing.T) { + now := time.Now() + health := []store.CredentialHealth{ + {Credential: "a", Status: "cooldown", CooldownUntil: now.Add(300 * time.Second)}, + {Credential: "b", Status: "cooldown", CooldownUntil: now.Add(30 * time.Second)}, + } + pr := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, health) + got, ok := pr.ResolveActive("pool") + if !ok || got != "b" { + t.Errorf("ResolveActive (all down) = %q,%v; want b,true (soonest recovery)", got, ok) + } +} + +func TestResolveActiveEmptyPool(t *testing.T) { + pr := NewPoolResolver([]store.Pool{mkPool("empty")}, nil) + if _, ok := pr.ResolveActive("empty"); ok { + t.Error("ResolveActive(empty pool) ok=true, want false") + } +} + +func TestMarkCooldownSynchronousFlip(t *testing.T) { + pr := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, nil) + if got, _ := pr.ResolveActive("pool"); got != "a" { + t.Fatalf("initial active = %q, want a", got) + } + // Synchronous in-memory failover (Phase 2 path): the very next + // resolution must already see b. + pr.MarkCooldown("a", time.Now().Add(60*time.Second), "429") + if got, _ := pr.ResolveActive("pool"); got != "b" { + t.Errorf("after MarkCooldown(a) active = %q, want b", got) + } + if _, cooling := pr.CooldownUntil("a"); !cooling { + t.Error("CooldownUntil(a) cooling=false, want true") + } + // Clearing (zero/past) recovers the member. + pr.MarkCooldown("a", time.Time{}, "") + if got, _ := pr.ResolveActive("pool"); got != "a" { + t.Errorf("after clear active = %q, want a", got) + } +} + +func TestPoolForMemberAndMembers(t *testing.T) { + pr := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, nil) + if p := pr.PoolForMember("b"); p != "pool" { + t.Errorf("PoolForMember(b) = %q, want pool", p) + } + if p := pr.PoolForMember("nope"); p != "" { + t.Errorf("PoolForMember(nope) = %q, want empty", p) + } + m := pr.Members("pool") + if len(m) != 2 || m[0] != "a" || m[1] != "b" { + t.Errorf("Members(pool) = %v, want [a b]", m) + } +} + +func TestNilPoolResolverSafe(t *testing.T) { + var pr *PoolResolver + if got, ok := pr.ResolveActive("x"); !ok || got != "x" { + t.Errorf("nil ResolveActive = %q,%v; want x,true", got, ok) + } + if pr.IsPool("x") { + t.Error("nil IsPool = true") + } + pr.MarkCooldown("x", time.Now(), "") // must not panic +} From a1a93e3c704011d10af5a40172ffff962ddcedf7 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:26:13 +0800 Subject: [PATCH 06/49] feat(cli): add pool subcommands and credential/pool namespace guards sluice pool create|list|status|rotate|remove. status computes the active member via the same PoolResolver logic the proxy uses so it never disagrees with what gets injected; rotate parks the active member (lazy-recovery cooldown) so the next member takes over. cred add rejects a name colliding with an existing pool; cred remove is blocked (before the vault delete) when the credential is a live pool member, so no dangling member rows or destroyed secrets. --- cmd/sluice/cred.go | 30 +++++ cmd/sluice/pool.go | 254 ++++++++++++++++++++++++++++++++++++++++ cmd/sluice/pool_test.go | 166 ++++++++++++++++++++++++++ 3 files changed, 450 insertions(+) create mode 100644 cmd/sluice/pool.go create mode 100644 cmd/sluice/pool_test.go diff --git a/cmd/sluice/cred.go b/cmd/sluice/cred.go index 383087e..2b35bea 100644 --- a/cmd/sluice/cred.go +++ b/cmd/sluice/cred.go @@ -227,6 +227,15 @@ func handleCredAdd(args []string) error { } defer func() { _ = db.Close() }() + // Namespace mutual-exclusion: a credential must not shadow a pool. Pool + // and credential names share one namespace so a bound destination + // resolves unambiguously to either a pool or a plain credential. + if exists, perr := db.PoolExists(name); perr != nil { + return fmt.Errorf("check pool name collision: %w", perr) + } else if exists { + return fmt.Errorf("name %q is already a credential pool; pool and credential names share one namespace", name) + } + // Inputs validated and DB is open. Now persist the credential. vs, err := openVaultStore(*dbPath) if err != nil { @@ -553,6 +562,27 @@ func handleCredRemove(args []string) error { } name := fs.Arg(0) + // Block removing a credential that is still a live pool member so no + // dangling member rows are left behind. The operator must remove it + // from the pool first. This check runs before the vault delete so a + // blocked removal does not destroy the secret. Only consult the DB if + // it already exists (do not create it as a side effect of a removal). + if _, statErr := os.Stat(*dbPath); statErr == nil { + guardDB, gerr := store.New(*dbPath) + if gerr != nil { + log.Printf("warning: could not open database %q to check pool membership: %v", *dbPath, gerr) + } else { + pools, perr := guardDB.PoolsForMember(name) + _ = guardDB.Close() + if perr != nil { + return fmt.Errorf("check pool membership for %q: %w", name, perr) + } + if len(pools) > 0 { + return fmt.Errorf("credential %q is a member of pool(s) %s; remove it from the pool first (sluice pool remove

, or recreate the pool without it)", name, strings.Join(pools, ", ")) + } + } + } + vs, err := openVaultStore(*dbPath) if err != nil { return err diff --git a/cmd/sluice/pool.go b/cmd/sluice/pool.go new file mode 100644 index 0000000..56b8f48 --- /dev/null +++ b/cmd/sluice/pool.go @@ -0,0 +1,254 @@ +package main + +import ( + "flag" + "fmt" + "strings" + "time" + + "github.com/nemirovsky/sluice/internal/store" + "github.com/nemirovsky/sluice/internal/vault" +) + +func handlePoolCommand(args []string) error { + if len(args) == 0 { + return fmt.Errorf("usage: sluice pool [create|list|status|rotate|remove]") + } + + switch args[0] { + case "create": + return handlePoolCreate(args[1:]) + case "list": + return handlePoolList(args[1:]) + case "status": + return handlePoolStatus(args[1:]) + case "rotate": + return handlePoolRotate(args[1:]) + case "remove": + return handlePoolRemove(args[1:]) + default: + return fmt.Errorf("unknown pool command: %s (usage: sluice pool [create|list|status|rotate|remove] ...)", args[0]) + } +} + +func handlePoolCreate(args []string) error { + fs := flag.NewFlagSet("pool create", flag.ContinueOnError) + dbPath := fs.String("db", "data/sluice.db", "path to SQLite database") + membersStr := fs.String("members", "", "comma-separated ordered list of oauth credential names (failover order)") + strategy := fs.String("strategy", store.PoolStrategyFailover, "pool strategy (only 'failover' is supported)") + if err := fs.Parse(reorderFlagsBeforePositional(args, fs)); err != nil { + return err + } + + if fs.NArg() == 0 { + return fmt.Errorf("usage: sluice pool create --members a,b[,c] [--strategy failover]") + } + name := fs.Arg(0) + + if *membersStr == "" { + return fmt.Errorf("--members is required (comma-separated oauth credential names)") + } + var members []string + for _, m := range strings.Split(*membersStr, ",") { + m = strings.TrimSpace(m) + if m == "" { + return fmt.Errorf("empty credential name in --members list") + } + members = append(members, m) + } + + db, err := store.New(*dbPath) + if err != nil { + return fmt.Errorf("open store: %w", err) + } + defer func() { _ = db.Close() }() + + if err := db.CreatePoolWithMembers(name, *strategy, members); err != nil { + return err + } + + fmt.Printf("pool %q created (strategy: %s)\n", name, *strategy) + for i, m := range members { + fmt.Printf(" [%d] %s\n", i, m) + } + fmt.Printf("bind it with: sluice binding add %s --destination [--ports 443]\n", name) + return nil +} + +func handlePoolList(args []string) error { + fs := flag.NewFlagSet("pool list", flag.ContinueOnError) + dbPath := fs.String("db", "data/sluice.db", "path to SQLite database") + if err := fs.Parse(args); err != nil { + return err + } + + db, err := store.New(*dbPath) + if err != nil { + return fmt.Errorf("open store: %w", err) + } + defer func() { _ = db.Close() }() + + pools, err := db.ListPools() + if err != nil { + return err + } + if len(pools) == 0 { + fmt.Println("no credential pools configured") + return nil + } + for _, p := range pools { + names := make([]string, 0, len(p.Members)) + for _, m := range p.Members { + names = append(names, m.Credential) + } + fmt.Printf("%s (strategy: %s): %s\n", p.Name, p.Strategy, strings.Join(names, ", ")) + } + return nil +} + +func handlePoolStatus(args []string) error { + fs := flag.NewFlagSet("pool status", flag.ContinueOnError) + dbPath := fs.String("db", "data/sluice.db", "path to SQLite database") + if err := fs.Parse(reorderFlagsBeforePositional(args, fs)); err != nil { + return err + } + if fs.NArg() == 0 { + return fmt.Errorf("usage: sluice pool status ") + } + name := fs.Arg(0) + + db, err := store.New(*dbPath) + if err != nil { + return fmt.Errorf("open store: %w", err) + } + defer func() { _ = db.Close() }() + + p, err := db.GetPool(name) + if err != nil { + return err + } + if p == nil { + return fmt.Errorf("pool %q not found", name) + } + healthRows, err := db.ListCredentialHealth() + if err != nil { + return err + } + + // Compute the active member using the exact same selection logic the + // proxy uses at injection time so `pool status` never disagrees with + // what would actually be injected. + resolver := vault.NewPoolResolver([]store.Pool{*p}, healthRows) + active, _ := resolver.ResolveActive(name) + + healthByCred := make(map[string]store.CredentialHealth, len(healthRows)) + for _, h := range healthRows { + healthByCred[h.Credential] = h + } + + fmt.Printf("pool %q (strategy: %s)\n", p.Name, p.Strategy) + now := time.Now() + for _, m := range p.Members { + marker := " " + if m.Credential == active { + marker = "* " + } + status := "healthy" + if h, ok := healthByCred[m.Credential]; ok && h.Status == "cooldown" && !h.CooldownUntil.IsZero() { + if h.CooldownUntil.After(now) { + status = fmt.Sprintf("cooldown until %s", h.CooldownUntil.Format(time.RFC3339)) + } else { + status = "healthy (cooldown expired)" + } + if h.LastFailureReason != "" { + status += " — " + h.LastFailureReason + } + } + fmt.Printf("%s[%d] %s %s\n", marker, m.Position, m.Credential, status) + } + fmt.Printf("active: %s\n", active) + return nil +} + +func handlePoolRotate(args []string) error { + fs := flag.NewFlagSet("pool rotate", flag.ContinueOnError) + dbPath := fs.String("db", "data/sluice.db", "path to SQLite database") + if err := fs.Parse(reorderFlagsBeforePositional(args, fs)); err != nil { + return err + } + if fs.NArg() == 0 { + return fmt.Errorf("usage: sluice pool rotate ") + } + name := fs.Arg(0) + + db, err := store.New(*dbPath) + if err != nil { + return fmt.Errorf("open store: %w", err) + } + defer func() { _ = db.Close() }() + + p, err := db.GetPool(name) + if err != nil { + return err + } + if p == nil { + return fmt.Errorf("pool %q not found", name) + } + healthRows, err := db.ListCredentialHealth() + if err != nil { + return err + } + resolver := vault.NewPoolResolver([]store.Pool{*p}, healthRows) + active, ok := resolver.ResolveActive(name) + if !ok || active == "" { + return fmt.Errorf("pool %q has no resolvable member to rotate away from", name) + } + + // Manual override: park the current active member so the next member in + // position order becomes active. The cooldown lapses on its own (lazy + // recovery, same as auto-failover), so a rotated-away member rejoins the + // rotation once its cooldown expires. + until := time.Now().Add(vault.AuthFailCooldown) + if err := db.SetCredentialHealth(active, "cooldown", until, "manual rotate"); err != nil { + return err + } + + // Recompute the new active member for operator feedback. + healthRows, err = db.ListCredentialHealth() + if err != nil { + return err + } + resolver = vault.NewPoolResolver([]store.Pool{*p}, healthRows) + next, _ := resolver.ResolveActive(name) + fmt.Printf("pool %q rotated: %s -> %s (parked %s until %s)\n", + name, active, next, active, until.Format(time.RFC3339)) + return nil +} + +func handlePoolRemove(args []string) error { + fs := flag.NewFlagSet("pool remove", flag.ContinueOnError) + dbPath := fs.String("db", "data/sluice.db", "path to SQLite database") + if err := fs.Parse(reorderFlagsBeforePositional(args, fs)); err != nil { + return err + } + if fs.NArg() == 0 { + return fmt.Errorf("usage: sluice pool remove ") + } + name := fs.Arg(0) + + db, err := store.New(*dbPath) + if err != nil { + return fmt.Errorf("open store: %w", err) + } + defer func() { _ = db.Close() }() + + removed, err := db.RemovePool(name) + if err != nil { + return err + } + if !removed { + return fmt.Errorf("pool %q not found", name) + } + fmt.Printf("pool %q removed (members and bindings referencing it are unaffected; remove stale bindings with 'sluice binding remove')\n", name) + return nil +} diff --git a/cmd/sluice/pool_test.go b/cmd/sluice/pool_test.go new file mode 100644 index 0000000..d984544 --- /dev/null +++ b/cmd/sluice/pool_test.go @@ -0,0 +1,166 @@ +package main + +import ( + "os" + "strings" + "testing" + + "github.com/nemirovsky/sluice/internal/store" + "github.com/nemirovsky/sluice/internal/vault" +) + +// seedPoolCred registers an oauth credential_meta row plus a vault secret so +// it is a valid pool member and a removable credential. +func seedPoolCred(t *testing.T, dbPath, dir, name string) { + t.Helper() + db, err := store.New(dbPath) + if err != nil { + t.Fatalf("open db: %v", err) + } + if err := db.AddCredentialMeta(name, "oauth", "https://auth.example.com/token"); err != nil { + t.Fatalf("add meta %q: %v", name, err) + } + _ = db.Close() + vs, err := vault.NewStore(dir) + if err != nil { + t.Fatalf("open vault: %v", err) + } + if _, err := vs.Add(name, `{"access_token":"x","token_url":"https://auth.example.com/token"}`); err != nil { + t.Fatalf("vault add %q: %v", name, err) + } +} + +func TestHandlePoolCreateListStatusRemove(t *testing.T) { + dir := t.TempDir() + dbPath := setupVaultDB(t, dir) + seedPoolCred(t, dbPath, dir, "acct_a") + seedPoolCred(t, dbPath, dir, "acct_b") + + out := captureStdout(t, func() { + if err := handleCredCommand([]string{}); err == nil { + t.Error("expected usage error for empty cred args") + } + if err := handlePoolCommand([]string{"create", "--db", dbPath, "--members", "acct_a,acct_b", "codex"}); err != nil { + t.Fatalf("pool create: %v", err) + } + }) + if !strings.Contains(out, `pool "codex" created`) { + t.Errorf("create output = %q", out) + } + + out = captureStdout(t, func() { + if err := handlePoolCommand([]string{"list", "--db", dbPath}); err != nil { + t.Fatalf("pool list: %v", err) + } + }) + if !strings.Contains(out, "codex") || !strings.Contains(out, "acct_a, acct_b") { + t.Errorf("list output = %q", out) + } + + out = captureStdout(t, func() { + if err := handlePoolCommand([]string{"status", "--db", dbPath, "codex"}); err != nil { + t.Fatalf("pool status: %v", err) + } + }) + // First member is active. + if !strings.Contains(out, "* [0] acct_a") || !strings.Contains(out, "active: acct_a") { + t.Errorf("status output = %q", out) + } + + // Rotate parks acct_a so acct_b becomes active. + out = captureStdout(t, func() { + if err := handlePoolCommand([]string{"rotate", "--db", dbPath, "codex"}); err != nil { + t.Fatalf("pool rotate: %v", err) + } + }) + if !strings.Contains(out, "acct_a -> acct_b") { + t.Errorf("rotate output = %q", out) + } + out = captureStdout(t, func() { + _ = handlePoolCommand([]string{"status", "--db", dbPath, "codex"}) + }) + if !strings.Contains(out, "active: acct_b") { + t.Errorf("post-rotate status = %q", out) + } + + out = captureStdout(t, func() { + if err := handlePoolCommand([]string{"remove", "--db", dbPath, "codex"}); err != nil { + t.Fatalf("pool remove: %v", err) + } + }) + if !strings.Contains(out, `pool "codex" removed`) { + t.Errorf("remove output = %q", out) + } +} + +func TestHandlePoolErrorPaths(t *testing.T) { + dir := t.TempDir() + dbPath := setupVaultDB(t, dir) + seedPoolCred(t, dbPath, dir, "acct_a") + + if err := handlePoolCommand(nil); err == nil { + t.Error("expected usage error for no args") + } + if err := handlePoolCommand([]string{"bogus"}); err == nil { + t.Error("expected error for unknown subcommand") + } + if err := handlePoolCommand([]string{"create", "--db", dbPath, "p"}); err == nil { + t.Error("expected error for missing --members") + } + if err := handlePoolCommand([]string{"status", "--db", dbPath, "missing"}); err == nil { + t.Error("expected error for status of missing pool") + } + if err := handlePoolCommand([]string{"remove", "--db", dbPath, "missing"}); err == nil { + t.Error("expected error for remove of missing pool") + } + // Pool name colliding with an existing credential is rejected. + if err := handlePoolCommand([]string{"create", "--db", dbPath, "--members", "acct_a", "acct_a"}); err == nil { + t.Error("expected namespace collision error (pool == credential)") + } +} + +func TestCredAddRejectsPoolNameCollision(t *testing.T) { + dir := t.TempDir() + dbPath := setupVaultDB(t, dir) + seedPoolCred(t, dbPath, dir, "acct_a") + if err := handlePoolCommand([]string{"create", "--db", dbPath, "--members", "acct_a", "mypool"}); err != nil { + t.Fatalf("pool create: %v", err) + } + + oldStdin := os.Stdin + r, w, _ := os.Pipe() + os.Stdin = r + _, _ = w.Write([]byte("secret\n")) + _ = w.Close() + defer func() { os.Stdin = oldStdin }() + + err := handleCredCommand([]string{"add", "--db", dbPath, "mypool"}) + if err == nil || !strings.Contains(err.Error(), "already a credential pool") { + t.Fatalf("cred add colliding with pool: err = %v, want namespace error", err) + } +} + +func TestCredRemoveBlockedForLivePoolMember(t *testing.T) { + dir := t.TempDir() + dbPath := setupVaultDB(t, dir) + seedPoolCred(t, dbPath, dir, "acct_a") + seedPoolCred(t, dbPath, dir, "acct_b") + if err := handlePoolCommand([]string{"create", "--db", dbPath, "--members", "acct_a,acct_b", "codex"}); err != nil { + t.Fatalf("pool create: %v", err) + } + + err := handleCredCommand([]string{"remove", "--db", dbPath, "acct_a"}) + if err == nil || !strings.Contains(err.Error(), "member of pool") { + t.Fatalf("cred remove of live member: err = %v, want block error", err) + } + // Secret must still be present (removal was blocked before vault delete). + vs, verr := vault.NewStore(dir) + if verr != nil { + t.Fatalf("open vault: %v", verr) + } + sb, gerr := vs.Get("acct_a") + if gerr != nil { + t.Fatalf("credential acct_a was destroyed despite blocked removal: %v", gerr) + } + sb.Release() +} From 800d36f0504d9b5aaa0073e5ac907bd7d96fa8d5 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:26:25 +0800 Subject: [PATCH 07/49] feat(proxy): wire PoolResolver into server, addon and reloadAll Server gains an atomic PoolResolver pointer (parallel to the binding resolver), StorePool/PoolResolverPtr, and threads it into SluiceAddon via WithPoolResolver. addon.resolvePoolMember is the chokepoint helper (non-pool names passthrough). main.go registers the pool subcommand, loads the resolver at startup, and rebuilds+atomically swaps it in reloadAll alongside the binding/oauth reloads. Injection does not consult it yet (Phase 1). --- cmd/sluice/main.go | 46 ++++++++++++++++++++++++++++++++++++++++ internal/proxy/addon.go | 42 ++++++++++++++++++++++++++++++++++++ internal/proxy/server.go | 34 ++++++++++++++++++++++++++--- 3 files changed, 119 insertions(+), 3 deletions(-) diff --git a/cmd/sluice/main.go b/cmd/sluice/main.go index 6952eb6..553a05a 100644 --- a/cmd/sluice/main.go +++ b/cmd/sluice/main.go @@ -70,6 +70,11 @@ func main() { log.Fatalf("channel: %v", err) } return + case "pool": + if err := handlePoolCommand(os.Args[2:]); err != nil { + log.Fatalf("pool: %v", err) + } + return } } @@ -342,6 +347,19 @@ func main() { } } + // Populate the initial credential pool resolver at startup so pool + // expansion works for pools defined before the first SIGHUP. Always + // store a non-nil resolver (empty when no pools) so the addon never + // has to nil-check before ResolveActive (non-pool names passthrough). + if db != nil { + if pr, perr := loadPoolResolver(db); perr != nil { + log.Printf("pool resolver init failed: %v", perr) + srv.StorePool(vault.NewPoolResolver(nil, nil)) + } else { + srv.StorePool(pr) + } + } + // Configure the OAuth refresh callback so that after a token refresh // is persisted, the updated phantom env vars are re-injected into the // agent container. @@ -687,6 +705,17 @@ func main() { log.Printf("reload oauth index failed: %v", metaErr) } + // Rebuild and atomically swap the credential pool resolver. + // Membership changes (pool create/remove) take effect here; + // durable health rows are reloaded too, which only reconciles + // the in-memory health that Phase 2 failover already updated + // synchronously on the response path. + if pr, perr := loadPoolResolver(db); perr != nil { + log.Printf("reload pool resolver failed: %v", perr) + } else { + srv.StorePool(pr) + } + // Re-inject env vars into the agent container after binding changes. if containerMgr != nil { if injectErr := injectEnvVarsFromStore(db, containerMgr); injectErr != nil { @@ -813,6 +842,23 @@ func readBindings(db *store.Store) ([]vault.Binding, error) { return bindings, nil } +// loadPoolResolver builds a vault.PoolResolver from the store's pool, +// member, and credential-health tables. A non-nil resolver is always +// returned on success (empty when no pools), so callers can store it +// unconditionally and the addon never has to nil-check before +// ResolveActive (a non-pool name is an identity passthrough). +func loadPoolResolver(db *store.Store) (*vault.PoolResolver, error) { + pools, err := db.ListPools() + if err != nil { + return nil, fmt.Errorf("list pools: %w", err) + } + health, err := db.ListCredentialHealth() + if err != nil { + return nil, fmt.Errorf("list credential health: %w", err) + } + return vault.NewPoolResolver(pools, health), nil +} + // injectEnvVarsFromStore reads bindings with env_var set from the store, // generates phantom tokens for each, and injects them into the agent // container via the container manager. This is called at startup and after diff --git a/internal/proxy/addon.go b/internal/proxy/addon.go index a9afefb..e6cd5a7 100644 --- a/internal/proxy/addon.go +++ b/internal/proxy/addon.go @@ -84,6 +84,14 @@ type SluiceAddon struct { // atomically on SIGHUP / policy mutation. resolver *atomic.Pointer[vault.BindingResolver] + // poolResolver expands a bound pool name to its currently active + // member at the single injection chokepoint (resolvePoolMember). + // Swapped atomically alongside resolver on reload; may be nil when + // no pools are configured (treated as identity passthrough). Phase 2 + // mutates the contained health map in place under the resolver's own + // mutex on the response path. + poolResolver *atomic.Pointer[vault.PoolResolver] + // provider retrieves real credential values from the vault. provider vault.Provider @@ -180,6 +188,40 @@ func WithProvider(p vault.Provider) SluiceAddonOption { return func(a *SluiceAddon) { a.provider = p } } +// WithPoolResolver sets the credential pool resolver pointer used by the +// injection chokepoint to expand a bound pool name to its active member. +func WithPoolResolver(r *atomic.Pointer[vault.PoolResolver]) SluiceAddonOption { + return func(a *SluiceAddon) { a.poolResolver = r } +} + +// SetPoolResolver wires (or rewires) the shared pool resolver pointer. Safe +// to call after construction; the pointer itself is stable and only its +// contents are atomically swapped on reload. +func (a *SluiceAddon) SetPoolResolver(r *atomic.Pointer[vault.PoolResolver]) { + a.poolResolver = r +} + +// resolvePoolMember is the single chokepoint that expands a bound +// credential-or-pool name to the concrete credential whose secret should be +// injected. For a plain credential it returns the name unchanged. For a +// pool it returns the currently active member. Every consumer that reads a +// binding's Credential (pass-1 header inject, pass-2 phantom pairs, +// OAuthIndex.Has gating, persist attribution) routes through here so pool +// expansion happens in exactly one place (Important I2). +func (a *SluiceAddon) resolvePoolMember(name string) string { + if a.poolResolver == nil { + return name + } + pr := a.poolResolver.Load() + if pr == nil { + return name + } + if member, ok := pr.ResolveActive(name); ok { + return member + } + return name +} + // WithAuditLogger sets the audit logger for per-request events. func WithAuditLogger(l *audit.FileLogger) SluiceAddonOption { return func(a *SluiceAddon) { a.auditLog = l } diff --git a/internal/proxy/server.go b/internal/proxy/server.go index b16551c..f0c9fb8 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -82,9 +82,15 @@ type Server struct { dnsInterceptor *DNSInterceptor quicProxy *QUICProxy resolver atomic.Pointer[vault.BindingResolver] - closed atomic.Bool - serving atomic.Bool - activeConns sync.WaitGroup + // poolResolver expands a bound pool name to its active member at + // injection time. Swapped atomically alongside resolver on reload; + // membership is immutable per instance while health is mutated in + // place under the resolver's own mutex (Phase 2 synchronous + // failover). Parallel to resolver, never gates it. + poolResolver atomic.Pointer[vault.PoolResolver] + closed atomic.Bool + serving atomic.Bool + activeConns sync.WaitGroup // oauthMetasCache holds the latest credential_meta slice the // server saw via UpdateOAuthIndex. Cached so a later @@ -659,6 +665,7 @@ func (s *Server) setupInjection(cfg Config, _ net.Listener) error { // Create the SluiceAddon for go-mitmproxy. addonOpts := []SluiceAddonOption{ WithResolver(&s.resolver), + WithPoolResolver(&s.poolResolver), WithProvider(cfg.Provider), WithWSProxy(wsProxy), } @@ -2681,6 +2688,27 @@ func (s *Server) StoreResolver(r *vault.BindingResolver) { s.resolver.Store(r) } +// StorePool atomically stores a new credential pool resolver. The caller +// must hold ReloadMu() when concurrent mutations are possible. The MITM +// addon shares the same atomic pointer so the injection chokepoint and the +// response-side failover see the same pool/health snapshot. A nil resolver +// (no pools configured) is stored as a non-nil empty resolver so the addon +// can call IsPool/ResolveActive without nil-checking; ResolveActive on a +// non-pool name is an identity passthrough. +func (s *Server) StorePool(r *vault.PoolResolver) { + s.poolResolver.Store(r) + if s.addon != nil { + s.addon.SetPoolResolver(&s.poolResolver) + } +} + +// PoolResolverPtr returns the shared atomic pool resolver pointer so the +// Telegram/REST mutation paths can keep the proxy's live pool snapshot in +// sync with the store, mirroring ResolverPtr. +func (s *Server) PoolResolverPtr() *atomic.Pointer[vault.PoolResolver] { + return &s.poolResolver +} + // UpdateOAuthIndex rebuilds the OAuth token URL index from credential // metadata. Call this after StoreResolver in the SIGHUP reload path or // after Telegram credential mutations so the response handler detects From b68cd7ce908cb01b2a159d33cdd1c1e2b959b791 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:26:54 +0800 Subject: [PATCH 08/49] feat(mcp): opt MCP tool calls out of approval coalescing --- internal/mcp/gateway.go | 8 ++- internal/mcp/gateway_test.go | 97 +++++++++++++++++++++++++++ internal/proxy/request_policy_test.go | 65 ++++++++++++++++++ 3 files changed, 169 insertions(+), 1 deletion(-) diff --git a/internal/mcp/gateway.go b/internal/mcp/gateway.go index f1b4b52..378f292 100644 --- a/internal/mcp/gateway.go +++ b/internal/mcp/gateway.go @@ -225,7 +225,13 @@ func (gw *Gateway) HandleToolCall(req CallToolParams) (*ToolResult, error) { } log.Printf("[MCP ASK] %s (args: %s)", req.Name, argsStr) timeout := time.Duration(gw.timeoutSec) * time.Second - resp, err := gw.broker.Request(req.Name, 0, "mcp", timeout, channel.WithToolArgs(argsStr)) + // MCP tool calls opt out of broker-level coalescing: two calls to + // the same tool with different ToolArgs are semantically distinct + // and feed arg-sensitive ContentInspector/exec rules, so each must + // get its own prompt. A "dest:port" dedup key (req.Name + port 0) + // would wrongly collapse them. + resp, err := gw.broker.Request(req.Name, 0, "mcp", timeout, + channel.WithToolArgs(argsStr), channel.WithNoCoalesce()) if err != nil { gw.logAudit(req.Name, "tool_call", policy.Deny) return &ToolResult{ diff --git a/internal/mcp/gateway_test.go b/internal/mcp/gateway_test.go index 73d2d0c..6e490e7 100644 --- a/internal/mcp/gateway_test.go +++ b/internal/mcp/gateway_test.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "sync" "testing" "time" @@ -1553,3 +1554,99 @@ func TestGatewayToolNamespacePreventsCollision(t *testing.T) { t.Errorf("expected success, got error: %s", result2.Content[0].Text) } } + +// gatingRecordChannel records every approval request and holds resolution +// until release is closed. Used to prove MCP tool calls do NOT coalesce. +type gatingRecordChannel struct { + mu sync.Mutex + broker *channel.Broker + requests []channel.ApprovalRequest + release chan struct{} + resp channel.Response +} + +func (c *gatingRecordChannel) RequestApproval(_ context.Context, req channel.ApprovalRequest) error { + c.mu.Lock() + c.requests = append(c.requests, req) + rel := c.release + br := c.broker + resp := c.resp + c.mu.Unlock() + go func() { + <-rel + br.Resolve(req.ID, resp) + }() + return nil +} +func (c *gatingRecordChannel) CancelApproval(_ string) error { return nil } +func (c *gatingRecordChannel) Commands() <-chan channel.Command { return nil } +func (c *gatingRecordChannel) Notify(_ context.Context, _ string) error { return nil } +func (c *gatingRecordChannel) Start() error { return nil } +func (c *gatingRecordChannel) Stop() {} +func (c *gatingRecordChannel) Type() channel.ChannelType { return channel.ChannelTelegram } + +func (c *gatingRecordChannel) count() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.requests) +} + +// TestGatewayToolCallAskNotCoalesced verifies the gateway passes +// WithNoCoalesce: two concurrent calls to the same tool with different +// ToolArgs must each produce their own prompt (they are semantically +// distinct and feed arg-sensitive inspection), never collapsing onto one. +func TestGatewayToolCallAskNotCoalesced(t *testing.T) { + script := writeMockServer(t) + tp, err := NewToolPolicy([]policy.ToolRule{ + {Tool: "test__greet", Verdict: "ask"}, + }, policy.Allow) + if err != nil { + t.Fatal(err) + } + ch := &gatingRecordChannel{release: make(chan struct{}), resp: channel.ResponseAllowOnce} + broker := channel.NewBroker([]channel.Channel{ch}) + ch.broker = broker + + gw := newGatewayForTest(t, GatewayConfig{ + Upstreams: []UpstreamConfig{{ + Name: "test", + Command: "bash", + Args: []string{script}, + }}, + ToolPolicy: tp, + Broker: broker, + TimeoutSec: 5, + }) + + var wg sync.WaitGroup + for i, args := range []string{`{"a":1}`, `{"a":2}`} { + wg.Add(1) + go func(i int, a string) { + defer wg.Done() + if _, err := gw.HandleToolCall(CallToolParams{ + Name: "test__greet", + Arguments: json.RawMessage(a), + }); err != nil { + t.Errorf("call %d: %v", i, err) + } + }(i, args) + } + + // Wait until both prompts have been delivered (no coalescing). If the + // gateway wrongly coalesced, only one prompt would ever arrive and + // this would time out. + deadline := time.After(3 * time.Second) + for ch.count() < 2 { + select { + case <-deadline: + t.Fatalf("expected 2 distinct MCP prompts, got %d (wrongly coalesced)", ch.count()) + default: + time.Sleep(time.Millisecond) + } + } + close(ch.release) + wg.Wait() + if ch.count() != 2 { + t.Errorf("expected exactly 2 prompts, got %d", ch.count()) + } +} diff --git a/internal/proxy/request_policy_test.go b/internal/proxy/request_policy_test.go index b9ecbd3..54a4577 100644 --- a/internal/proxy/request_policy_test.go +++ b/internal/proxy/request_policy_test.go @@ -745,3 +745,68 @@ destination = "api.example.com" t.Errorf("broker request count = %d, want 1 (concurrent burst coalesces)", fc.requestCount()) } } + +// TestRequestPolicyChecker_SSHStyleConnectionLevelCoalesces confirms the +// connection-level Ask path (no HTTP method/path, e.g. an SSH/IMAP/SMTP +// burst to one host:port that is deferred through CheckAndConsume -> +// resolveAsk -> broker.Request) coalesces onto a single prompt exactly like +// HTTP per-request asks. There is no separate SSH call site and no +// per-protocol special-casing: persistence granularity (one dest:port rule) +// equals dedup granularity. +func TestRequestPolicyChecker_SSHStyleConnectionLevelCoalesces(t *testing.T) { + toml := ` +[policy] +default = "deny" + +[[ask]] +destination = "git.example.com" +` + checker, fc := newTestChecker(t, toml, channel.ResponseAllowOnce) + releaseAll := fc.gate() + + const n = 4 + var wg sync.WaitGroup + results := make([]policy.Verdict, n) + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + // No WithRequestInfo: this is the connection-level shape. + v, err := checker.CheckAndConsume("git.example.com", 22, WithProtocol("ssh")) + if err != nil { + t.Errorf("goroutine %d: %v", i, err) + } + results[i] = v + }(i) + } + + deadline := time.After(3 * time.Second) + for { + id := fc.firstReqID() + if id != "" && checker.broker.CoalescedCount(id) >= n { + break + } + select { + case <-deadline: + t.Fatalf("SSH-style burst did not coalesce: requests=%d", fc.requestCount()) + default: + time.Sleep(time.Millisecond) + } + } + releaseAll() + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("connection-level CheckAndConsume did not finish in time") + } + for i, v := range results { + if v != policy.Allow { + t.Errorf("result[%d] = %v, want Allow", i, v) + } + } + if fc.requestCount() != 1 { + t.Errorf("broker request count = %d, want 1 (connection-level burst coalesces)", fc.requestCount()) + } +} From 185a3829ab21022575e3c9252c5a40691fa2a643 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:28:42 +0800 Subject: [PATCH 09/49] wip(telegram): checkpoint final-count edit before respawn --- internal/telegram/approval.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/internal/telegram/approval.go b/internal/telegram/approval.go index 65500f7..e47a32f 100644 --- a/internal/telegram/approval.go +++ b/internal/telegram/approval.go @@ -191,6 +191,13 @@ func (tc *TelegramChannel) CancelApproval(id string) error { } else if tc.broker != nil && tc.broker.IsClosed() { reason = "(proxy shutting down)" } + // Fold the final coalesced count into the one edit that already + // happens here — zero extra Telegram API calls. + if tc.broker != nil { + if c := tc.broker.CoalescedCount(id); c > 1 { + reason += fmt.Sprintf(" — applied to %d requests", c) + } + } edit := tgbotapi.NewEditMessageText(tc.chatID, am.messageID, FormatApprovalMessage(am.req)+"\n\n"+reason) edit.ParseMode = tgbotapi.ModeHTML @@ -330,6 +337,16 @@ func (tc *TelegramChannel) handleCallback(cq *tgbotapi.CallbackQuery) { } if resolved { + // Resolve has run, so the broker has recorded the final coalesced + // count for this primary. Fold it into the single resolve edit + // (no extra Send): one tap dismissed the whole burst. + labelText := label + if tc.broker != nil { + if c := tc.broker.CoalescedCount(reqID); c > 1 { + labelText = fmt.Sprintf("%s — applied to %d requests", label, c) + } + } + callback := tgbotapi.NewCallback(cq.ID, label) _, _ = tc.api.Request(callback) From a3602d6ef8d584b365a2daaaad1e84ac3fb3a7f6 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:35:02 +0800 Subject: [PATCH 10/49] fix(telegram): use coalesced-count label in resolve edit body --- internal/telegram/approval.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/telegram/approval.go b/internal/telegram/approval.go index e47a32f..0e266c6 100644 --- a/internal/telegram/approval.go +++ b/internal/telegram/approval.go @@ -359,10 +359,10 @@ func (tc *TelegramChannel) handleCallback(cq *tgbotapi.CallbackQuery) { var body string if haveAM { body = fmt.Sprintf("%s\n\n%s at %s", - FormatApprovalMessage(am.req), label, time.Now().UTC().Format("15:04:05")) + FormatApprovalMessage(am.req), labelText, time.Now().UTC().Format("15:04:05")) } else { body = fmt.Sprintf("%s\n\n%s at %s", - cq.Message.Text, label, time.Now().UTC().Format("15:04:05")) + cq.Message.Text, labelText, time.Now().UTC().Format("15:04:05")) } edit := tgbotapi.NewEditMessageText(tc.chatID, cq.Message.MessageID, body) edit.ParseMode = tgbotapi.ModeHTML From d2f536028c5e6a1b78d715e7243765b0aa679ce1 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:49:02 +0800 Subject: [PATCH 11/49] docs(plans): convert tasks/phases to exec checkboxes; mark completed work --- docs/plans/20260515-approval-coalescing.md | 52 ++++++------ .../20260515-credential-pool-failover.md | 82 +++++++++---------- 2 files changed, 65 insertions(+), 69 deletions(-) diff --git a/docs/plans/20260515-approval-coalescing.md b/docs/plans/20260515-approval-coalescing.md index 1fd89f8..778c4b6 100644 --- a/docs/plans/20260515-approval-coalescing.md +++ b/docs/plans/20260515-approval-coalescing.md @@ -57,50 +57,50 @@ Verified against the working tree on `main` (tip `20cc367`): **Files:** Modify `internal/channel/broker.go`; Modify `internal/channel/broker_test.go` -- Add to `Broker`: `dedupIndex map[string]string` (`dedupKey → primary reqID`); extend `waiter` with `subs []chan Response`, `count int` (starts 1), and `dedupKey string`. -- Add `WithNoCoalesce()` request option (escape hatch). -- `Request`: compute `dedupKey := dest + ":" + strconv.Itoa(port)` (proto-agnostic — matches the proto-agnostic persisted rule). If `WithNoCoalesce` set, skip all dedup logic. -- Under `b.mu`: if `dedupIndex[dedupKey]` exists and that primary waiter is still in `waiters` → create a **buffered (cap 1)** response chan, append to `waiters[primary].subs`, `count++`, capture primary id + count; release lock; (Phase 2 only: notify channels of new count). Block on the sub chan using a **sub-specific select** arm: `case resp := <-subCh` (return it) / `case <-deadline.C` (under `b.mu`: remove this chan from `waiters[primary].subs` if still present, then return timeout deny — **no `waiters` delete, no `timedOut` entry**, must not tear down the shared waiter) / `case <-b.done` (drain subCh non-blocking then deny). -- Else: today's behavior — new id, register waiter, set `dedupIndex[dedupKey]=id`, record `dedupKey` on the waiter, `broadcast`. -- `Resolve(id,resp)`: under `b.mu` look up waiter, **snapshot the whole waiter (including the `subs` slice)**, delete `b.waiters[id]` **and** `delete(b.dedupIndex, w.dedupKey)` in the *same locked section* (this is what closes the late-attach race), release lock, then fan `resp` to `w.ch` and every chan in the snapshot `subs` (all buffered cap 1 so a send to a detached/timed-out sub never blocks), then `cancelOnChannels(id)`. First-wins preserved: only the call that finds the waiter present wins. -- Timeout/`done`/shutdown of the **primary**: fan the terminal response to all snapshot subs and clear `dedupIndex` under the same lock discipline. -- write tests: concurrent dedup → one broadcast; fan-out to all N; late-attach interleave (no attach to dead waiter); sub-timeout-detach does not block fan-out; deny/timeout/shutdown fan-out; distinct dest:port; `WithNoCoalesce`; cross-channel first-wins. -- run `go test ./internal/channel/...` — must pass before Task 2. +- [x] Add to `Broker`: `dedupIndex map[string]string`; extend `waiter` with `subs []chan Response`, `count int`, `dedupKey string`. +- [x] Add `WithNoCoalesce()` request option (escape hatch). +- [x] `Request`: compute `dedupKey := dest + ":" + strconv.Itoa(port)`; if `WithNoCoalesce` set, skip dedup. +- [x] Under `b.mu`: if `dedupIndex[dedupKey]` exists and primary waiter still present → buffered (cap 1) sub chan appended to `waiters[primary].subs`, `count++`; sub-specific select arm (resp / deadline detach-only / done) that never tears down the shared waiter. +- [x] Else: new id, register waiter, set `dedupIndex[dedupKey]=id`, record `dedupKey`, `broadcast`. +- [x] `Resolve(id,resp)`: snapshot waiter+subs and delete `waiters[id]`+`dedupIndex[w.dedupKey]` in the same locked section; fan resp to `w.ch` + all snapshot subs after unlock; `cancelOnChannels`. +- [x] Timeout/done/shutdown of primary: fan terminal response to all snapshot subs, clear `dedupIndex` under lock. +- [x] write tests: concurrent dedup → one broadcast; fan-out to all N; late-attach interleave; sub-timeout-detach non-blocking; deny/timeout/shutdown fan-out; distinct dest:port; `WithNoCoalesce`; cross-channel first-wins. +- [ ] verify `go test ./internal/channel/...` passes (re-run to confirm Task 1 still green after merge). ### Task 2: Persist-once (idempotent approval rule) **Files:** Modify `internal/store/store.go`; Modify `internal/proxy/server.go`; Modify `internal/store/store_test.go` -- Add `Store.HasApprovalRule(verdict, dest string, port int) (bool, error)` — plain SELECT against `rules` where `source='approval'` AND verdict/destination/port match. **No migration** (read-only query). -- In `persistApprovalRule` (`server.go:506`), under the existing `reloadMu`, call `HasApprovalRule` first and skip `AddRule` + engine recompile if present (M coalesced callers serialize on `reloadMu`; first inserts, rest no-op). This is the chosen design over "only primary persists" because it needs no broker→persist signaling and is robustly idempotent under concurrent resolve fan-out. (Scope note: this is *not* a vehicle for fixing any pre-existing manual double-tap dup-row behavior — that is out of scope and not a design driver.) -- write tests: M concurrent persists → exactly one row; existing single-persist path unchanged. -- run `go test ./internal/store/... ./internal/proxy/...` — must pass before Task 3. +- [ ] Verify/complete `Store.HasApprovalRule(verdict, dest string, port int) (bool, error)` — plain SELECT against `rules` where `source='approval'` AND verdict/destination/port match. No migration. +- [ ] Verify/complete `persistApprovalRule` (`server.go`): under `reloadMu`, call `HasApprovalRule` first and skip `AddRule` + engine recompile if present. +- [ ] write/verify tests: M concurrent persists → exactly one row; existing single-persist path unchanged. +- [ ] run `go test ./internal/store/... ./internal/proxy/...` — must pass before Task 3. ### Task 3: Route call sites; MCP opt-out **Files:** Modify `internal/mcp/gateway.go`; audit `internal/proxy/request_policy.go`, `internal/proxy/server.go` -- `request_policy.go:299` (HTTP/gRPC/WS + connection-level SSH/IMAP/SMTP): coalesce **uniformly**. Rationale: an SSH/IMAP burst to one `dest:port` persists the *same* single `dest:port` rule as HTTP — the plan's own "persistence granularity = dedup granularity" thesis applies identically; no per-protocol special-casing, no `checkContext` plumbing. -- `mcp/gateway.go:228`: pass `WithNoCoalesce()` — distinct `ToolArgs` are semantically distinct and arg-sensitive (ContentInspector/exec). This is the call site that genuinely needs the escape hatch. -- QUIC (`server.go:2477`): untouched (its own buffering remains). -- write tests: MCP calls with differing `ToolArgs` produce distinct prompts (not coalesced); SSH-style connection-level Ask to same dest:port coalesces. -- run `go test ./...` — must pass before Task 4. +- [x] `request_policy.go` (HTTP/gRPC/WS + connection-level SSH/IMAP/SMTP): coalesce uniformly. +- [x] `mcp/gateway.go`: pass `WithNoCoalesce()` — distinct `ToolArgs` are semantically distinct. +- [x] QUIC: untouched. +- [x] write tests: MCP calls with differing `ToolArgs` not coalesced; SSH-style connection-level Ask to same dest:port coalesces. +- [ ] run `go test ./...` — must pass before Task 4 (re-confirm after merge). ### Task 4: Final count on the existing resolve/cancel edit **Files:** Modify `internal/telegram/approval.go`; Modify `internal/channel/broker.go` (expose final `count` on resolve); Modify `internal/telegram/approval_test.go` -- Broker passes the final coalesced `count` to channels on cancel/resolve (extend the existing cancel/resolve notification path; no new Telegram Send). -- Telegram resolve edit (`:332-353`) / cancel edit (`:181-199`): when `count > 1`, render e.g. "Always allowed — applied to N requests at HH:MM:SS". **Zero extra API calls** — folded into the one edit that already happens. -- write tests: count rendered correctly for count==1 and count>1; no additional `Send` beyond the existing single edit. -- run `go test ./internal/telegram/...` — must pass. +- [ ] Verify/complete: broker passes final coalesced `count` to channels on cancel/resolve (no new Telegram Send). +- [ ] Verify/complete Telegram resolve edit / cancel edit: when `count > 1`, render "… — applied to N requests at HH:MM:SS"; zero extra API calls. +- [ ] write/verify tests: count rendered for count==1 and count>1; no additional `Send` beyond the existing single edit. +- [ ] run `go test ./internal/telegram/...` — must pass. ### Task 5: Verify acceptance + docs -- verify the prompt-wall scenario: burst → one prompt → one tap dismisses all (e2e). -- run full suite `go test ./... -timeout 30s`; run e2e `go test -tags=e2e ./e2e/ -count=1 -timeout=300s`. -- update CLAUDE.md "Channel/approval abstraction" + "QUIC broker dedup" notes to mention broker-level coalescing. -- move plan to `docs/plans/completed/`. +- [ ] verify the prompt-wall scenario via e2e (burst → one prompt → one tap dismisses all). +- [ ] run full suite `go test ./... -timeout 120s`; run e2e `go test -tags=e2e ./e2e/ -count=1 -timeout=300s` (if e2e cannot run in this environment, state so explicitly in the progress file, do not silently skip). +- [ ] update CLAUDE.md "Channel/approval abstraction" + "QUIC broker dedup" notes to mention broker-level coalescing. +- [ ] move plan to `docs/plans/completed/`. ## Phase 2 (optional) — Live mid-burst counter diff --git a/docs/plans/20260515-credential-pool-failover.md b/docs/plans/20260515-credential-pool-failover.md index dd36c29..dc43323 100644 --- a/docs/plans/20260515-credential-pool-failover.md +++ b/docs/plans/20260515-credential-pool-failover.md @@ -65,49 +65,45 @@ rotate` is an operator override, not the primary mechanism. ## Phases -### Phase 0 — Data model + CLI (no runtime behavior change) - -1. **Migration** `internal/store/migrations/000006_credential_pools.up.sql` (+`.down.sql`): - - `credential_pools(name TEXT PRIMARY KEY, strategy TEXT NOT NULL DEFAULT 'failover' CHECK(strategy IN ('failover')), created_at TEXT)`. - - `credential_pool_members(pool TEXT, credential TEXT, position INTEGER NOT NULL, PRIMARY KEY(pool,credential), FOREIGN KEY(pool) REFERENCES credential_pools(name) ON DELETE CASCADE)`. - - `credential_health(credential TEXT PRIMARY KEY, status TEXT NOT NULL DEFAULT 'healthy' CHECK(status IN ('healthy','cooldown')), cooldown_until TEXT, last_failure_reason TEXT, updated_at TEXT)`. -2. **Store API** `internal/store/store.go`: `CreatePool`, `AddPoolMember`, `ListPools`, `GetPool` (members ordered by `position`), `RemovePool`; `SetCredentialHealth`, `GetCredentialHealth`, `ListCredentialHealth`. App-layer CHECK: a member must be an existing `oauth` cred with non-empty `token_url`; reject `static`. - - **Orphan-member cleanup**: `cred remove ` of a pooled member must either cascade-remove the member row or mark it missing. Decision: `cred remove` errors if the credential is a live pool member ("remove it from pool `

` first"); document in CLI help. (No silent dangling rows.) -3. **CLI** `cmd/sluice/cred.go` new `pool` subtree: `pool create --members a,b[,c] [--strategy failover]`, `pool list`, `pool status ` (member order + health + active), `pool rotate ` (manual override), `pool remove `. -4. **Namespace**: pool names and credential names share one namespace. `pool create` rejects a name that collides with an existing credential; `cred add` rejects a name colliding with an existing pool. Bind a pool via `sluice binding add --destination ` (pool name stored verbatim in `bindings.credential`). - -Phase 0 exit: pools definable/inspectable; `reloadAll` loads pool + health tables into a new in-memory `PoolResolver` (atomic-pointer-swapped, parallel to `StoreResolver`), but injection does not consult it. - -### Phase 1 — Phantom indirection (pool phantom → active member) - -Active member changes only via `pool rotate` in this phase. - -1. **Single chokepoint for pool→member expansion** `internal/vault/pool.go` (new): - - `PoolResolver.IsPool(name) bool`; `ResolveActive(name) (member string, ok bool)` — if `name` is a pool, first member whose health is `healthy` or whose `cooldown_until <= now`, in `position` order; if all in cooldown, return the soonest-recovering member and log a WARNING. If `name` is a plain credential, return it unchanged. - - **Mandatory task: enumerate and route every `binding.Credential` / `OAuthIndex.Has` / `extractInjectableSecret` / `findAdder`/persist consumer through `ResolveActive` at one chokepoint** (grep `binding.Credential`, `\.Has(`, `extractInjectableSecret`). Do **not** scatter `IsPool` checks across pass-1/pass-2 only — that was the original gap. -2. **Injection** `internal/proxy/addon.go`: pass-1 header and pass-2 phantom swap call the chokepoint so the *real* value injected is the active member's, while the agent's pool-scoped phantom string is what gets matched/replaced. -3. **Per-request member tag — precise join key** (resolves Risk R1): - - When pass-2 swaps the agent's `SLUICE_PHANTOM:.refresh` to a real refresh token in an outbound token-endpoint request, sluice **records `realRefreshToken → member`** in a short-TTL map (the refresh token value is sluice's own injected bytes, unique per member, and is the field actually present in an RFC-6749 refresh-grant body — *not* the access token, which a refresh POST need not carry). `connState` keyed by `ClientConn.Id` is insufficient (one client conn multiplexes both members' h2 streams), so the join key is the real refresh-token value, not the connection. - - On the token-endpoint **response**, the handler recovers `member` from that map by the real refresh token sluice sent in the matching request. Persist refreshed tokens to *that member* (`persistAddonOAuthTokens(member,...)`, singleflight `"persist:"+member`). - - **Fail-closed (mandatory enumerated task + unit test):** if the member cannot be recovered, do **not** guess and do **not** fall back to `OAuthIndex.Match` for pooled token URLs — log a WARNING and skip the vault write so the next refresh retries. Dedicated unit test: two members, same token URL, assert a B-refresh never overwrites A's vault entry, and a missing tag results in zero writes. -4. **Pool-stable phantom** (resolves Risk R3): for pooled OAuth creds, `oauthPhantomAccess`/`resignJWT` produce a JWT from a deterministic synthetic payload keyed on the **pool name** (not the member's real token), so it is byte-identical across member switches. Enumerated unit test asserts byte-identity across a switch. Document the static-form fallback and the reason it is not the default. -5. `cmd/sluice/main.go:reloadAll` builds & swaps `PoolResolver` + health snapshot alongside the existing swaps. - -Phase 1 exit: `pool rotate` flips the backing account; agent's phantom unchanged byte-for-byte; refreshes attributed correctly; fail-closed proven by test. - -### Phase 2 — Auto-failover on 429 / 401 - -1. **Failure classification** in `SluiceAddon.Response` for pooled destinations: - - `429`, or `403` with body error `insufficient_quota`/quota-exhaustion → rate-limited. - - `401`, or token-endpoint body `invalid_grant`/`invalid_token` → auth-failure. - - `5xx` and everything else → no-op (upstream-side; failing over would thrash both accounts — documented choice). -2. **Prompt failover (resolves Important I1):** on classification, update the in-memory `PoolResolver` health **synchronously before the response returns** (atomic-pointer swap or dedicated mutex on the health map — call out the locking discipline), so the *very next* request injects the new active member. Also write `SetCredentialHealth(member, 'cooldown', now+ttl, reason)` to the store for durability; the 2s data-version watcher then merely reconciles. Do **not** rely on the 2s watcher for the active-member change — that lag was an error amplifier. - - Cooldown TTLs as named consts in `internal/vault/pool.go`: rate-limit 60s, auth-fail 300s (a broken refresh token will not self-heal quickly). Lazy recovery: `ResolveActive` treats expired cooldown as eligible — no scheduler. -3. **Audit**: emit `cred_failover` with `Reason = ":->:<429|403|401|invalid_grant>"`. -4. **Telegram notify** (best-effort, non-blocking, never blocks injection): one-line "pool `` failed over ``→`` ()". -5. **No in-flight retry** of the triggering request in Phase 2 (it returns its error; the next request uses the new member). Transparent retry is out of scope (needs body buffering; unsafe for non-idempotent calls). - -Phase 2 exit: e2e proves A 429 → next request uses B → B's refresh persists to B → phantom byte-unchanged. +### Task 1: Phase 0 — Data model + CLI (no runtime behavior change) + +- [x] Migration `internal/store/migrations/000006_credential_pools.{up,down}.sql`: `credential_pools`, `credential_pool_members`, `credential_health` tables with the documented CHECK constraints. +- [x] Store API `internal/store/store.go`/`pools.go`: pool CRUD + member ordering + `Set/Get/ListCredentialHealth`; reject `static` members; `cred remove` errors on a live pool member. +- [x] CLI `cmd/sluice/pool.go`: `pool create/list/status/rotate/remove`. +- [x] Namespace mutual-exclusion (pool name vs credential name) at create time. +- [x] `reloadAll` loads pool + health into an atomic-pointer-swapped `PoolResolver` (no injection consumption yet). +- [ ] re-run `go test ./internal/store/... ./internal/vault/... ./cmd/...` to confirm Phase 0 still green after merge. + +### Task 2: Phase 1 — Phantom indirection (pool phantom → active member) + +**Files:** `internal/vault/pool.go`, `internal/proxy/addon.go`, `internal/proxy/oauth_response.go`, `internal/proxy/oauth_index.go`, `cmd/sluice/main.go` + tests + +- [ ] `PoolResolver.IsPool(name)` + `ResolveActive(name)` (healthy/expired-cooldown first by position; all-in-cooldown → soonest-recovering + WARNING; plain cred returned unchanged). +- [ ] Route EVERY `binding.Credential` / `OAuthIndex.Has` / `extractInjectableSecret` / `findAdder`/persist consumer through `ResolveActive` at one chokepoint (grep `binding.Credential`, `\.Has(`, `extractInjectableSecret`; do not scatter `IsPool` checks). +- [ ] Injection (`addon.go` pass-1 header + pass-2 phantom swap) injects the active member's real value while matching/replacing the pool-scoped phantom string. +- [ ] R1 per-request member tag: record `realRefreshToken → member` (short-TTL map) when pass-2 swaps `SLUICE_PHANTOM:.refresh`; on token-endpoint response recover member by that real refresh token; persist to that member (`persistAddonOAuthTokens(member,...)`, singleflight `"persist:"+member`). +- [ ] R1 fail-closed: if member unrecoverable, do NOT guess, do NOT fall back to `OAuthIndex.Match` for pooled token URLs — WARNING + skip vault write. +- [ ] R1 dedicated unit test: two members, same token URL — B-refresh never overwrites A; missing tag → zero writes. +- [ ] R3 pool-stable phantom: pooled OAuth `oauthPhantomAccess`/`resignJWT` build the JWT from a deterministic synthetic payload keyed on the pool name (byte-identical across member switch). Unit test asserts byte-identity across a switch; document static-form fallback. +- [ ] `cmd/sluice/main.go:reloadAll` builds & swaps `PoolResolver` + health snapshot alongside existing swaps. +- [ ] `go test ./... -timeout 120s` green; build clean; gofumpt. + +### Task 3: Phase 2 — Auto-failover on 429 / 401 + +**Files:** `internal/proxy/addon.go`, `internal/vault/pool.go`, audit logger, telegram + tests + +- [ ] Failure classification in `SluiceAddon.Response` for pooled destinations: 429 or 403+`insufficient_quota` → rate-limited; 401 or token-body `invalid_grant`/`invalid_token` → auth-failure; 5xx/other → no-op. +- [ ] Prompt failover: synchronously update in-memory `PoolResolver` health BEFORE the response returns (documented locking discipline); also `SetCredentialHealth(member,'cooldown',now+ttl,reason)` for durability (2s watcher only reconciles). Cooldown TTL consts: rate-limit 60s, auth-fail 300s; lazy recovery in `ResolveActive`. +- [ ] Audit `cred_failover` with `Reason = ":->:<429|403|401|invalid_grant>"`. +- [ ] Telegram best-effort non-blocking notice "pool `` failed over ``→`` ()". +- [ ] No in-flight retry (documented); next request uses new member. +- [ ] Unit tests for classification + synchronous health swap + cooldown TTL/lazy recovery; `go test ./... -timeout 120s` green; build clean; gofumpt. + +### Task 4: Verify acceptance + docs + +- [ ] full `go test ./... -timeout 120s`; e2e `go test -tags=e2e ./e2e/ -count=1 -timeout=300s` (if e2e cannot run here, state so explicitly in the progress file, do not silently skip). +- [ ] update CLAUDE.md credential-pool/failover notes. +- [ ] move plan to `docs/plans/completed/`. ## Out of scope / future work From 0b0554bd8872f45794a1631d7cc08ff40499751a Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:50:47 +0800 Subject: [PATCH 12/49] feat(channel): re-confirm broker coalescing tests green --- docs/plans/20260515-approval-coalescing.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/plans/20260515-approval-coalescing.md b/docs/plans/20260515-approval-coalescing.md index 778c4b6..6898bd1 100644 --- a/docs/plans/20260515-approval-coalescing.md +++ b/docs/plans/20260515-approval-coalescing.md @@ -65,7 +65,7 @@ Verified against the working tree on `main` (tip `20cc367`): - [x] `Resolve(id,resp)`: snapshot waiter+subs and delete `waiters[id]`+`dedupIndex[w.dedupKey]` in the same locked section; fan resp to `w.ch` + all snapshot subs after unlock; `cancelOnChannels`. - [x] Timeout/done/shutdown of primary: fan terminal response to all snapshot subs, clear `dedupIndex` under lock. - [x] write tests: concurrent dedup → one broadcast; fan-out to all N; late-attach interleave; sub-timeout-detach non-blocking; deny/timeout/shutdown fan-out; distinct dest:port; `WithNoCoalesce`; cross-channel first-wins. -- [ ] verify `go test ./internal/channel/...` passes (re-run to confirm Task 1 still green after merge). +- [x] verify `go test ./internal/channel/...` passes (re-run to confirm Task 1 still green after merge). ### Task 2: Persist-once (idempotent approval rule) From 8948f2b35285e2b9f4526d5bca064062a2cb538c Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:52:35 +0800 Subject: [PATCH 13/49] feat(store): idempotent approval-rule persist --- docs/plans/20260515-approval-coalescing.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/plans/20260515-approval-coalescing.md b/docs/plans/20260515-approval-coalescing.md index 6898bd1..6c0d1ad 100644 --- a/docs/plans/20260515-approval-coalescing.md +++ b/docs/plans/20260515-approval-coalescing.md @@ -71,10 +71,10 @@ Verified against the working tree on `main` (tip `20cc367`): **Files:** Modify `internal/store/store.go`; Modify `internal/proxy/server.go`; Modify `internal/store/store_test.go` -- [ ] Verify/complete `Store.HasApprovalRule(verdict, dest string, port int) (bool, error)` — plain SELECT against `rules` where `source='approval'` AND verdict/destination/port match. No migration. -- [ ] Verify/complete `persistApprovalRule` (`server.go`): under `reloadMu`, call `HasApprovalRule` first and skip `AddRule` + engine recompile if present. -- [ ] write/verify tests: M concurrent persists → exactly one row; existing single-persist path unchanged. -- [ ] run `go test ./internal/store/... ./internal/proxy/...` — must pass before Task 3. +- [x] Verify/complete `Store.HasApprovalRule(verdict, dest string, port int) (bool, error)` — plain SELECT against `rules` where `source='approval'` AND verdict/destination/port match. No migration. +- [x] Verify/complete `persistApprovalRule` (`server.go`): under `reloadMu`, call `HasApprovalRule` first and skip `AddRule` + engine recompile if present. +- [x] write/verify tests: M concurrent persists → exactly one row; existing single-persist path unchanged. +- [x] run `go test ./internal/store/... ./internal/proxy/...` — must pass before Task 3. ### Task 3: Route call sites; MCP opt-out From 8e265fd4bcb2b501464057e06d7f057cb19460c3 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:54:05 +0800 Subject: [PATCH 14/49] test(proxy): confirm full suite green after coalescing merge --- docs/plans/20260515-approval-coalescing.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/plans/20260515-approval-coalescing.md b/docs/plans/20260515-approval-coalescing.md index 6c0d1ad..a466767 100644 --- a/docs/plans/20260515-approval-coalescing.md +++ b/docs/plans/20260515-approval-coalescing.md @@ -84,7 +84,7 @@ Verified against the working tree on `main` (tip `20cc367`): - [x] `mcp/gateway.go`: pass `WithNoCoalesce()` — distinct `ToolArgs` are semantically distinct. - [x] QUIC: untouched. - [x] write tests: MCP calls with differing `ToolArgs` not coalesced; SSH-style connection-level Ask to same dest:port coalesces. -- [ ] run `go test ./...` — must pass before Task 4 (re-confirm after merge). +- [x] run `go test ./...` — must pass before Task 4 (re-confirm after merge). ### Task 4: Final count on the existing resolve/cancel edit From 2d945de3c95eee65c0b170a3747f447c244ba491 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Fri, 15 May 2026 23:59:08 +0800 Subject: [PATCH 15/49] feat(telegram): render coalesced count on resolve and cancel edits --- docs/plans/20260515-approval-coalescing.md | 8 +- internal/channel/broker.go | 4 + internal/telegram/approval_test.go | 212 +++++++++++++++++++++ 3 files changed, 220 insertions(+), 4 deletions(-) diff --git a/docs/plans/20260515-approval-coalescing.md b/docs/plans/20260515-approval-coalescing.md index a466767..0075899 100644 --- a/docs/plans/20260515-approval-coalescing.md +++ b/docs/plans/20260515-approval-coalescing.md @@ -90,10 +90,10 @@ Verified against the working tree on `main` (tip `20cc367`): **Files:** Modify `internal/telegram/approval.go`; Modify `internal/channel/broker.go` (expose final `count` on resolve); Modify `internal/telegram/approval_test.go` -- [ ] Verify/complete: broker passes final coalesced `count` to channels on cancel/resolve (no new Telegram Send). -- [ ] Verify/complete Telegram resolve edit / cancel edit: when `count > 1`, render "… — applied to N requests at HH:MM:SS"; zero extra API calls. -- [ ] write/verify tests: count rendered for count==1 and count>1; no additional `Send` beyond the existing single edit. -- [ ] run `go test ./internal/telegram/...` — must pass. +- [x] Verify/complete: broker passes final coalesced `count` to channels on cancel/resolve (no new Telegram Send). (resolve/timeout already recorded the final count; added missing `recordCoalescedLocked` to the broker shutdown branch so the shutdown CancelApproval edit can also render the count.) +- [x] Verify/complete Telegram resolve edit / cancel edit: when `count > 1`, render "… — applied to N requests at HH:MM:SS"; zero extra API calls. (resolve at approval.go:343-348, cancel at approval.go:194-200 — both pre-existing from wip 185a382 + fix a3602d6, verified correct.) +- [x] write/verify tests: count rendered for count==1 and count>1; no additional `Send` beyond the existing single edit. (TestHandleCallbackRendersCoalescedCount, TestHandleCallbackSingleRequestNoCount, TestCancelApprovalRendersCoalescedCount, TestCancelApprovalSingleRequestNoCount — assert exactly one prompt send + exactly one resolve/cancel edit, count rendering adds zero API calls.) +- [x] run `go test ./internal/telegram/...` — must pass. (230 passed across telegram + channel.) ### Task 5: Verify acceptance + docs diff --git a/internal/channel/broker.go b/internal/channel/broker.go index eb01891..89a1559 100644 --- a/internal/channel/broker.go +++ b/internal/channel/broker.go @@ -327,6 +327,10 @@ func (b *Broker) Request(dest string, port int, protocol string, timeout time.Du if w.dedupKey != "" { delete(b.dedupIndex, w.dedupKey) } + // Retain the final coalesced count so the shutdown + // CancelApproval edit can still render "applied to N + // requests" for a burst that was pending at shutdown. + b.recordCoalescedLocked(id, w.count) } b.mu.Unlock() // Fan the terminal deny to any coalesced subscribers (buffered diff --git a/internal/telegram/approval_test.go b/internal/telegram/approval_test.go index 50f4eca..58f3f04 100644 --- a/internal/telegram/approval_test.go +++ b/internal/telegram/approval_test.go @@ -2094,3 +2094,215 @@ func waitForPending(t *testing.T, broker *channel.Broker, n int) { //nolint:unpa } } } + +// fireCoalescedBurstTG starts the primary request, waits for it to register, +// then fires n-1 more concurrent requests to the same dest:port so they +// coalesce onto the primary waiter. It returns the primary request ID and a +// channel that receives all n responses. +func fireCoalescedBurstTG(t *testing.T, broker *channel.Broker, dest string, port, n int) (string, chan channel.Response) { + t.Helper() + out := make(chan channel.Response, n) + + go func() { + resp, _ := broker.Request(dest, port, "https", 5*time.Second) + out <- resp + }() + waitForPending(t, broker, 1) + reqID := broker.PendingRequests()[0].ID + + for i := 0; i < n-1; i++ { + go func() { + resp, _ := broker.Request(dest, port, "https", 5*time.Second) + out <- resp + }() + } + + deadline := time.After(3 * time.Second) + for broker.CoalescedCount(reqID) < n { + select { + case <-deadline: + t.Fatalf("burst did not coalesce: count=%d want %d", broker.CoalescedCount(reqID), n) + default: + time.Sleep(time.Millisecond) + } + } + return reqID, out +} + +// TestHandleCallbackRendersCoalescedCount verifies that a coalesced burst +// resolved via one inline-keyboard tap folds the final count into the single +// resolve edit ("applied to N requests") with zero extra Telegram Sends +// beyond that one edit. +func TestHandleCallbackRendersCoalescedCount(t *testing.T) { + mock := newMockTelegramAPI(t) + s := newTestStore(t) + tc := newTestTelegramChannel(t, mock, s) + + broker := channel.NewBroker([]channel.Channel{tc}) + tc.SetBroker(broker) + + const n = 5 + reqID, out := fireCoalescedBurstTG(t, broker, "burst.example.com", 443, n) + + tc.handleCallback(&tgbotapi.CallbackQuery{ + ID: "cb_coalesce", + Message: &tgbotapi.Message{ + MessageID: 300, + Chat: &tgbotapi.Chat{ID: 12345}, + Text: "burst message", + }, + Data: reqID + "|allow_once", + }) + + // All n waiters must receive the response from the single tap. + for i := 0; i < n; i++ { + select { + case resp := <-out: + if resp != channel.ResponseAllowOnce { + t.Fatalf("waiter %d got %v, want AllowOnce", i, resp) + } + case <-time.After(3 * time.Second): + t.Fatalf("only %d/%d waiters received the fanned response", i, n) + } + } + + time.Sleep(50 * time.Millisecond) + + edits := mock.getEditedMessages() + if len(edits) != 1 { + t.Fatalf("expected exactly 1 edit (the resolve edit), got %d: %+v", len(edits), edits) + } + want := fmt.Sprintf("applied to %d requests", n) + if !strings.Contains(edits[0].Text, want) { + t.Errorf("resolve edit should contain %q, got: %s", want, edits[0].Text) + } + + // Baseline only: exactly one initial prompt sendMessage for the whole + // coalesced burst (one broadcast) and the single resolve edit above. + // Count rendering must add zero extra API calls. + if sent := mock.getSentMessages(); len(sent) != 1 { + t.Errorf("expected exactly 1 prompt send (no extra Send beyond the single edit), got %d: %+v", len(sent), sent) + } +} + +// TestHandleCallbackSingleRequestNoCount verifies a lone (count==1) request +// renders the plain label with no "applied to" suffix, still one edit only. +func TestHandleCallbackSingleRequestNoCount(t *testing.T) { + mock := newMockTelegramAPI(t) + s := newTestStore(t) + tc := newTestTelegramChannel(t, mock, s) + + broker := channel.NewBroker([]channel.Channel{tc}) + tc.SetBroker(broker) + + done := make(chan channel.Response, 1) + go func() { + resp, _ := broker.Request("solo.example.com", 443, "https", 5*time.Second) + done <- resp + }() + waitForPending(t, broker, 1) + reqID := broker.PendingRequests()[0].ID + + tc.handleCallback(&tgbotapi.CallbackQuery{ + ID: "cb_solo", + Message: &tgbotapi.Message{ + MessageID: 301, + Chat: &tgbotapi.Chat{ID: 12345}, + Text: "solo message", + }, + Data: reqID + "|allow_once", + }) + + if resp := <-done; resp != channel.ResponseAllowOnce { + t.Fatalf("got %v, want AllowOnce", resp) + } + + time.Sleep(50 * time.Millisecond) + + edits := mock.getEditedMessages() + if len(edits) != 1 { + t.Fatalf("expected exactly 1 edit, got %d", len(edits)) + } + if strings.Contains(edits[0].Text, "applied to") { + t.Errorf("single request must not render an 'applied to N' suffix, got: %s", edits[0].Text) + } + // Baseline only: one prompt send + the single resolve edit, nothing more. + if sent := mock.getSentMessages(); len(sent) != 1 { + t.Errorf("expected exactly 1 prompt send (no extra Send beyond the single edit), got %d", len(sent)) + } +} + +// TestCancelApprovalRendersCoalescedCount verifies the cancel path (used for +// timeout / shutdown / resolved-via-another-channel) also folds the final +// coalesced count into its single edit. +func TestCancelApprovalRendersCoalescedCount(t *testing.T) { + mock := newMockTelegramAPI(t) + s := newTestStore(t) + tc := newTestTelegramChannel(t, mock, s) + + broker := channel.NewBroker([]channel.Channel{tc}) + tc.SetBroker(broker) + + const n = 4 + reqID, out := fireCoalescedBurstTG(t, broker, "cancel.example.com", 443, n) + + // Resolve via the broker directly (simulating another channel) so the + // final count is recorded, then drive the Telegram cleanup edit. + if !broker.Resolve(reqID, channel.ResponseDeny) { + t.Fatalf("Resolve returned false") + } + for i := 0; i < n; i++ { + select { + case <-out: + case <-time.After(3 * time.Second): + t.Fatalf("only %d/%d waiters received the response", i, n) + } + } + + if err := tc.CancelApproval(reqID); err != nil { + t.Fatalf("CancelApproval: %v", err) + } + + time.Sleep(50 * time.Millisecond) + edits := mock.getEditedMessages() + if len(edits) != 1 { + t.Fatalf("expected exactly 1 cancel edit, got %d", len(edits)) + } + want := fmt.Sprintf("applied to %d requests", n) + if !strings.Contains(edits[0].Text, want) { + t.Errorf("cancel edit should contain %q, got: %s", want, edits[0].Text) + } + // Baseline only: one prompt send for the burst + the single cancel edit. + if sent := mock.getSentMessages(); len(sent) != 1 { + t.Errorf("expected exactly 1 prompt send (no extra Send beyond the single edit), got %d", len(sent)) + } +} + +// TestCancelApprovalSingleRequestNoCount verifies the cancel path renders the +// plain reason with no "applied to" suffix for a lone request. +func TestCancelApprovalSingleRequestNoCount(t *testing.T) { + mock := newMockTelegramAPI(t) + s := newTestStore(t) + tc := newTestTelegramChannel(t, mock, s) + + broker := channel.NewBroker([]channel.Channel{tc}) + tc.SetBroker(broker) + + tc.msgMap.Store("req_solo_cancel", approvalMsg{ + messageID: 44, + req: channel.ApprovalRequest{ID: "req_solo_cancel", Destination: "solo.example.com", Port: 443, Protocol: "https"}, + }) + + if err := tc.CancelApproval("req_solo_cancel"); err != nil { + t.Fatalf("CancelApproval: %v", err) + } + + time.Sleep(50 * time.Millisecond) + edits := mock.getEditedMessages() + if len(edits) != 1 { + t.Fatalf("expected exactly 1 cancel edit, got %d", len(edits)) + } + if strings.Contains(edits[0].Text, "applied to") { + t.Errorf("single request cancel must not render 'applied to N', got: %s", edits[0].Text) + } +} From 5aff5ffa4c5bc336b227e8c4ffcb3ca0b1c5fe2c Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 00:06:17 +0800 Subject: [PATCH 16/49] docs(plans): complete approval-coalescing; move to completed --- CLAUDE.md | 4 +++- .../plans/{ => completed}/20260515-approval-coalescing.md | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) rename docs/plans/{ => completed}/20260515-approval-coalescing.md (87%) diff --git a/CLAUDE.md b/CLAUDE.md index 1099f72..0751575 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -264,7 +264,7 @@ or, when `f.ConnContext` is nil and dedup cannot be applied: **QUIC SNI extraction:** Hostname recovery uses `ExtractQUICSNI()` to decrypt the QUIC Initial packet and extract SNI from the embedded TLS ClientHello. QUIC Initial packets encrypt the ClientHello, but the encryption keys are derived from the Destination Connection ID (DCID) visible in the packet header (RFC 9001 Section 5). Supports both QUIC v1 and v2 salts. Falls back to DNS reverse cache lookup, then raw IP if extraction fails. -**QUIC broker dedup:** `pendingQUICSessions` in `server.go` prevents duplicate Telegram approval prompts when multiple UDP packets arrive for the same destination during the approval wait. Packets are buffered (max 32 per session). When approval resolves, buffered packets are flushed (if allowed) or discarded (if denied). +**QUIC broker dedup:** `pendingQUICSessions` in `server.go` prevents duplicate Telegram approval prompts when multiple UDP packets arrive for the same destination during the approval wait. Packets are buffered (max 32 per session). When approval resolves, buffered packets are flushed (if allowed) or discarded (if denied). This is QUIC's own packet-level dedup and is independent of the channel-agnostic broker-level `dest:port` coalescing (see "Channel/approval abstraction" above) — QUIC's `broker.Request` call site is deliberately left on this path and is *not* routed through the broker `dedupIndex` (it predates and subsumes it for the QUIC packet model). The broker coalescing covers the other two call sites (HTTP/HTTPS/gRPC/WS + connection-level SSH/IMAP/SMTP; MCP opted out). Both mechanisms converge on the same outcome: one prompt per target, one human tap dismisses the whole burst, and the final coalesced count is folded into the resolve/cancel edit. See `internal/proxy/request_policy.go`, `internal/policy/engine.go` (`EvaluateDetailed`, `EvaluateQUICDetailed`), `internal/proxy/quic_sni.go` (`ExtractQUICSNI`), and `internal/proxy/addon.go` (`SluiceAddon`). @@ -288,6 +288,8 @@ Two-phase detection: port-based guess first, then byte-level for non-standard po `Channel` interface with `Broker` coordinating across channels (Telegram, HTTP). Broadcast-and-first-wins. Rate limiting: `MaxPendingRequests` (50), per-destination (5/min). "Always Allow" writes to SQLite store, recompiles and swaps Engine. +**Broker-level approval coalescing.** The broker dedups pending approvals by their persistence-equivalent target (`dest:port`, the same key `persistApprovalRule` writes). The first request to a target opens one prompt and registers the primary waiter under `dedupIndex[dest:port]`. Concurrent requests to the same `dest:port` while that prompt is still pending do not create new prompts — they attach a buffered (cap 1) sub channel to the primary waiter (`waiter.subs`, `count++`) instead of broadcasting again. On resolve/deny/timeout/shutdown the terminal response fans out to `w.ch` plus a snapshot of every attached sub taken under the same lock that deletes `waiters[id]` and `dedupIndex[key]` (closes the late-attach race; subs use a detach-only select arm so a timed-out sub never tears down the shared waiter). One human tap therefore dismisses the whole burst, matching the granularity of the single persisted `dest:port` rule. The final coalesced `count` is folded into the *existing* resolve/cancel message edit (rendered as "… — applied to N requests at HH:MM:SS" when `count > 1`) so Phase 1 adds zero extra Telegram API calls. `WithNoCoalesce()` is the escape hatch. Of the three `broker.Request` call sites, HTTP/HTTPS/gRPC/WS and connection-level SSH/IMAP/SMTP (all share `request_policy.go`'s `resolveAsk`) coalesce uniformly; **MCP tool calls opt out** via `WithNoCoalesce()` because distinct `ToolArgs` are semantically distinct (arg-sensitive ContentInspector/exec rules) and must not collapse onto one `dest:port` key; QUIC keeps its own packet-buffering dedup (see below) and is untouched. + `CouldBeAllowed(dest, includeAsk)`: when broker configured, Ask-matching destinations resolve via DNS for approval flow. When no broker, Ask treated as Deny at DNS stage to prevent leaking queries. **DNS approval design**: The DNS interceptor intentionally only blocks explicitly denied domains (returns NXDOMAIN). All other queries (allow, ask, default) are forwarded to the upstream resolver. This is by design. Policy enforcement for "ask" destinations happens at the SOCKS5 CONNECT layer, not at DNS. Blocking DNS for "ask" destinations would prevent the TCP connection from ever reaching the SOCKS5 handler where the approval flow triggers. The DNS layer populates the reverse DNS cache (IP -> hostname) so the SOCKS5 handler can recover hostnames from IP-only CONNECT requests. DNS uses `IsDeniedDomain`, a separate evaluation path that is independent from the unscoped-rule matching in `EvaluateUDP` / `EvaluateQUICDetailed`. Unscoped rules therefore widen TCP/UDP/QUIC policy without changing DNS behavior. diff --git a/docs/plans/20260515-approval-coalescing.md b/docs/plans/completed/20260515-approval-coalescing.md similarity index 87% rename from docs/plans/20260515-approval-coalescing.md rename to docs/plans/completed/20260515-approval-coalescing.md index 0075899..d5a59a2 100644 --- a/docs/plans/20260515-approval-coalescing.md +++ b/docs/plans/completed/20260515-approval-coalescing.md @@ -97,10 +97,10 @@ Verified against the working tree on `main` (tip `20cc367`): ### Task 5: Verify acceptance + docs -- [ ] verify the prompt-wall scenario via e2e (burst → one prompt → one tap dismisses all). -- [ ] run full suite `go test ./... -timeout 120s`; run e2e `go test -tags=e2e ./e2e/ -count=1 -timeout=300s` (if e2e cannot run in this environment, state so explicitly in the progress file, do not silently skip). -- [ ] update CLAUDE.md "Channel/approval abstraction" + "QUIC broker dedup" notes to mention broker-level coalescing. -- [ ] move plan to `docs/plans/completed/`. +- [x] verify the prompt-wall scenario via e2e (burst → one prompt → one tap dismisses all). The burst→one-prompt→fan-out scenario is verified at the unit/integration level: `internal/channel` 11 coalescing tests (TestBrokerCoalesceOneBroadcastFanToAll, ...DenyFanOut, ...TimeoutFanOut, ...ShutdownFanOut, ...SubTimeoutDoesNotBlockFanOut, ...LateAttachOpensNewPrompt, ...ConcurrentResolveAndAttach, TestBrokerDistinctDestNotCoalesced, TestBrokerSamePortDifferentDestNotCoalesced, TestBrokerWithNoCoalesceNeverCoalesces, TestBrokerCoalesceCrossChannelFirstWins) + `internal/proxy` TestRequestPolicyChecker_ConcurrentAllowOnceCoalesces / _SSHStyleConnectionLevelCoalesces + telegram TestHandleCallbackRendersCoalescedCount / TestCancelApprovalRendersCoalescedCount. [x] (skipped: dedicated burst e2e) — the `e2e/` suite has no delayed-verdict-server helper to keep a first approval pending while a concurrent burst arrives (the verdict server answers synchronously), so a true broker-coalescing e2e cannot be expressed without new harness code, which is out of scope for Task 5. The non-container e2e suite (66 tests, `-tags=e2e`) was run and passes, exercising the same `resolveAsk → broker.Request` Ask path via TestPerRequestAllowOnce*/AlwaysAllow*/Deny*. +- [x] run full suite `go test ./... -timeout 120s` (2524 passed, 13 packages); ran e2e `go test -tags=e2e ./e2e/ -count=1 -timeout=300s` (66 passed, non-container `e2e` tag). [x] (skipped: docker/apple-container e2e — `e2e && linux` / `e2e && darwin` compose/Apple-Container tags not run; the burst-coalescing scenario is verified by unit/integration tests above, container e2e adds no coalescing-specific coverage). +- [x] update CLAUDE.md "Channel/approval abstraction" + "QUIC broker dedup" notes to mention broker-level coalescing. +- [x] move plan to `docs/plans/completed/`. ## Phase 2 (optional) — Live mid-burst counter From ef05ee90311994594370963965b99dba221d52ef Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 00:07:54 +0800 Subject: [PATCH 17/49] test(store): confirm credential-pool Phase 0 green post-merge --- docs/plans/20260515-credential-pool-failover.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/plans/20260515-credential-pool-failover.md b/docs/plans/20260515-credential-pool-failover.md index dc43323..6e9e00a 100644 --- a/docs/plans/20260515-credential-pool-failover.md +++ b/docs/plans/20260515-credential-pool-failover.md @@ -72,7 +72,7 @@ rotate` is an operator override, not the primary mechanism. - [x] CLI `cmd/sluice/pool.go`: `pool create/list/status/rotate/remove`. - [x] Namespace mutual-exclusion (pool name vs credential name) at create time. - [x] `reloadAll` loads pool + health into an atomic-pointer-swapped `PoolResolver` (no injection consumption yet). -- [ ] re-run `go test ./internal/store/... ./internal/vault/... ./cmd/...` to confirm Phase 0 still green after merge. +- [x] re-run `go test ./internal/store/... ./internal/vault/... ./cmd/...` to confirm Phase 0 still green after merge. ### Task 2: Phase 1 — Phantom indirection (pool phantom → active member) From 864420c4220416812dbf8c65874886cd67fa1113 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 00:23:21 +0800 Subject: [PATCH 18/49] feat(proxy): pool phantom indirection + R1 attribution + R3 stable JWT --- .../20260515-credential-pool-failover.md | 18 +- internal/proxy/addon.go | 210 ++++++++++- internal/proxy/oauth_response.go | 70 ++++ internal/proxy/phantom_pairs.go | 53 +++ internal/proxy/pool_attribution.go | 86 +++++ internal/proxy/pool_phantom_test.go | 337 ++++++++++++++++++ 6 files changed, 750 insertions(+), 24 deletions(-) create mode 100644 internal/proxy/pool_attribution.go create mode 100644 internal/proxy/pool_phantom_test.go diff --git a/docs/plans/20260515-credential-pool-failover.md b/docs/plans/20260515-credential-pool-failover.md index 6e9e00a..dc1816b 100644 --- a/docs/plans/20260515-credential-pool-failover.md +++ b/docs/plans/20260515-credential-pool-failover.md @@ -78,15 +78,15 @@ rotate` is an operator override, not the primary mechanism. **Files:** `internal/vault/pool.go`, `internal/proxy/addon.go`, `internal/proxy/oauth_response.go`, `internal/proxy/oauth_index.go`, `cmd/sluice/main.go` + tests -- [ ] `PoolResolver.IsPool(name)` + `ResolveActive(name)` (healthy/expired-cooldown first by position; all-in-cooldown → soonest-recovering + WARNING; plain cred returned unchanged). -- [ ] Route EVERY `binding.Credential` / `OAuthIndex.Has` / `extractInjectableSecret` / `findAdder`/persist consumer through `ResolveActive` at one chokepoint (grep `binding.Credential`, `\.Has(`, `extractInjectableSecret`; do not scatter `IsPool` checks). -- [ ] Injection (`addon.go` pass-1 header + pass-2 phantom swap) injects the active member's real value while matching/replacing the pool-scoped phantom string. -- [ ] R1 per-request member tag: record `realRefreshToken → member` (short-TTL map) when pass-2 swaps `SLUICE_PHANTOM:.refresh`; on token-endpoint response recover member by that real refresh token; persist to that member (`persistAddonOAuthTokens(member,...)`, singleflight `"persist:"+member`). -- [ ] R1 fail-closed: if member unrecoverable, do NOT guess, do NOT fall back to `OAuthIndex.Match` for pooled token URLs — WARNING + skip vault write. -- [ ] R1 dedicated unit test: two members, same token URL — B-refresh never overwrites A; missing tag → zero writes. -- [ ] R3 pool-stable phantom: pooled OAuth `oauthPhantomAccess`/`resignJWT` build the JWT from a deterministic synthetic payload keyed on the pool name (byte-identical across member switch). Unit test asserts byte-identity across a switch; document static-form fallback. -- [ ] `cmd/sluice/main.go:reloadAll` builds & swaps `PoolResolver` + health snapshot alongside existing swaps. -- [ ] `go test ./... -timeout 120s` green; build clean; gofumpt. +- [x] `PoolResolver.IsPool(name)` + `ResolveActive(name)` (healthy/expired-cooldown first by position; all-in-cooldown → soonest-recovering + WARNING; plain cred returned unchanged). (Implemented in Phase 0; verified + covered by `internal/vault/pool_test.go`.) +- [x] Route EVERY `binding.Credential` / `OAuthIndex.Has` / `extractInjectableSecret` / `findAdder`/persist consumer through `ResolveActive` at one chokepoint (grep `binding.Credential`, `\.Has(`, `extractInjectableSecret`; do not scatter `IsPool` checks). (HTTP/HTTPS OAuth path — Task 2's file scope — routed through `resolveInjectionTarget` (pass-1, pass-2) and `resolveOAuthResponseAttribution` (response/persist). `idx.Has` is called with the resolved member name, never the pool. SSH/mail/QUIC paths are non-OAuth/out of Task 2 file scope and unchanged.) +- [x] Injection (`addon.go` pass-1 header + pass-2 phantom swap) injects the active member's real value while matching/replacing the pool-scoped phantom string. +- [x] R1 per-request member tag: record `realRefreshToken → member` (short-TTL map) when pass-2 swaps `SLUICE_PHANTOM:.refresh`; on token-endpoint response recover member by that real refresh token; persist to that member (`persistAddonOAuthTokens(member,...)`, singleflight `"persist:"+member`). +- [x] R1 fail-closed: if member unrecoverable, do NOT guess, do NOT fall back to `OAuthIndex.Match` for pooled token URLs — WARNING + skip vault write. +- [x] R1 dedicated unit test: two members, same token URL — B-refresh never overwrites A; missing tag → zero writes. +- [x] R3 pool-stable phantom: pooled OAuth `oauthPhantomAccess`/`resignJWT` build the JWT from a deterministic synthetic payload keyed on the pool name (byte-identical across member switch). Unit test asserts byte-identity across a switch; document static-form fallback. +- [x] `cmd/sluice/main.go:reloadAll` builds & swaps `PoolResolver` + health snapshot alongside existing swaps. (Wired in Phase 0; verified — `reloadAll` calls `loadPoolResolver` + `srv.StorePool`, and `StorePool` rewires the addon pointer via `SetPoolResolver`.) +- [x] `go test ./... -timeout 120s` green; build clean; gofumpt. ### Task 3: Phase 2 — Auto-failover on 429 / 401 diff --git a/internal/proxy/addon.go b/internal/proxy/addon.go index e6cd5a7..383cb27 100644 --- a/internal/proxy/addon.go +++ b/internal/proxy/addon.go @@ -125,6 +125,12 @@ type SluiceAddon struct { // refreshes. refreshGroup singleflight.Group + // refreshAttr maps the real refresh token sluice injected into an + // outbound OAuth refresh-grant to the pool member that owns it. It is + // the precise per-request join key for pooled credential refresh + // attribution (Risk R1). Never nil after NewSluiceAddon. + refreshAttr *refreshAttribution + // onOAuthRefresh is called after an OAuth token refresh persist // completes successfully. It receives the credential name so the // caller can re-inject updated phantom env vars into the agent @@ -168,6 +174,7 @@ type SluiceAddon struct { func NewSluiceAddon(opts ...SluiceAddonOption) *SluiceAddon { a := &SluiceAddon{ pendingCheckers: make(map[string][]*pendingCheck), + refreshAttr: newRefreshAttribution(), } for _, o := range opts { o(a) @@ -222,6 +229,47 @@ func (a *SluiceAddon) resolvePoolMember(name string) string { return name } +// injectionTarget is the result of expanding a bound credential-or-pool +// name at the single chokepoint. phantomName is the name the agent's +// phantom string is keyed on (the POOL name when pooled, so the phantom is +// stable across member switches); secretName is the concrete credential +// whose real vault value is injected (the active member when pooled). For a +// plain credential both fields equal the input name and pooled is false. +type injectionTarget struct { + phantomName string + secretName string + pooled bool +} + +// resolveInjectionTarget is the single chokepoint every credential consumer +// (pass-1 header inject, pass-2 phantom pairs, OAuthIndex.Has gating, +// persist attribution) routes through. It expands a pool name to its active +// member exactly once here so no consumer scatters its own IsPool check +// (Important I2). The pool→member expansion MUST happen before any +// OAuthIndex.Has / JSON-envelope decision: a pool name is not in +// credential_meta so idx.Has(pool) is always false, and gating on the pool +// name would mis-handle the OAuth envelope as a static secret. +func (a *SluiceAddon) resolveInjectionTarget(name string) injectionTarget { + if a.poolResolver == nil { + return injectionTarget{phantomName: name, secretName: name} + } + pr := a.poolResolver.Load() + if pr == nil { + return injectionTarget{phantomName: name, secretName: name} + } + if !pr.IsPool(name) { + return injectionTarget{phantomName: name, secretName: name} + } + member, ok := pr.ResolveActive(name) + if !ok || member == "" { + // Empty/unresolvable pool: keep the name so callers degrade + // gracefully (no secret found -> no injection) rather than + // dereferencing an empty string. + return injectionTarget{phantomName: name, secretName: name, pooled: true} + } + return injectionTarget{phantomName: name, secretName: member, pooled: true} +} + // WithAuditLogger sets the audit logger for per-request events. func WithAuditLogger(l *audit.FileLogger) SluiceAddonOption { return func(a *SluiceAddon) { a.auditLog = l } @@ -591,16 +639,28 @@ func (a *SluiceAddon) injectHeaders(f *mitmproxy.Flow, host string, port int) { return } - secret, err := a.provider.Get(binding.Credential) + // Chokepoint: expand a bound pool name to its active member BEFORE the + // vault lookup and the OAuthIndex.Has envelope decision. A pool name is + // not a vault credential and is not in credential_meta, so both + // provider.Get and extractInjectableSecret must operate on the resolved + // member name (Important I2). + target := a.resolveInjectionTarget(binding.Credential) + + secret, err := a.provider.Get(target.secretName) if err != nil { - log.Printf("[ADDON-INJECT] credential %q lookup failed: %v", binding.Credential, err) + log.Printf("[ADDON-INJECT] credential %q lookup failed: %v", target.secretName, err) return } defer secret.Release() - f.Request.Header.Set(binding.Header, binding.FormatValue(extractInjectableSecret(a.oauthIndex.Load(), binding.Credential, secret.String()))) - log.Printf("[ADDON-INJECT] injected header %q for %s:%d (credential %q)", - binding.Header, host, port, binding.Credential) + f.Request.Header.Set(binding.Header, binding.FormatValue(extractInjectableSecret(a.oauthIndex.Load(), target.secretName, secret.String()))) + if target.pooled { + log.Printf("[ADDON-INJECT] injected header %q for %s:%d (pool %q -> member %q)", + binding.Header, host, port, binding.Credential, target.secretName) + } else { + log.Printf("[ADDON-INJECT] injected header %q for %s:%d (credential %q)", + binding.Header, host, port, binding.Credential) + } } // extractInjectableSecret returns the value to substitute into a binding's @@ -793,6 +853,66 @@ func (a *SluiceAddon) Response(f *mitmproxy.Flow) { a.scanResponseForDLP(f) } +// oauthRespAttribution describes how a token-endpoint response is handled. +// phantomName keys the phantom strings the agent receives (the POOL name +// for pooled creds, so the phantom is byte-identical across member switches +// — Risk R3). persistMember names the vault entry the rotated real tokens +// are written to. skipPersist is set when the response belongs to a pooled +// token URL but the owning member could not be recovered from the injected +// real refresh token — the swap still runs (the agent must never see real +// tokens) but the vault write is skipped so we never misfile B's rotated +// tokens under A (Risk R1, fail-closed). +type oauthRespAttribution struct { + phantomName string + persistMember string + pooled bool + skipPersist bool +} + +// resolveOAuthResponseAttribution turns the OAuthIndex match into a precise +// attribution. For a plain credential it is the identity (phantom + persist +// both the matched name). For a pooled member it keys the phantom on the +// pool name and recovers the owning member via the REAL refresh token that +// was injected into this exact outbound request body (the only join key +// that survives two members sharing one token URL). When recovery fails it +// returns skipPersist=true and never falls back to OAuthIndex.Match for the +// persist target (R1: never guess). +func (a *SluiceAddon) resolveOAuthResponseAttribution(f *mitmproxy.Flow, matchedCred string) oauthRespAttribution { + pr := (*vault.PoolResolver)(nil) + if a.poolResolver != nil { + pr = a.poolResolver.Load() + } + poolName := "" + if pr != nil { + poolName = pr.PoolForMember(matchedCred) + } + if poolName == "" { + // Not a pooled token URL: unchanged 1:1 behavior. + return oauthRespAttribution{phantomName: matchedCred, persistMember: matchedCred} + } + + // Pooled token URL. Recover the owning member from the real refresh + // token sluice injected into this request's body (R1 join key). + reqCT := "" + reqBody := []byte(nil) + if f.Request != nil { + if f.Request.Header != nil { + reqCT = f.Request.Header.Get("Content-Type") + } + reqBody = f.Request.Body + } + realRefresh := extractRequestRefreshToken(reqBody, reqCT) + member, ok := a.refreshAttr.Recover(realRefresh) + if !ok { + log.Printf("[ADDON-OAUTH] R1 fail-closed: pooled token URL for pool %q but owning member "+ + "could not be recovered from the injected refresh token; skipping vault write "+ + "(next refresh will retry)", poolName) + return oauthRespAttribution{phantomName: poolName, pooled: true, skipPersist: true} + } + log.Printf("[ADDON-OAUTH] R1 attributed pooled refresh to member %q (pool %q)", member, poolName) + return oauthRespAttribution{phantomName: poolName, persistMember: member, pooled: true} +} + // processOAuthResponseIfMatching performs OAuth token phantom swap on the // response when the request URL matches the OAuth index. Extracted from // Response so DLP scanning can run independently on non-OAuth responses. @@ -816,7 +936,11 @@ func (a *SluiceAddon) processOAuthResponseIfMatching(f *mitmproxy.Flow) { return } - modified, err := a.processAddonOAuthResponse(f, credName) + // Chokepoint: turn the (collision-prone for pools) OAuthIndex match + // into a precise phantom-key + persist-member attribution. + attr := a.resolveOAuthResponseAttribution(f, credName) + + modified, err := a.processAddonOAuthResponse(f, attr) if err != nil { log.Printf("[ADDON-OAUTH] error processing OAuth response for %q: %v", credName, err) return @@ -960,7 +1084,11 @@ func (a *SluiceAddon) StreamResponseModifier(f *mitmproxy.Flow, in io.Reader) (o contentType = f.Response.Header.Get("Content-Type") } - modified, err := a.swapOAuthTokens(body, contentType, credName) + // Chokepoint: precise phantom-key + persist-member attribution + // (pool-stable phantom, R1 fail-closed when member unrecoverable). + attr := a.resolveOAuthResponseAttribution(f, credName) + + modified, err := a.swapOAuthTokens(body, contentType, attr) if err != nil { // The body did not parse as an OAuth token response. This is // usually an HTML error page from a misconfigured token @@ -1000,7 +1128,8 @@ func (a *SluiceAddon) StreamResponseModifier(f *mitmproxy.Flow, in io.Reader) (o // snapshot is restored on every failure path so the flow either has a // fully phantom-swapped body or the original bytes with original // headers, never a half-modified mix. -func (a *SluiceAddon) processAddonOAuthResponse(f *mitmproxy.Flow, credName string) (modified bool, err error) { +func (a *SluiceAddon) processAddonOAuthResponse(f *mitmproxy.Flow, attr oauthRespAttribution) (modified bool, err error) { + credName := attr.phantomName if f == nil || f.Response == nil { return false, nil } @@ -1061,7 +1190,7 @@ func (a *SluiceAddon) processAddonOAuthResponse(f *mitmproxy.Flow, credName stri contentType = f.Response.Header.Get("Content-Type") } - swapped, err := a.swapOAuthTokens(body, contentType, credName) + swapped, err := a.swapOAuthTokens(body, contentType, attr) if err != nil { rollback() return false, err @@ -1098,14 +1227,33 @@ func (a *SluiceAddon) processAddonOAuthResponse(f *mitmproxy.Flow, credName stri // swapOAuthTokens parses a token response body, replaces real tokens with // deterministic phantoms, and schedules an async vault persist. Returns // the modified body. Shared by Response (buffered) and StreamResponseModifier. -func (a *SluiceAddon) swapOAuthTokens(body []byte, contentType, credName string) ([]byte, error) { +// +// attr controls phantom keying and persist target. For a plain credential +// it is the identity. For a pooled credential the phantom is keyed on the +// POOL name (byte-identical across member switches, Risk R3) and the +// persist target is the recovered owning member; when the member could not +// be recovered, attr.skipPersist suppresses the vault write entirely so a +// rotated token is never misfiled (Risk R1, fail-closed) — the swap still +// runs so the agent never receives the real tokens. +func (a *SluiceAddon) swapOAuthTokens(body []byte, contentType string, attr oauthRespAttribution) ([]byte, error) { tr, err := parseTokenResponse(body, contentType) if err != nil { return nil, err } - accessPhantom := oauthPhantomAccess(credName, tr.AccessToken) - refreshPhantom := oauthPhantomRefresh(credName, tr.RefreshToken) + var accessPhantom, refreshPhantom string + if attr.pooled { + // Pooled: phantomName is the pool name. Use the pool-stable + // synthetic JWT for access and the deterministic static string + // for refresh, byte-identical to what buildPooledOAuthPhantomPairs + // emits on the request side, so the agent's stored phantom never + // changes across a member switch. + accessPhantom = poolStablePhantomAccess(attr.phantomName) + refreshPhantom = "SLUICE_PHANTOM:" + attr.phantomName + ".refresh" + } else { + accessPhantom = oauthPhantomAccess(attr.phantomName, tr.AccessToken) + refreshPhantom = oauthPhantomRefresh(attr.phantomName, tr.RefreshToken) + } // Replace real tokens with phantoms in the response body. // Replace the longer token first to prevent substring corruption when @@ -1123,12 +1271,20 @@ func (a *SluiceAddon) swapOAuthTokens(body []byte, contentType, credName string) modified = bytes.ReplaceAll(modified, []byte(tr.AccessToken), []byte(accessPhantom)) } - // Asynchronously persist the new tokens to the vault. + if attr.skipPersist { + // R1 fail-closed: response swapped to phantoms (agent safe) but + // the owning pool member is unknown, so do NOT write the vault. + // The next refresh round-trip carries a fresh tag and retries. + return modified, nil + } + + // Asynchronously persist the new tokens to the vault, attributed to + // the precise member (pooled) or the credential itself (plain). realAccess := vault.NewSecureBytes(tr.AccessToken) realRefresh := vault.NewSecureBytes(tr.RefreshToken) expiresIn := tr.ExpiresIn - go a.persistAddonOAuthTokens(credName, realAccess, realRefresh, expiresIn) + go a.persistAddonOAuthTokens(attr.persistMember, realAccess, realRefresh, expiresIn) return modified, nil } @@ -1216,13 +1372,34 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string) []p } var pairs []phantomPair - for _, name := range boundCreds { + for _, boundName := range boundCreds { + // Chokepoint: expand a bound pool name to its active member + // before the vault lookup and the OAuth-envelope decision. The + // agent holds a pool-keyed phantom; the secret injected is the + // active member's real token (Important I2). + target := a.resolveInjectionTarget(boundName) + name := target.secretName secret, err := a.provider.Get(name) if err != nil { log.Printf("[ADDON-INJECT] credential %q lookup failed: %v", name, err) continue } if vault.IsOAuth(secret.Bytes()) { + if target.pooled { + poolName := target.phantomName + member := target.secretName + oauthPairs, parseErr := buildPooledOAuthPhantomPairs( + poolName, member, secret, "ADDON-INJECT", + func(realRefresh string) { + a.refreshAttr.Tag(realRefresh, member) + }, + ) + if parseErr != nil { + continue + } + pairs = append(pairs, oauthPairs...) + continue + } oauthPairs, parseErr := buildOAuthPhantomPairs(name, secret, "ADDON-INJECT") if parseErr != nil { continue @@ -1230,6 +1407,9 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string) []p pairs = append(pairs, oauthPairs...) continue } + // Static (non-OAuth) credential. Pools reject static members, so + // a pooled target never reaches here; the phantom is keyed on the + // resolved name (== bound name for plain creds). phantom := []byte(PhantomToken(name)) encoded := encodePhantomForPair(phantom) pairs = append(pairs, phantomPair{ diff --git a/internal/proxy/oauth_response.go b/internal/proxy/oauth_response.go index 8e7ca68..8639e7e 100644 --- a/internal/proxy/oauth_response.go +++ b/internal/proxy/oauth_response.go @@ -33,6 +33,44 @@ func oauthPhantomAccess(credName string, realToken ...string) string { return "SLUICE_PHANTOM:" + credName + ".access" } +// poolStablePhantomAccess returns the pool-keyed phantom access token for a +// pooled OAuth credential (Risk R3). resignJWT is deterministic per *real* +// token, so a naive phantom would change every time sluice fails over to a +// different pool member — the agent would see its access token mutate +// underneath it and the "agent never notices" guarantee would break. +// +// Instead we synthesize a structurally valid JWT from a deterministic +// payload keyed on the POOL NAME (stable sub/iss, far-future exp), HMAC'd +// with the same fixed phantomSigningKey. The result is byte-identical for a +// given pool regardless of which member is currently active, so a +// cross-member refresh never changes the token the agent holds. +// +// Static-form fallback: if the consuming agent is verified to treat the +// access token as opaque (never parses it client-side), emitting the plain +// "SLUICE_PHANTOM:.access" string is equally pool-stable and simpler. +// The synthetic-JWT path is primary because resignJWT exists specifically +// because *something* (OpenAI Codex / Hermes) parses the JWT client-side, so +// we must not assume opacity. +func poolStablePhantomAccess(poolName string) string { + // Header: {"alg":"HS256","typ":"JWT"} — fixed, no per-pool variation. + header := base64.RawURLEncoding.EncodeToString( + []byte(`{"alg":"HS256","typ":"JWT"}`), + ) + // Payload: deterministic, keyed on the pool name. exp is a far-future + // fixed timestamp (2100-01-01T00:00:00Z = 4102444800) so client-side + // expiry checks treat it as valid; iat is intentionally omitted so the + // payload is a pure function of the pool name (an iat would make the + // phantom time-varying and break byte-identity). + payload := base64.RawURLEncoding.EncodeToString([]byte( + `{"sub":"sluice-pool:` + poolName + `","iss":"sluice-phantom","exp":4102444800}`, + )) + signingInput := header + "." + payload + mac := hmac.New(sha256.New, phantomSigningKey) + mac.Write([]byte(signingInput)) + sig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + return signingInput + "." + sig +} + // oauthPhantomRefresh returns a phantom for an OAuth refresh token. func oauthPhantomRefresh(credName string, realToken ...string) string { if len(realToken) > 0 && realToken[0] != "" { @@ -68,6 +106,38 @@ func resignJWT(token string) string { return signingInput + "." + newSig } +// extractRequestRefreshToken pulls the `refresh_token` value out of an +// outbound OAuth token-endpoint request body. By the time the Response +// addon runs, pass-2 has already swapped sluice's phantom for the active +// member's REAL refresh token, so this returns the real token value — the +// Risk R1 join key. RFC 6749 §6 mandates application/x-www-form-urlencoded +// for the refresh grant; some non-conformant endpoints accept JSON, so both +// are parsed (form first, JSON fallback). Returns "" when no refresh_token +// field is present (e.g. an authorization_code grant), which the caller +// treats as "not a refresh round-trip, nothing to attribute". +func extractRequestRefreshToken(body []byte, contentType string) string { + if len(body) == 0 { + return "" + } + ct := strings.ToLower(contentType) + if strings.Contains(ct, "application/x-www-form-urlencoded") || !strings.Contains(ct, "json") { + if vals, err := url.ParseQuery(string(body)); err == nil { + if rt := vals.Get("refresh_token"); rt != "" { + return rt + } + } + } + if strings.Contains(ct, "json") || strings.HasPrefix(strings.TrimSpace(string(body)), "{") { + var probe struct { + RefreshToken string `json:"refresh_token"` + } + if err := json.Unmarshal(body, &probe); err == nil && probe.RefreshToken != "" { + return probe.RefreshToken + } + } + return "" +} + // tokenResponse is the parsed result from an OAuth token endpoint. Fields // match the RFC 6749 token response format. type tokenResponse struct { diff --git a/internal/proxy/phantom_pairs.go b/internal/proxy/phantom_pairs.go index 3f27c37..fe40130 100644 --- a/internal/proxy/phantom_pairs.go +++ b/internal/proxy/phantom_pairs.go @@ -214,3 +214,56 @@ func buildOAuthPhantomPairs(name string, secret vault.SecureBytes, logPrefix str } return pairs, nil } + +// buildPooledOAuthPhantomPairs builds phantom pairs for a pooled OAuth +// credential. The phantom strings are keyed on the POOL name so they are +// byte-identical across member switches (Risk R3): the access phantom is +// the pool-stable synthetic JWT, the refresh phantom is the deterministic +// static `SLUICE_PHANTOM:.refresh` string. The injected secrets are +// the ACTIVE MEMBER's real tokens. +// +// onRefreshInject, when non-nil, is called with the member's real refresh +// token so the caller can record the realRefreshToken -> member tag (the +// Risk R1 join key) before the swap injects that token into the outbound +// refresh-grant request body. The caller's raw secret is released before +// returning. On parse failure the secret is released and an error returned. +func buildPooledOAuthPhantomPairs(poolName, member string, secret vault.SecureBytes, logPrefix string, onRefreshInject func(realRefresh string)) ([]phantomPair, error) { + cred, err := vault.ParseOAuth(secret.Bytes()) + secret.Release() + if err != nil { + log.Printf("[%s] parse pooled oauth member %q (pool %q) failed: %v", logPrefix, member, poolName, err) + return nil, err + } + accessSecret := vault.NewSecureBytes(cred.AccessToken) + accessPhantom := []byte(poolStablePhantomAccess(poolName)) + accessEncoded := encodePhantomForPair(accessPhantom) + pairs := []phantomPair{{ + phantom: accessPhantom, + encodedPhantom: accessEncoded, + encodedPhantomLower: encodePhantomLowerForPair(accessEncoded), + secret: accessSecret, + }} + if cred.RefreshToken != "" { + // Record the precise R1 join: this exact real refresh token is + // about to be injected into the outbound refresh-grant request + // for `member`. The token-endpoint response is attributed back + // to `member` by recovering this value from the request body. + if onRefreshInject != nil { + onRefreshInject(cred.RefreshToken) + } + refreshSecret := vault.NewSecureBytes(cred.RefreshToken) + // Pool-stable static refresh phantom (not resignJWT, which would + // be per-real-token and change on every member switch). Refresh + // tokens travel in request bodies, not parsed client-side, so the + // static form is sufficient and inherently pool-stable. + refreshPhantom := []byte("SLUICE_PHANTOM:" + poolName + ".refresh") + refreshEncoded := encodePhantomForPair(refreshPhantom) + pairs = append(pairs, phantomPair{ + phantom: refreshPhantom, + encodedPhantom: refreshEncoded, + encodedPhantomLower: encodePhantomLowerForPair(refreshEncoded), + secret: refreshSecret, + }) + } + return pairs, nil +} diff --git a/internal/proxy/pool_attribution.go b/internal/proxy/pool_attribution.go new file mode 100644 index 0000000..b9a44fd --- /dev/null +++ b/internal/proxy/pool_attribution.go @@ -0,0 +1,86 @@ +package proxy + +import ( + "sync" + "time" +) + +// refreshAttrTTL is how long a real-refresh-token -> member tag is retained. +// An OAuth refresh round-trip (agent POSTs refresh_token, upstream answers +// with rotated tokens) completes in well under a second in practice; a +// generous TTL absorbs slow upstreams and clock skew while still bounding +// the map so a member that never sees its response cannot leak the tag +// forever. The tag is also deleted on first successful lookup. +const refreshAttrTTL = 5 * time.Minute + +// refreshAttribution maps the REAL refresh token sluice injected into an +// outbound OAuth refresh-grant request to the pool member that owns it. +// +// This is the join key for Risk R1: two pool members share one token URL, +// so OAuthIndex.Match is 1:1 and cannot tell which member a token-endpoint +// response belongs to. The injected real refresh token, by contrast, is +// unique per member and is present verbatim in the RFC-6749 refresh-grant +// request body (`refresh_token=`). Recording member-by-injected- +// refresh-token at pass-2 swap time and recovering it on the matching +// response is the only attribution that cannot misfile B's rotated tokens +// under A. The access token is NOT a valid key (it is not echoed in the +// refresh-grant request body), and the client connection is NOT a valid key +// (one HTTP/2 connection multiplexes both members' streams). +type refreshAttribution struct { + mu sync.Mutex + entries map[string]refreshAttrEntry +} + +type refreshAttrEntry struct { + member string + expires time.Time +} + +func newRefreshAttribution() *refreshAttribution { + return &refreshAttribution{entries: make(map[string]refreshAttrEntry)} +} + +// Tag records that the given real refresh token was injected for member. +// Called from the pass-2 phantom swap when the phantom being replaced is a +// pooled credential's `.refresh` phantom. A best-effort opportunistic sweep +// of expired entries keeps the map bounded without a background goroutine. +func (r *refreshAttribution) Tag(realRefreshToken, member string) { + if realRefreshToken == "" || member == "" { + return + } + now := time.Now() + r.mu.Lock() + defer r.mu.Unlock() + if len(r.entries) > 0 { + for k, e := range r.entries { + if now.After(e.expires) { + delete(r.entries, k) + } + } + } + r.entries[realRefreshToken] = refreshAttrEntry{ + member: member, + expires: now.Add(refreshAttrTTL), + } +} + +// Recover returns the member tagged for the given real refresh token and +// removes the entry (single-use: a rotated refresh token will never be +// presented again). Returns ("", false) when no live tag exists — the +// caller MUST fail closed (skip the vault write, never guess) per R1. +func (r *refreshAttribution) Recover(realRefreshToken string) (string, bool) { + if realRefreshToken == "" { + return "", false + } + r.mu.Lock() + defer r.mu.Unlock() + e, ok := r.entries[realRefreshToken] + if !ok { + return "", false + } + delete(r.entries, realRefreshToken) + if time.Now().After(e.expires) { + return "", false + } + return e.member, true +} diff --git a/internal/proxy/pool_phantom_test.go b/internal/proxy/pool_phantom_test.go new file mode 100644 index 0000000..640dcd1 --- /dev/null +++ b/internal/proxy/pool_phantom_test.go @@ -0,0 +1,337 @@ +package proxy + +import ( + "net/http" + "net/url" + "strings" + "sync/atomic" + "testing" + "time" + + mitmproxy "github.com/lqqyt2423/go-mitmproxy/proxy" + "github.com/nemirovsky/sluice/internal/store" + "github.com/nemirovsky/sluice/internal/vault" + uuid "github.com/satori/go.uuid" +) + +func timeFuture() time.Time { return time.Now().Add(5 * time.Minute) } + +// poolMemberCred builds an OAuth credential envelope for a pool member. +func poolMemberCred(t *testing.T, access, refresh string) string { + t.Helper() + c := &vault.OAuthCredential{ + AccessToken: access, + RefreshToken: refresh, + TokenURL: testOAuthTokenURL, + } + data, err := c.Marshal() + if err != nil { + t.Fatalf("marshal oauth cred: %v", err) + } + return string(data) +} + +// setupPoolAddon wires a SluiceAddon with a two-member pool bound to +// auth.example.com. Both members share testOAuthTokenURL (the Risk R1 +// collision shape: two Codex accounts behind one OpenAI token endpoint). +func setupPoolAddon(t *testing.T, poolName, memberA, memberB string) (*SluiceAddon, *addonWritableProvider, *atomic.Pointer[vault.PoolResolver]) { + t.Helper() + + provider := &addonWritableProvider{ + creds: map[string]string{ + memberA: poolMemberCred(t, "A-access-old", "A-refresh-old"), + memberB: poolMemberCred(t, "B-access-old", "B-refresh-old"), + }, + } + + // The agent's binding points at the POOL name, not a member. + bindings := []vault.Binding{{ + Destination: "auth.example.com", + Ports: []int{443}, + Credential: poolName, + }} + resolver, err := vault.NewBindingResolver(bindings) + if err != nil { + t.Fatalf("NewBindingResolver: %v", err) + } + var resolverPtr atomic.Pointer[vault.BindingResolver] + resolverPtr.Store(resolver) + + addon := NewSluiceAddon(WithResolver(&resolverPtr), WithProvider(provider)) + addon.persistDone = make(chan struct{}, 10) + + // Both members are registered in credential_meta (real OAuth creds) + // with the SAME token URL. The pool name is NOT in credential_meta. + metas := []store.CredentialMeta{ + {Name: memberA, CredType: "oauth", TokenURL: testOAuthTokenURL}, + {Name: memberB, CredType: "oauth", TokenURL: testOAuthTokenURL}, + } + addon.UpdateOAuthIndex(metas) + + pool := store.Pool{Name: poolName, Strategy: store.PoolStrategyFailover} + pool.Members = []store.PoolMember{ + {Credential: memberA, Position: 0}, + {Credential: memberB, Position: 1}, + } + var prPtr atomic.Pointer[vault.PoolResolver] + prPtr.Store(vault.NewPoolResolver([]store.Pool{pool}, nil)) + addon.SetPoolResolver(&prPtr) + + return addon, provider, &prPtr +} + +// refreshGrantBody is an RFC-6749 form-encoded refresh grant carrying the +// pool-scoped refresh phantom. Pass-2 swaps the phantom for the active +// member's real refresh token before the request leaves sluice. +func refreshGrantBody(poolName string) []byte { + return []byte("grant_type=refresh_token&refresh_token=SLUICE_PHANTOM:" + poolName + ".refresh") +} + +func newPoolReqRespFlow(client *mitmproxy.ClientConn, reqBody []byte, respBody []byte) *mitmproxy.Flow { + u, _ := url.Parse(testOAuthTokenURL) + reqHdr := make(http.Header) + reqHdr.Set("Content-Type", "application/x-www-form-urlencoded") + respHdr := make(http.Header) + respHdr.Set("Content-Type", "application/json") + return &mitmproxy.Flow{ + Id: uuid.NewV4(), + ConnContext: &mitmproxy.ConnContext{ClientConn: client}, + Request: &mitmproxy.Request{ + Method: "POST", + URL: u, + Header: reqHdr, + Body: reqBody, + }, + Response: &mitmproxy.Response{ + StatusCode: 200, + Header: respHdr, + Body: respBody, + }, + } +} + +// TestR3PoolPhantomByteIdenticalAcrossMemberSwitch asserts the agent-facing +// phantom access token is byte-identical before and after a member switch +// (Risk R3). resignJWT is per-real-token; the pool-stable synthetic JWT must +// not depend on which member is active. +func TestR3PoolPhantomByteIdenticalAcrossMemberSwitch(t *testing.T) { + // Direct determinism check on the synthetic-JWT builder. + p1 := poolStablePhantomAccess("codex_pool") + p2 := poolStablePhantomAccess("codex_pool") + if p1 != p2 { + t.Fatalf("poolStablePhantomAccess not deterministic: %q != %q", p1, p2) + } + if parts := strings.Split(p1, "."); len(parts) != 3 { + t.Fatalf("phantom not a 3-part JWT: %q", p1) + } + if poolStablePhantomAccess("other_pool") == p1 { + t.Fatal("phantom not keyed on pool name (collision across pools)") + } + + // End-to-end: the access phantom the agent receives in a token-endpoint + // response must be identical when member A is active and after failover + // to member B (members have DIFFERENT real access tokens). + addon, _, prPtr := setupPoolAddon(t, "codex_pool", "codexA", "codexB") + client := setupAddonConn(addon, "auth.example.com:443") + + // Member A active. Request body carries A's real refresh token (as if + // pass-2 already swapped it), upstream returns A's rotated tokens. + reqA := []byte("grant_type=refresh_token&refresh_token=A-refresh-old") + addon.refreshAttr.Tag("A-refresh-old", "codexA") + respA := mustJSON(t, map[string]interface{}{ + "access_token": "A-real-access-NEW-aaaaaaaa", + "refresh_token": "A-real-refresh-NEW-aaaaaaaa", + "expires_in": 3600, + }) + fA := newPoolReqRespFlow(client, reqA, respA) + addon.Response(fA) + waitAddonPersist(t, addon) + bodyA := string(fA.Response.Body) + phantomA := poolStablePhantomAccess("codex_pool") + if !strings.Contains(bodyA, phantomA) { + t.Fatalf("member-A response missing pool-stable phantom\n got: %q\nwant substring: %q", bodyA, phantomA) + } + if strings.Contains(bodyA, "A-real-access-NEW-aaaaaaaa") { + t.Fatal("real access token leaked in member-A response") + } + + // Fail member A over: B is now active. + prPtr.Load().MarkCooldown("codexA", timeFuture(), "429") + if got, _ := prPtr.Load().ResolveActive("codex_pool"); got != "codexB" { + t.Fatalf("after cooldown active = %q, want codexB", got) + } + + reqB := []byte("grant_type=refresh_token&refresh_token=B-refresh-old") + addon.refreshAttr.Tag("B-refresh-old", "codexB") + respB := mustJSON(t, map[string]interface{}{ + "access_token": "B-real-access-NEW-bbbbbbbbbbbb", + "refresh_token": "B-real-refresh-NEW-bbbbbbbbbbbb", + "expires_in": 3600, + }) + fB := newPoolReqRespFlow(client, reqB, respB) + addon.Response(fB) + waitAddonPersist(t, addon) + bodyB := string(fB.Response.Body) + phantomB := poolStablePhantomAccess("codex_pool") + + if phantomA != phantomB { + t.Fatalf("R3 violated: phantom changed across member switch\n A: %q\n B: %q", phantomA, phantomB) + } + if !strings.Contains(bodyB, phantomB) { + t.Fatalf("member-B response missing pool-stable phantom\n got: %q", bodyB) + } + if strings.Contains(bodyB, "B-real-access-NEW-bbbbbbbbbbbb") { + t.Fatal("real access token leaked in member-B response") + } +} + +// TestR1RefreshAttributionByInjectedRefreshToken asserts a B-refresh +// response is persisted to B's vault entry, never A's, even though both +// members share one token URL (OAuthIndex.Match is 1:1 and collides). +func TestR1RefreshAttributionByInjectedRefreshToken(t *testing.T) { + addon, provider, prPtr := setupPoolAddon(t, "codex_pool", "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + + // --- Member A round-trip via the real pass-2 path. --- + // A is active; Request() swaps the pool refresh phantom -> A's real + // refresh token AND tags A-refresh-old -> memA. + reqFlow := newTestFlow(client, "POST", testOAuthTokenURL) + reqFlow.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + reqFlow.Request.Body = refreshGrantBody("codex_pool") + addon.Request(reqFlow) + if !strings.Contains(string(reqFlow.Request.Body), "A-refresh-old") { + t.Fatalf("pass-2 did not inject member-A real refresh token; body=%q", reqFlow.Request.Body) + } + + respFlow := newPoolReqRespFlow(client, reqFlow.Request.Body, mustJSON(t, map[string]interface{}{ + "access_token": "A-access-rotated-1", + "refresh_token": "A-refresh-rotated-1", + "expires_in": 3600, + })) + addon.Response(respFlow) + waitAddonPersist(t, addon) + + credA, err := vault.ParseOAuth([]byte(provider.creds["memA"])) + if err != nil { + t.Fatalf("parse memA: %v", err) + } + if credA.RefreshToken != "A-refresh-rotated-1" { + t.Errorf("memA refresh not persisted: got %q want A-refresh-rotated-1", credA.RefreshToken) + } + credB, err := vault.ParseOAuth([]byte(provider.creds["memB"])) + if err != nil { + t.Fatalf("parse memB: %v", err) + } + if credB.RefreshToken != "B-refresh-old" || credB.AccessToken != "B-access-old" { + t.Errorf("memB MUST be untouched by an A-refresh response; got access=%q refresh=%q", + credB.AccessToken, credB.RefreshToken) + } + + // --- Member B round-trip after failover. --- + prPtr.Load().MarkCooldown("memA", timeFuture(), "429") + reqFlowB := newTestFlow(client, "POST", testOAuthTokenURL) + reqFlowB.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + reqFlowB.Request.Body = refreshGrantBody("codex_pool") + addon.Request(reqFlowB) + if !strings.Contains(string(reqFlowB.Request.Body), "B-refresh-old") { + t.Fatalf("pass-2 did not inject member-B real refresh token; body=%q", reqFlowB.Request.Body) + } + respFlowB := newPoolReqRespFlow(client, reqFlowB.Request.Body, mustJSON(t, map[string]interface{}{ + "access_token": "B-access-rotated-1", + "refresh_token": "B-refresh-rotated-1", + "expires_in": 3600, + })) + addon.Response(respFlowB) + waitAddonPersist(t, addon) + + credB2, _ := vault.ParseOAuth([]byte(provider.creds["memB"])) + if credB2.RefreshToken != "B-refresh-rotated-1" { + t.Errorf("memB refresh not persisted after failover: got %q", credB2.RefreshToken) + } + credA2, _ := vault.ParseOAuth([]byte(provider.creds["memA"])) + if credA2.RefreshToken != "A-refresh-rotated-1" { + t.Errorf("memA MUST retain its own rotated token; B-refresh response corrupted A: got %q", + credA2.RefreshToken) + } +} + +// TestR1FailClosedWhenMemberTagMissing asserts that when the owning member +// cannot be recovered from the injected refresh token (no tag), the response +// is still swapped to phantoms (agent safe) but ZERO vault writes occur — no +// guess, no fallback to OAuthIndex.Match. +func TestR1FailClosedWhenMemberTagMissing(t *testing.T) { + addon, provider, _ := setupPoolAddon(t, "codex_pool", "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + + beforeA := provider.creds["memA"] + beforeB := provider.creds["memB"] + + // Request body carries a refresh token that was NEVER tagged (no + // pass-2 ran, or the tag expired). resolveOAuthResponseAttribution + // must fail closed. + resp := newPoolReqRespFlow(client, + []byte("grant_type=refresh_token&refresh_token=untracked-refresh-xyz"), + mustJSON(t, map[string]interface{}{ + "access_token": "should-not-persist-access", + "refresh_token": "should-not-persist-refresh", + "expires_in": 3600, + })) + addon.Response(resp) + + // No persist goroutine should have been scheduled. Give any (buggy) + // async write a chance to land, then assert nothing changed. + select { + case <-addon.persistDone: + t.Fatal("R1 fail-closed violated: a vault persist was scheduled with no member tag") + default: + } + + if provider.creds["memA"] != beforeA { + t.Error("memA vault entry mutated despite fail-closed") + } + if provider.creds["memB"] != beforeB { + t.Error("memB vault entry mutated despite fail-closed") + } + + // Agent must still be protected: real tokens swapped to phantoms. + body := string(resp.Response.Body) + if strings.Contains(body, "should-not-persist-access") || strings.Contains(body, "should-not-persist-refresh") { + t.Errorf("fail-closed must still strip real tokens; body=%q", body) + } + if !strings.Contains(body, poolStablePhantomAccess("codex_pool")) { + t.Errorf("fail-closed response missing pool-stable phantom; body=%q", body) + } +} + +// TestChokepointPlainCredentialUnchanged asserts a non-pool credential +// routes through the chokepoint as an identity (regression guard for +// Important I2: the single chokepoint must not alter plain-cred behavior). +func TestChokepointPlainCredentialUnchanged(t *testing.T) { + addon, _ := setupOAuthAddon(t, "plain_oauth", &vault.OAuthCredential{ + AccessToken: "plain-access-old", + RefreshToken: "plain-refresh-old", + TokenURL: testOAuthTokenURL, + }) + // Attach an (empty) pool resolver so the chokepoint code path runs. + var prPtr atomic.Pointer[vault.PoolResolver] + prPtr.Store(vault.NewPoolResolver(nil, nil)) + addon.SetPoolResolver(&prPtr) + + client := setupAddonConn(addon, "auth.example.com:443") + resp := newTestResponseFlow(client, testOAuthTokenURL, 200, mustJSON(t, map[string]interface{}{ + "access_token": "plain-real-access-NEW", + "refresh_token": "plain-real-refresh-NEW", + "expires_in": 3600, + }), "application/json") + addon.Response(resp) + waitAddonPersist(t, addon) + + body := string(resp.Response.Body) + if strings.Contains(body, "plain-real-access-NEW") { + t.Error("plain cred: real token leaked") + } + // Plain creds keep the legacy per-real-token resign / static phantom. + if !strings.Contains(body, oauthPhantomAccess("plain_oauth", "plain-real-access-NEW")) { + t.Errorf("plain cred phantom changed; body=%q", body) + } +} From 1187d7af6355659e5af631c646ea5cb26882a8ac Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 00:36:59 +0800 Subject: [PATCH 19/49] feat(proxy): auto-failover on 429/401 for credential pools --- cmd/sluice/main.go | 30 ++ .../20260515-credential-pool-failover.md | 12 +- internal/proxy/addon.go | 24 ++ internal/proxy/pool_failover.go | 266 +++++++++++++++ internal/proxy/pool_failover_test.go | 306 ++++++++++++++++++ internal/proxy/server.go | 11 + 6 files changed, 643 insertions(+), 6 deletions(-) create mode 100644 internal/proxy/pool_failover.go create mode 100644 internal/proxy/pool_failover_test.go diff --git a/cmd/sluice/main.go b/cmd/sluice/main.go index 553a05a..47135fb 100644 --- a/cmd/sluice/main.go +++ b/cmd/sluice/main.go @@ -462,6 +462,36 @@ func main() { // Update the proxy's broker reference now that it's created. srv.SetBroker(broker) + // Wire Phase 2 pool failover side effects: durable health write + // + best-effort Telegram notice. The in-memory active-member + // switch already happened synchronously on the response path + // before this callback fires (Risk I1); this only persists for + // restart durability and tells the operator. Everything here runs + // in a detached goroutine so the response/injection path is never + // blocked by a SQLite write or a Telegram round-trip. + failoverBroker := broker + srv.SetOnFailover(func(ev proxy.FailoverEvent) { + go func() { + if db != nil { + reason := fmt.Sprintf("failover:%s", ev.Reason) + if herr := db.SetCredentialHealth(ev.From, "cooldown", ev.Until, reason); herr != nil { + log.Printf("[POOL-FAILOVER] durable health write for %q failed: %v", ev.From, herr) + } + } + if failoverBroker != nil { + msg := fmt.Sprintf("pool `%s` failed over `%s`→`%s` (%s)", + ev.Pool, ev.From, ev.To, ev.Reason) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for _, ch := range failoverBroker.Channels() { + if nerr := ch.Notify(ctx, msg); nerr != nil { + log.Printf("[POOL-FAILOVER] notice via %s failed: %v", ch.Type(), nerr) + } + } + } + }() + }) + // Start all channels. if tgChannel != nil { if err := tgChannel.Start(); err != nil { diff --git a/docs/plans/20260515-credential-pool-failover.md b/docs/plans/20260515-credential-pool-failover.md index dc1816b..eab928a 100644 --- a/docs/plans/20260515-credential-pool-failover.md +++ b/docs/plans/20260515-credential-pool-failover.md @@ -92,12 +92,12 @@ rotate` is an operator override, not the primary mechanism. **Files:** `internal/proxy/addon.go`, `internal/vault/pool.go`, audit logger, telegram + tests -- [ ] Failure classification in `SluiceAddon.Response` for pooled destinations: 429 or 403+`insufficient_quota` → rate-limited; 401 or token-body `invalid_grant`/`invalid_token` → auth-failure; 5xx/other → no-op. -- [ ] Prompt failover: synchronously update in-memory `PoolResolver` health BEFORE the response returns (documented locking discipline); also `SetCredentialHealth(member,'cooldown',now+ttl,reason)` for durability (2s watcher only reconciles). Cooldown TTL consts: rate-limit 60s, auth-fail 300s; lazy recovery in `ResolveActive`. -- [ ] Audit `cred_failover` with `Reason = ":->:<429|403|401|invalid_grant>"`. -- [ ] Telegram best-effort non-blocking notice "pool `` failed over ``→`` ()". -- [ ] No in-flight retry (documented); next request uses new member. -- [ ] Unit tests for classification + synchronous health swap + cooldown TTL/lazy recovery; `go test ./... -timeout 120s` green; build clean; gofumpt. +- [x] Failure classification in `SluiceAddon.Response` for pooled destinations: 429 or 403+`insufficient_quota` → rate-limited; 401 or token-body `invalid_grant`/`invalid_token` → auth-failure; 5xx/other → no-op. (`classifyFailover` in `internal/proxy/pool_failover.go`; token-endpoint body only trusted when the request URL matched the OAuth index.) +- [x] Prompt failover: synchronously update in-memory `PoolResolver` health BEFORE the response returns (documented locking discipline — `MarkCooldown` takes the resolver write lock, `ResolveActive` the read lock); also `SetCredentialHealth(member,'cooldown',now+ttl,reason)` for durability via the detached `onFailover` callback (2s watcher only reconciles, I1). Cooldown TTL consts `vault.RateLimitCooldown` 60s / `vault.AuthFailCooldown` 300s; lazy recovery verified in `ResolveActive` (no scheduler). +- [x] Audit `cred_failover` with `Reason = ":->:<429|403|401|invalid_grant>"`. (Emitted synchronously in `handlePoolFailover` via the addon `auditLog` when configured — Action `cred_failover`, Verdict `failover`, Credential = the cooled-down member; covered by `TestFailoverAuditEvent`.) +- [x] Telegram best-effort non-blocking notice "pool `` failed over ``→`` ()". (Wired in `cmd/sluice/main.go`; the callback detaches the store write + every broker channel `Notify` into its own goroutine so the response path never blocks.) +- [x] No in-flight retry (documented in `handlePoolFailover` doc comment); next request uses new member. +- [x] Unit tests for classification + synchronous health swap + cooldown TTL/lazy recovery + non-blocking notice (`internal/proxy/pool_failover_test.go`); `go test ./... -timeout 120s` green (13/13 ok); build clean; gofumpt clean. ### Task 4: Verify acceptance + docs diff --git a/internal/proxy/addon.go b/internal/proxy/addon.go index 383cb27..3cbb389 100644 --- a/internal/proxy/addon.go +++ b/internal/proxy/addon.go @@ -137,6 +137,14 @@ type SluiceAddon struct { // container. Nil means no post-refresh action. onOAuthRefresh func(credName string) + // onFailover is called after a Phase 2 pool failover has been applied + // in memory (the active-member switch already happened synchronously + // before this fires). The callback owns the durable store write + // (SetCredentialHealth) and the best-effort Telegram notice, and MUST + // NOT block the response path (it dispatches its own goroutine). Nil + // means failover is in-memory only (no durability, no notice). + onFailover func(FailoverEvent) + // persistDone is an optional channel signaled when an async OAuth // token persist goroutine completes. Used by tests to avoid // time.Sleep-based synchronization. Nil in production. @@ -371,6 +379,14 @@ func (a *SluiceAddon) SetOnOAuthRefresh(fn func(credName string)) { a.onOAuthRefresh = fn } +// SetOnFailover configures the callback invoked after a pool failover has +// been applied in memory. The callback is responsible for the durable store +// write and the Telegram notice and must be non-blocking. Safe to leave +// unset (in-memory-only failover). +func (a *SluiceAddon) SetOnFailover(fn func(FailoverEvent)) { + a.onFailover = fn +} + // UpdateOAuthIndex rebuilds the OAuth token URL index from credential // metadata. Called on startup and after credential metadata changes // (e.g. SIGHUP hot-reload). @@ -839,6 +855,14 @@ func (a *SluiceAddon) Response(f *mitmproxy.Flow) { a.processOAuthResponseIfMatching(f) + // Phase 2 pool auto-failover. Runs on every response (the failover + // triggers are non-2xx, so this cannot piggyback on the OAuth 2xx-only + // path). It does not mutate the response — it only updates in-memory + // pool health synchronously and dispatches the durable write + notice — + // so its position relative to OAuth swap / DLP is immaterial. It is a + // cheap no-op for non-pooled destinations and for non-trigger statuses. + a.handlePoolFailover(f) + // Test-only panic injection. Always nil in production. Lets a // regression test exercise the deferred recover above without // having to construct a Flow that triggers a real downstream diff --git a/internal/proxy/pool_failover.go b/internal/proxy/pool_failover.go new file mode 100644 index 0000000..9ea771c --- /dev/null +++ b/internal/proxy/pool_failover.go @@ -0,0 +1,266 @@ +package proxy + +import ( + "bytes" + "fmt" + "log" + "strings" + "time" + + mitmproxy "github.com/lqqyt2423/go-mitmproxy/proxy" + "github.com/nemirovsky/sluice/internal/audit" + "github.com/nemirovsky/sluice/internal/vault" +) + +// failoverClass is the result of classifying an upstream response for a +// pooled destination. +type failoverClass int + +const ( + // failoverNone means the response is not a failover trigger (2xx, 5xx, + // or any 4xx that is not an exhaustion/auth signal). Phase 2 deliberately + // does NOT fail over on 5xx: a server-side error is not evidence that the + // active member's account is exhausted or revoked, and rolling onto the + // next member would just spread a transient upstream outage across every + // account in the pool. 5xx and everything else is a documented no-op. + failoverNone failoverClass = iota + // failoverRateLimited: the active member is quota-exhausted / throttled + // (HTTP 429, or HTTP 403 whose body names quota exhaustion). Short cooldown + // (RateLimitCooldown) because rate limits roll off within the provider's + // window. + failoverRateLimited + // failoverAuthFailure: the active member's token is rejected (HTTP 401, or + // a token-endpoint body of invalid_grant / invalid_token). Long cooldown + // (AuthFailCooldown) because a revoked/expired refresh token will not + // self-heal quickly and retrying it thrashes a broken account. + failoverAuthFailure +) + +// reasonTag returns the short tag embedded in the audit Reason +// (":->:") and the Telegram notice. +func failoverReasonTag(class failoverClass, statusCode int, bodyTag string) string { + switch class { + case failoverRateLimited: + if statusCode == 403 { + return "403" + } + return "429" + case failoverAuthFailure: + if statusCode == 401 { + return "401" + } + if bodyTag != "" { + return bodyTag + } + return "invalid_grant" + default: + return "" + } +} + +// classifyFailover inspects a response for a pooled destination and decides +// whether it is a failover trigger. +// +// Classification rules (status code is the primary signal; the body is only +// consulted for the documented 403/token-endpoint cases): +// +// - HTTP 429 -> rate-limited +// - HTTP 403 with body insufficient_quota / quota -> rate-limited +// - HTTP 401 -> auth-failure +// - token-endpoint body invalid_grant/invalid_token -> auth-failure +// - 2xx, 5xx, and everything else -> no-op (documented) +// +// isTokenEndpoint is true when the request URL matched the OAuth token-URL +// index (so a body classification is only trusted on an actual token +// endpoint, not on an arbitrary API 4xx that happens to echo the string +// "invalid_grant" in unrelated prose). bodyTag returns the matched body +// token (for the audit reason) when the decision came from the body. +func classifyFailover(statusCode int, body []byte, isTokenEndpoint bool) (class failoverClass, bodyTag string) { + switch { + case statusCode == 429: + return failoverRateLimited, "" + case statusCode == 401: + return failoverAuthFailure, "" + case statusCode == 403: + if bodyContainsAny(body, "insufficient_quota", "quota_exceeded", "quota exhausted", "rate_limit_exceeded") { + return failoverRateLimited, "" + } + return failoverNone, "" + } + // Non-4xx-status path. Only a real token-endpoint body may be classified + // (invalid_grant/invalid_token), and only when the status is not a 2xx + // success. A 2xx token response is a healthy refresh, never a failover. + if isTokenEndpoint && (statusCode < 200 || statusCode > 299) { + if bodyContainsAny(body, "invalid_grant") { + return failoverAuthFailure, "invalid_grant" + } + if bodyContainsAny(body, "invalid_token") { + return failoverAuthFailure, "invalid_token" + } + } + return failoverNone, "" +} + +// bodyContainsAny reports whether body contains any of the substrings, +// case-insensitively. Bodies are bounded by maxProxyBody upstream so an +// in-memory scan is safe. +func bodyContainsAny(body []byte, subs ...string) bool { + if len(body) == 0 { + return false + } + lower := bytes.ToLower(body) + for _, s := range subs { + if bytes.Contains(lower, []byte(strings.ToLower(s))) { + return true + } + } + return false +} + +// FailoverEvent describes a completed pool failover. It is handed to the +// optional onFailover callback (store durability write + Telegram notice) +// configured via SetOnFailover. +type FailoverEvent struct { + Pool string + From string + To string + Reason string // short tag: 429 | 403 | 401 | invalid_grant | invalid_token + Class failoverClass + Until time.Time // member cooldown expiry just applied +} + +// poolForResponse maps a response's CONNECT destination back to a pooled +// binding and returns the pool name + the member that was active for this +// request. Returns ok=false when the destination is not bound to a pool. +func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember string, pr *vault.PoolResolver, ok bool) { + if a.poolResolver == nil || a.resolver == nil { + return "", "", nil, false + } + pr = a.poolResolver.Load() + if pr == nil { + return "", "", nil, false + } + res := a.resolver.Load() + if res == nil { + return "", "", nil, false + } + host, port := connectTargetForFlow(a, f) + if host == "" { + return "", "", nil, false + } + // The Response addon path is HTTP/HTTPS/HTTP2 (gRPC). Bindings without + // an explicit protocol list match any protocol; pass "https" so a + // protocol-scoped binding still resolves on the common case. + for _, boundName := range res.CredentialsForDestination(host, port, "https") { + if !pr.IsPool(boundName) { + continue + } + member, mok := pr.ResolveActive(boundName) + if !mok || member == "" { + continue + } + return boundName, member, pr, true + } + return "", "", nil, false +} + +// handlePoolFailover is the Phase 2 entry point invoked from Response for +// every response. It is a cheap no-op for the overwhelming common case +// (destination is not pooled, or the response is a success / 5xx). When the +// response classifies as a failover trigger for the active pool member it: +// +// 1. Synchronously marks the active member in cooldown in the in-memory +// PoolResolver BEFORE this function returns, so the very next request +// resolves to the next member. This is the I1 fix: the active-member +// switch must NOT wait on the 2s data-version watcher. The store write +// below only reconciles for durability across restarts. +// 2. Computes the next active member (post-cooldown) for the audit/notice. +// 3. Hands a FailoverEvent to the onFailover callback (async, best-effort): +// the callback persists SetCredentialHealth to the store and fires the +// Telegram notice. The callback MUST NOT block the response path. +// +// No in-flight retry: the triggering request still returns its own upstream +// error to the agent unmodified. The agent (or its SDK) retries on its own +// schedule, and that retry resolves to the freshly-activated next member. +// Transparent in-flight retry is intentionally out of scope (see the plan's +// "Out of scope" section) — buffering and replaying an arbitrary upstream +// request body/headers safely is a separate, larger change. +func (a *SluiceAddon) handlePoolFailover(f *mitmproxy.Flow) { + if f == nil || f.Response == nil || f.Request == nil { + return + } + pool, from, pr, ok := a.poolForResponse(f) + if !ok { + return + } + + isTokenEndpoint := false + if idx := a.oauthIndex.Load(); idx != nil { + _, isTokenEndpoint = idx.Match(f.Request.URL) + } + + class, bodyTag := classifyFailover(f.Response.StatusCode, f.Response.Body, isTokenEndpoint) + if class == failoverNone { + return + } + + ttl := vault.RateLimitCooldown + if class == failoverAuthFailure { + ttl = vault.AuthFailCooldown + } + until := time.Now().Add(ttl) + tag := failoverReasonTag(class, f.Response.StatusCode, bodyTag) + + // (1) Synchronous in-memory health update BEFORE returning (Risk I1). + // MarkCooldown takes the resolver's write lock; ResolveActive takes the + // read lock, so the next request observes the new active member with no + // dependency on the store-reconcile watcher. + pr.MarkCooldown(from, until, tag) + + // (2) Recompute the active member now that `from` is cooling down. If + // every member is in cooldown ResolveActive degrades to the + // soonest-recovering one (possibly `from` itself); the notice still + // records the attempted transition honestly. + to := from + if next, nok := pr.ResolveActive(pool); nok && next != "" { + to = next + } + + log.Printf("[POOL-FAILOVER] pool %q: %s -> %s (%s); member %q cooling down until %s", + pool, from, to, tag, from, until.Format(time.RFC3339)) + + // Audit: emit a cred_failover action with the documented Reason shape + // ":->:". Safe to call with a nil auditLog. The + // blake3 hash chain is appended synchronously by FileLogger.Log; the + // write is local and fast (mirrors logDLPAudit on the same path), so it + // does not warrant detaching like the store/Telegram side effects. + if a.auditLog != nil { + host, port := connectTargetForFlow(a, f) + evt := audit.Event{ + Destination: host, + Port: port, + Protocol: "https", + Verdict: "failover", + Action: "cred_failover", + Reason: fmt.Sprintf("%s:%s->%s:%s", pool, from, to, tag), + Credential: from, + } + if err := a.auditLog.Log(evt); err != nil { + log.Printf("[POOL-FAILOVER] audit log error: %v", err) + } + } + + // (3) Durability + Telegram via the callback. The callback is + // responsible for being non-blocking (it runs the store write and the + // Telegram send in its own goroutine); we still guard with a nil check. + if a.onFailover != nil { + a.onFailover(FailoverEvent{ + Pool: pool, + From: from, + To: to, + Reason: tag, + Class: class, + Until: until, + }) + } +} diff --git a/internal/proxy/pool_failover_test.go b/internal/proxy/pool_failover_test.go new file mode 100644 index 0000000..03c07d5 --- /dev/null +++ b/internal/proxy/pool_failover_test.go @@ -0,0 +1,306 @@ +package proxy + +import ( + "encoding/json" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" + + mitmproxy "github.com/lqqyt2423/go-mitmproxy/proxy" + "github.com/nemirovsky/sluice/internal/audit" + "github.com/nemirovsky/sluice/internal/vault" + uuid "github.com/satori/go.uuid" +) + +// newPoolRespFlow builds a response flow for the pooled destination with an +// arbitrary status code and body. The request URL is the OAuth token URL so +// the token-endpoint body classification path is exercised. +func newPoolRespFlow(client *mitmproxy.ClientConn, status int, respBody []byte) *mitmproxy.Flow { + u, _ := url.Parse(testOAuthTokenURL) + reqHdr := make(http.Header) + respHdr := make(http.Header) + respHdr.Set("Content-Type", "application/json") + return &mitmproxy.Flow{ + Id: uuid.NewV4(), + ConnContext: &mitmproxy.ConnContext{ClientConn: client}, + Request: &mitmproxy.Request{ + Method: "POST", + URL: u, + Header: reqHdr, + Body: []byte("grant_type=refresh_token&refresh_token=x"), + }, + Response: &mitmproxy.Response{ + StatusCode: status, + Header: respHdr, + Body: respBody, + }, + } +} + +// TestClassifyFailover is the classification truth table from the plan. +func TestClassifyFailover(t *testing.T) { + cases := []struct { + name string + status int + body string + tokenEP bool + wantClass failoverClass + wantTagPart string + }{ + {"429 rate limited", 429, "", false, failoverRateLimited, "429"}, + {"403 insufficient_quota", 403, `{"error":"insufficient_quota"}`, false, failoverRateLimited, "403"}, + {"403 quota_exceeded", 403, `{"error":{"code":"quota_exceeded"}}`, false, failoverRateLimited, "403"}, + {"403 unrelated -> noop", 403, `{"error":"forbidden: bad scope"}`, false, failoverNone, ""}, + {"401 auth failure", 401, "", false, failoverAuthFailure, "401"}, + {"token-endpoint invalid_grant", 400, `{"error":"invalid_grant"}`, true, failoverAuthFailure, "invalid_grant"}, + {"token-endpoint invalid_token", 400, `{"error":"invalid_token"}`, true, failoverAuthFailure, "invalid_token"}, + {"invalid_grant but NOT token endpoint -> noop", 400, `{"error":"invalid_grant"}`, false, failoverNone, ""}, + {"200 success -> noop", 200, `{"access_token":"x"}`, true, failoverNone, ""}, + {"500 server error -> noop", 500, `oops`, false, failoverNone, ""}, + {"502 -> noop", 502, ``, false, failoverNone, ""}, + {"404 -> noop", 404, ``, false, failoverNone, ""}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + class, bodyTag := classifyFailover(c.status, []byte(c.body), c.tokenEP) + if class != c.wantClass { + t.Fatalf("class = %v, want %v", class, c.wantClass) + } + if c.wantClass == failoverNone { + return + } + tag := failoverReasonTag(class, c.status, bodyTag) + if tag != c.wantTagPart { + t.Fatalf("reason tag = %q, want %q", tag, c.wantTagPart) + } + }) + } +} + +// TestFailoverSynchronousHealthSwap asserts that after a 429 response on a +// pooled destination, the very NEXT ResolveActive call returns the next +// member — without any reliance on the 2s store-reconcile watcher (Risk I1). +func TestFailoverSynchronousHealthSwap(t *testing.T) { + addon, _, prPtr := setupPoolAddon(t, "codex_pool", "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + + pr := prPtr.Load() + if got, _ := pr.ResolveActive("codex_pool"); got != "memA" { + t.Fatalf("pre-failover active = %q, want memA", got) + } + + var got FailoverEvent + gotCalled := make(chan struct{}, 1) + addon.SetOnFailover(func(ev FailoverEvent) { + got = ev + gotCalled <- struct{}{} + }) + + addon.Response(newPoolRespFlow(client, 429, []byte(`{"error":"rate_limited"}`))) + + // Synchronous: by the time Response returns the swap is already done. + if active, _ := pr.ResolveActive("codex_pool"); active != "memB" { + t.Fatalf("post-failover active = %q, want memB (synchronous swap, no watcher)", active) + } + + select { + case <-gotCalled: + case <-time.After(2 * time.Second): + t.Fatal("onFailover callback not invoked") + } + if got.Pool != "codex_pool" || got.From != "memA" || got.To != "memB" || got.Reason != "429" { + t.Fatalf("FailoverEvent = %+v, want pool=codex_pool from=memA to=memB reason=429", got) + } + if got.Class != failoverRateLimited { + t.Fatalf("class = %v, want rate-limited", got.Class) + } +} + +// TestFailoverCooldownTTLAndLazyRecovery asserts the documented cooldown +// durations and that an expired cooldown makes the member eligible again +// with no scheduler (lazy recovery in ResolveActive). +func TestFailoverCooldownTTLAndLazyRecovery(t *testing.T) { + // Rate-limit TTL = 60s, auth-fail TTL = 300s (named consts). + if vault.RateLimitCooldown != 60*time.Second { + t.Fatalf("RateLimitCooldown = %v, want 60s", vault.RateLimitCooldown) + } + if vault.AuthFailCooldown != 300*time.Second { + t.Fatalf("AuthFailCooldown = %v, want 300s", vault.AuthFailCooldown) + } + + addon, _, prPtr := setupPoolAddon(t, "codex_pool", "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + pr := prPtr.Load() + + // Auth failure (401) -> memA cools down for AuthFailCooldown. + before := time.Now() + addon.Response(newPoolRespFlow(client, 401, nil)) + until, cooling := pr.CooldownUntil("memA") + if !cooling { + t.Fatal("memA should be cooling down after 401") + } + gotTTL := until.Sub(before) + // Allow generous slack for scheduling jitter. + if gotTTL < vault.AuthFailCooldown-5*time.Second || gotTTL > vault.AuthFailCooldown+5*time.Second { + t.Fatalf("auth-fail cooldown TTL = %v, want ~%v", gotTTL, vault.AuthFailCooldown) + } + + // Lazy recovery: force the cooldown to the past; ResolveActive must + // treat memA as eligible again with no background scheduler involved. + pr.MarkCooldown("memA", time.Now().Add(-time.Second), "expired") + if active, _ := pr.ResolveActive("codex_pool"); active != "memA" { + t.Fatalf("after expiry active = %q, want memA (lazy recovery)", active) + } +} + +// TestFailoverNoopForNonPooledAndSuccess asserts the failover path is a +// no-op for a successful response and never invokes the callback. +func TestFailoverNoopForSuccessfulResponse(t *testing.T) { + addon, _, prPtr := setupPoolAddon(t, "codex_pool", "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + + called := false + addon.SetOnFailover(func(FailoverEvent) { called = true }) + + addon.Response(newPoolRespFlow(client, 200, []byte(`{"access_token":"ok"}`))) + if called { + t.Fatal("onFailover invoked for a 200 response") + } + if active, _ := prPtr.Load().ResolveActive("codex_pool"); active != "memA" { + t.Fatalf("active = %q, want memA unchanged on success", active) + } + + // 5xx is also a documented no-op. + addon.Response(newPoolRespFlow(client, 503, []byte(`upstream down`))) + if called { + t.Fatal("onFailover invoked for a 5xx response (must be no-op)") + } +} + +// TestFailoverNoticeNonBlocking asserts the response path is not blocked by +// a slow onFailover callback. The callback in production dispatches its own +// goroutine; this test models a callback whose own work is slow and verifies +// Response returns promptly regardless (the addon does not goroutine for the +// callback, so the callback contract is "be non-blocking yourself" — here we +// assert Response itself never waits on callback-internal work by having the +// callback spawn the slow part and return immediately, mirroring main.go). +func TestFailoverNoticeNonBlocking(t *testing.T) { + addon, _, _ := setupPoolAddon(t, "codex_pool", "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + + done := make(chan struct{}) + addon.SetOnFailover(func(FailoverEvent) { + // Production wiring (main.go) detaches the slow store/Telegram + // work into a goroutine and returns immediately. Model that. + go func() { + time.Sleep(500 * time.Millisecond) + close(done) + }() + }) + + start := time.Now() + addon.Response(newPoolRespFlow(client, 429, nil)) + elapsed := time.Since(start) + if elapsed > 200*time.Millisecond { + t.Fatalf("Response blocked %v on failover callback; must be non-blocking", elapsed) + } + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("detached failover work never completed") + } +} + +// TestFailoverNonPooledDestinationIgnored asserts a response whose +// destination is NOT bound to a pool never triggers failover. +func TestFailoverNonPooledDestinationIgnored(t *testing.T) { + addon, _, _ := setupPoolAddon(t, "codex_pool", "memA", "memB") + // Connect to a destination with no pooled binding. + client := setupAddonConn(addon, "unrelated.example.com:443") + + called := false + addon.SetOnFailover(func(FailoverEvent) { called = true }) + + f := newPoolRespFlow(client, 429, nil) + addon.Response(f) + if called { + t.Fatal("onFailover invoked for a non-pooled destination") + } +} + +// TestFailoverAuditEvent asserts a cred_failover audit event is emitted with +// the documented Reason shape ":->:". +func TestFailoverAuditEvent(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "audit.log") + logger, err := audit.NewFileLogger(logPath) + if err != nil { + t.Fatalf("NewFileLogger: %v", err) + } + t.Cleanup(func() { _ = logger.Close() }) + + addon, _, _ := setupPoolAddon(t, "codex_pool", "memA", "memB") + addon.auditLog = logger + client := setupAddonConn(addon, "auth.example.com:443") + + addon.Response(newPoolRespFlow(client, 429, []byte(`{"error":"rate_limited"}`))) + + if err := logger.Close(); err != nil { + t.Fatalf("logger close: %v", err) + } + data, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("read audit log: %v", err) + } + + var found bool + for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") { + if line == "" { + continue + } + var evt audit.Event + if uerr := json.Unmarshal([]byte(line), &evt); uerr != nil { + t.Fatalf("unmarshal audit line %q: %v", line, uerr) + } + if evt.Action != "cred_failover" { + continue + } + found = true + if evt.Reason != "codex_pool:memA->memB:429" { + t.Fatalf("audit Reason = %q, want %q", evt.Reason, "codex_pool:memA->memB:429") + } + if evt.Verdict != "failover" { + t.Fatalf("audit Verdict = %q, want failover", evt.Verdict) + } + if evt.Credential != "memA" { + t.Fatalf("audit Credential = %q, want memA", evt.Credential) + } + } + if !found { + t.Fatalf("no cred_failover audit event found in:\n%s", data) + } +} + +// TestPoolForResponseResolvesActiveMember sanity-checks the destination -> +// pool reverse mapping used by handlePoolFailover. +func TestPoolForResponseResolvesActiveMember(t *testing.T) { + addon, _, prPtr := setupPoolAddon(t, "codex_pool", "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + f := newPoolRespFlow(client, 429, nil) + + pool, member, pr, ok := addon.poolForResponse(f) + if !ok { + t.Fatal("poolForResponse: expected a pooled destination match") + } + if pool != "codex_pool" || member != "memA" { + t.Fatalf("got pool=%q member=%q, want codex_pool/memA", pool, member) + } + if pr != prPtr.Load() { + t.Fatal("poolForResponse returned a different resolver than the live one") + } +} diff --git a/internal/proxy/server.go b/internal/proxy/server.go index bac2be8..9006893 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -2776,6 +2776,17 @@ func (s *Server) SetOnOAuthRefresh(fn func(credName string)) { } } +// SetOnFailover configures a callback on the addon that is invoked after a +// pool failover has been applied in memory. The active-member switch has +// already happened synchronously by the time this fires; the callback owns +// the durable SetCredentialHealth store write and the best-effort Telegram +// notice, and must not block the response path. +func (s *Server) SetOnFailover(fn func(FailoverEvent)) { + if s.addon != nil { + s.addon.SetOnFailover(fn) + } +} + // EnginePtr returns the shared atomic engine pointer. The Telegram command // handler uses this to read and mutate the same engine as the proxy, avoiding // split-brain windows during SIGHUP reloads. From 54a64bd68a2e273074b1171db02bdfbfdfd92def Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 00:49:43 +0800 Subject: [PATCH 20/49] docs(plans): complete credential-pool-failover; move to completed --- CLAUDE.md | 34 +++++++++++++++++++ .../20260515-credential-pool-failover.md | 8 +++-- 2 files changed, 39 insertions(+), 3 deletions(-) rename docs/plans/{ => completed}/20260515-credential-pool-failover.md (88%) diff --git a/CLAUDE.md b/CLAUDE.md index 0751575..7ef5ff4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -212,6 +212,40 @@ Extends phantom swap to handle OAuth credentials bidirectionally. Static credent - `internal/store/migrations/000002_credential_meta.up.sql` -- Schema for credential metadata - `internal/store/migrations/000003_binding_env_var.up.sql` -- `env_var` column on bindings +### Credential pools and auto-failover + +A **credential pool** lets one phantom identity the agent sees be backed by **N real OAuth credentials**. The agent always holds a single pool-scoped phantom pair (`SLUICE_PHANTOM:.access` / `SLUICE_PHANTOM:.refresh`); sluice maps it to the *currently active member's* real tokens at injection time and persists refreshed tokens back to the member that issued them. Primary use case: two OpenAI Codex OAuth accounts behind one agent so quota exhaustion on one account transparently rolls onto the other. Pool members must be `oauth` credentials — `static` members are rejected. `cred remove` errors on a credential that is a live pool member. + +**CLI:** + +``` +sluice pool create --member [--member ...] # ordered members; rejects static; namespace must not collide with a credential name +sluice pool list +sluice pool status # active member, per-member health (healthy / cooldown + recover-at + reason) +sluice pool rotate # operator override: advance the active member manually +sluice pool remove +``` + +Auto-failover on 429/401 is the primary mechanism; `pool rotate` is an operator override. Pool and credential namespaces are mutually exclusive at create time. + +**Data model (migration `000006_credential_pools`):** three tables — `credential_pools` (pool name, strategy reserved `failover`), `credential_pool_members` (ordered membership, pool→credential FK), `credential_health` (per-member state `healthy|cooldown`, `recover_at`, reason) — with CHECK constraints. Store API lives in `internal/store/pools.go`. `reloadAll` loads pool + health into an atomic-pointer-swapped `PoolResolver` (`internal/vault/pool.go`), rewired into the addon via `srv.StorePool`/`SetPoolResolver` on SIGHUP and the 2s data-version watcher. + +**Phase 1 — phantom indirection (pool phantom → active member):** + +- **Single chokepoint (I2):** every `binding.Credential` / `OAuthIndex.Has` / `extractInjectableSecret` / persist consumer on the HTTP/HTTPS OAuth path routes through `PoolResolver.ResolveActive` (`resolveInjectionTarget` for pass-1 header + pass-2 phantom swap; `resolveOAuthResponseAttribution` for the response/persist path). `idx.Has` is always called with the resolved member name, never the pool. Plain (non-pool) credentials pass through `ResolveActive` unchanged. SSH/mail/QUIC are non-OAuth and out of scope. +- **Active-member selection:** healthy or expired-cooldown members first, by configured position; if all members are in cooldown, the soonest-recovering member is returned with a WARNING (degrade, never hard-fail). Recovery is lazy — evaluated in `ResolveActive`, no scheduler. +- **R1 refresh-token attribution / fail-closed:** when pass-2 swaps `SLUICE_PHANTOM:.refresh`, sluice records `realRefreshToken → member` in a short-TTL map. On the token-endpoint response it recovers the member by that real refresh token and persists to that member (`persistAddonOAuthTokens(member, ...)`, singleflight key `"persist:"+member`). The join key is the real **refresh** token sluice injected — never the access token, the client connection, or `OAuthIndex.Match` (two pooled members share `auth.openai.com`'s token URL and collide there). If the member is unrecoverable: WARNING + skip the vault write, never guess. Rotating refresh tokens are single-use, so a mis-attributed write would brick both accounts — fail-closed is mandatory. +- **R3 pool-stable phantom JWT:** Codex access tokens are JWTs and the per-real-token `resignJWT` would emit a *different* phantom after every cross-member refresh, breaking the "agent never notices" guarantee. Pooled OAuth `oauthPhantomAccess`/`resignJWT` instead build the phantom JWT from a deterministic synthetic payload keyed on the **pool name** (stable `sub`/`iss`, far-future `exp`), HMAC'd with the existing fixed key — byte-identical across member switches while still a structurally valid JWT. Static-form fallback (`SLUICE_PHANTOM:.access`) is documented for the case where the agent is verified to treat the access token as opaque. + +**Phase 2 — auto-failover on 429 / 401:** + +- **Classification** (`classifyFailover` in `internal/proxy/pool_failover.go`, called from `SluiceAddon.Response` for pooled destinations): `429` or `403 + insufficient_quota` → rate-limited; `401` or token-body `invalid_grant` / `invalid_token` → auth-failure; `5xx` / other → no-op. The token-endpoint body is only trusted when the request URL matched the OAuth index. +- **Synchronous in-memory failover (I1):** health is updated in-process *before* the response returns — `MarkCooldown` takes the resolver write lock, `ResolveActive` the read lock — so the active-member switch never waits on the 2s data-version watcher (which only reconciles). A detached `onFailover` callback also writes `SetCredentialHealth(member, 'cooldown', now+ttl, reason)` for durability. Cooldown TTLs: `vault.RateLimitCooldown` = 60s, `vault.AuthFailCooldown` = 300s. No in-flight retry — the next request uses the new member. +- **Audit:** a `cred_failover` event (Verdict `failover`, Credential = the cooled-down member) with `Reason = ":->:<429|403|401|invalid_grant>"`, emitted synchronously in `handlePoolFailover`. +- **Telegram:** a best-effort non-blocking notice "pool `` failed over ``→`` ()"; the store write and every broker channel `Notify` are detached into their own goroutine so the response path never blocks. + +**Key files:** `internal/store/migrations/000006_credential_pools.{up,down}.sql`, `internal/store/pools.go`, `internal/vault/pool.go`, `internal/proxy/pool_failover.go`, `cmd/sluice/pool.go`, plus the pool routing in `internal/proxy/addon.go` / `internal/proxy/oauth_response.go`. + ### Protocol-specific handling | Protocol | Credential injection | Content inspection | Policy granularity | diff --git a/docs/plans/20260515-credential-pool-failover.md b/docs/plans/completed/20260515-credential-pool-failover.md similarity index 88% rename from docs/plans/20260515-credential-pool-failover.md rename to docs/plans/completed/20260515-credential-pool-failover.md index eab928a..dd97f89 100644 --- a/docs/plans/20260515-credential-pool-failover.md +++ b/docs/plans/completed/20260515-credential-pool-failover.md @@ -101,9 +101,11 @@ rotate` is an operator override, not the primary mechanism. ### Task 4: Verify acceptance + docs -- [ ] full `go test ./... -timeout 120s`; e2e `go test -tags=e2e ./e2e/ -count=1 -timeout=300s` (if e2e cannot run here, state so explicitly in the progress file, do not silently skip). -- [ ] update CLAUDE.md credential-pool/failover notes. -- [ ] move plan to `docs/plans/completed/`. +- [x] full `go test ./... -timeout 120s` — 2548 tests passed, 13/13 packages `ok`, 0 FAIL/panic. e2e `go test -tags=e2e ./e2e/ -count=1 -timeout=300s` — 64 e2e tests passed (68 RUN incl. subtests, 66 PASS lines), 0 FAIL/panic, ~130s. Container tags (`e2e && linux/darwin`) skipped: no local Docker Compose / Apple Container harness in this validation step; CI runs them via the dedicated e2e workflows. +- [x] update CLAUDE.md credential-pool/failover notes — added `### Credential pools and auto-failover` (pool concept, `sluice pool` CLI, migration 000006 tables, Phase 1 chokepoint + R1 fail-closed attribution + R3 pool-stable JWT, Phase 2 classification + synchronous failover + `cred_failover` audit + Telegram notice + cooldown TTLs). +- [x] move plan to `docs/plans/completed/`. + +> **E2e gap (Testing Strategy item, honestly noted):** the dedicated two-fake-OAuth-upstreams pool-failover e2e (assert A used until 429 → switch to B → B's refreshed tokens land in B's vault not A's → phantom access JWT byte-identical across failover) was **not added**. Standing up a full e2e harness with JWT-issuing fake token endpoints, pool wiring through SOCKS5 + MITM, and 429-then-switch assertions is a substantial new harness beyond the reasonable scope of this verify+docs task. The failover behavior it would cover is already exercised by unit tests added in Tasks 2 & 3 (`internal/vault/pool_test.go`, `internal/proxy/pool_failover_test.go`): R1 collision/fail-closed, R3 phantom byte-identity, classification, synchronous health swap, cooldown TTL/lazy recovery, non-blocking notice. The existing non-container e2e suite was run in full (64 tests, all passing). Recommend tracking the pool-failover e2e as follow-up future work. ## Out of scope / future work From 3a4104dfd287e55751edcebc104af3d011f01eaa Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 01:06:10 +0800 Subject: [PATCH 21/49] fix(proxy): address comprehensive review findings --- CLAUDE.md | 5 +- README.md | 25 ++++++ cmd/sluice/main.go | 4 +- internal/channel/broker.go | 8 -- internal/proxy/addon.go | 23 +---- internal/proxy/pool_failover.go | 17 ++++ internal/proxy/pool_failover_test.go | 124 ++++++++++++++++++++++++++- internal/proxy/server.go | 13 +++ internal/store/pools.go | 5 +- internal/vault/pool.go | 58 ++++++++++++- internal/vault/pool_test.go | 70 +++++++++++++++ 11 files changed, 316 insertions(+), 36 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 7ef5ff4..21b3ad2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -240,9 +240,12 @@ Auto-failover on 429/401 is the primary mechanism; `pool rotate` is an operator **Phase 2 — auto-failover on 429 / 401:** - **Classification** (`classifyFailover` in `internal/proxy/pool_failover.go`, called from `SluiceAddon.Response` for pooled destinations): `429` or `403 + insufficient_quota` → rate-limited; `401` or token-body `invalid_grant` / `invalid_token` → auth-failure; `5xx` / other → no-op. The token-endpoint body is only trusted when the request URL matched the OAuth index. +- **Pool attribution for the response** (`poolForResponse`): a response is attributed to a pool either (a) when the flow's CONNECT host has a pooled binding (the API-host 429/403 path), **or** (b) when the request URL matches the OAuth token-URL index for a credential that is a pool member (the token-endpoint 401 / `invalid_grant` path). Case (b) is essential: an OAuth refresh hits the credential's token-URL host (e.g. `auth.openai.com`), which has no pool binding — only the API host (e.g. `api.openai.com`) does — so without the token-URL index match the token-endpoint classification would be dead code for the Codex deployment. `idx.Match` is strict 1:1 token_url→credential, so case (b) cools the exact member whose refresh token was injected. - **Synchronous in-memory failover (I1):** health is updated in-process *before* the response returns — `MarkCooldown` takes the resolver write lock, `ResolveActive` the read lock — so the active-member switch never waits on the 2s data-version watcher (which only reconciles). A detached `onFailover` callback also writes `SetCredentialHealth(member, 'cooldown', now+ttl, reason)` for durability. Cooldown TTLs: `vault.RateLimitCooldown` = 60s, `vault.AuthFailCooldown` = 300s. No in-flight retry — the next request uses the new member. +- **Reload does not resurrect a cooled member:** because the durable `SetCredentialHealth` write is detached and best-effort, any reload (SIGHUP or the 2s data-version watcher firing on *any* unrelated DB write) rebuilds the resolver from store rows alone via `NewPoolResolver`. `Server.StorePool` therefore calls `PoolResolver.MergeLiveCooldowns(prev)` to carry forward still-active in-memory cooldowns from the resolver being replaced before the atomic swap. The merge is monotonic (a live cooldown is never shortened/erased by an unrelated reload) and drops cooldowns for credentials no longer in any pool. - **Audit:** a `cred_failover` event (Verdict `failover`, Credential = the cooled-down member) with `Reason = ":->:<429|403|401|invalid_grant>"`, emitted synchronously in `handlePoolFailover`. -- **Telegram:** a best-effort non-blocking notice "pool `` failed over ``→`` ()"; the store write and every broker channel `Notify` are detached into their own goroutine so the response path never blocks. +- **Telegram:** a best-effort non-blocking notice "pool failed over -> ()" (plain text — `TelegramChannel.Notify` sends with no parse mode); the store write and every broker channel `Notify` are detached into their own goroutine so the response path never blocks. +- **Known limitation: streaming responses bypass failover.** `handlePoolFailover` runs only from the buffered `Response` addon. Server-Sent Events (`text/event-stream`) and bodies above `StreamLargeBodies` set `f.Stream=true`, which skips the `Response` callback (same path as the Response DLP streaming bypass documented above), so a 429/401 delivered on a streamed response does not trigger failover. Practical impact is low because quota/auth error bodies are tiny JSON, not streamed; the next non-streamed request to the API host still fails over normally. **Key files:** `internal/store/migrations/000006_credential_pools.{up,down}.sql`, `internal/store/pools.go`, `internal/vault/pool.go`, `internal/proxy/pool_failover.go`, `cmd/sluice/pool.go`, plus the pool routing in `internal/proxy/addon.go` / `internal/proxy/oauth_response.go`. diff --git a/README.md b/README.md index 63f9e21..c411e02 100644 --- a/README.md +++ b/README.md @@ -286,6 +286,31 @@ github_pat static api.github.com **Supported response formats:** Both `application/json` and `application/x-www-form-urlencoded` token responses per RFC 6749. +## Credential Pools + +A credential pool lets a single phantom identity the agent sees be backed by **N real OAuth credentials**, with sluice auto-failing-over to the next member when the upstream rejects the active one. Primary use case: two OpenAI Codex OAuth accounts driven by one agent, so quota exhaustion on one account transparently rolls onto the other. The agent always holds one pool-scoped phantom pair (`SLUICE_PHANTOM:.access` / `.refresh`); sluice maps it to the currently active member's real token at injection time and persists refreshed tokens back to the member that issued them. + +```bash +sluice pool create --members credA,credB[,credC] [--strategy failover] +sluice pool list +sluice pool status +sluice pool rotate # operator override: force next member +sluice pool remove +``` + +Members are existing OAuth credentials (static credentials are rejected). Member order is the failover order. + +**Auto-failover behavior:** + +- HTTP 429, or 403 with a quota-exhaustion body -> the active member is rate-limited; cooled down for **60s**. +- HTTP 401, or a token-endpoint body of `invalid_grant` / `invalid_token` -> the active member's token is rejected; cooled down for **300s**. +- 2xx, 5xx, and any other status -> no-op (a server-side error is not evidence the account is exhausted). +- The active-member switch is **synchronous**: the cooldown is recorded in memory before the response returns, so the very next request injects the next member. The durable store write only reconciles for restarts. +- **No in-flight retry**: the triggering request still returns its own upstream error to the agent; the agent's own retry resolves to the freshly-activated next member. +- Every failover emits a `cred_failover` audit event (`Reason = ":->:"`) and a best-effort Telegram notice. + +The phantom access token is **byte-identical across a member switch** (pooled OAuth credentials use a pool-keyed synthetic JWT resign), so the agent never observes the rollover. + ## Approval Channels Sluice broadcasts "ask" verdicts to all configured approval channels. The first channel to respond wins. Other channels get a cancellation notice. diff --git a/cmd/sluice/main.go b/cmd/sluice/main.go index 47135fb..022d008 100644 --- a/cmd/sluice/main.go +++ b/cmd/sluice/main.go @@ -479,7 +479,9 @@ func main() { } } if failoverBroker != nil { - msg := fmt.Sprintf("pool `%s` failed over `%s`→`%s` (%s)", + // Plain text: TelegramChannel.Notify sends with no parse + // mode, so markdown backticks would render literally. + msg := fmt.Sprintf("pool %s failed over %s -> %s (%s)", ev.Pool, ev.From, ev.To, ev.Reason) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() diff --git a/internal/channel/broker.go b/internal/channel/broker.go index 89a1559..1c600ff 100644 --- a/internal/channel/broker.go +++ b/internal/channel/broker.go @@ -250,9 +250,7 @@ func (b *Broker) Request(dest string, port int, protocol string, timeout time.Du w.subs = append(w.subs, subCh) w.count++ b.waiters[primaryID] = w - count := w.count b.mu.Unlock() - b.notifyCoalesced(primaryID, count) return b.waitSub(primaryID, subCh, deadline.C, timeout) } } @@ -454,12 +452,6 @@ func (b *Broker) CoalescedCount(id string) int { return 1 } -// notifyCoalesced is the Phase 1 no-op hook for live mid-burst "+N pending" -// indicators. Phase 2 fills this in to best-effort call channels that -// implement a CoalesceNotifier interface. Keeping the call site here means -// Phase 2 is a localized change with no churn to Request. -func (b *Broker) notifyCoalesced(_ string, _ int) {} - // broadcast sends the approval request to all channels. Errors and panics // from individual channels are logged but do not prevent other channels from // receiving the request. diff --git a/internal/proxy/addon.go b/internal/proxy/addon.go index 3cbb389..1d27369 100644 --- a/internal/proxy/addon.go +++ b/internal/proxy/addon.go @@ -85,7 +85,7 @@ type SluiceAddon struct { resolver *atomic.Pointer[vault.BindingResolver] // poolResolver expands a bound pool name to its currently active - // member at the single injection chokepoint (resolvePoolMember). + // member at the single injection chokepoint (resolveInjectionTarget). // Swapped atomically alongside resolver on reload; may be nil when // no pools are configured (treated as identity passthrough). Phase 2 // mutates the contained health map in place under the resolver's own @@ -216,27 +216,6 @@ func (a *SluiceAddon) SetPoolResolver(r *atomic.Pointer[vault.PoolResolver]) { a.poolResolver = r } -// resolvePoolMember is the single chokepoint that expands a bound -// credential-or-pool name to the concrete credential whose secret should be -// injected. For a plain credential it returns the name unchanged. For a -// pool it returns the currently active member. Every consumer that reads a -// binding's Credential (pass-1 header inject, pass-2 phantom pairs, -// OAuthIndex.Has gating, persist attribution) routes through here so pool -// expansion happens in exactly one place (Important I2). -func (a *SluiceAddon) resolvePoolMember(name string) string { - if a.poolResolver == nil { - return name - } - pr := a.poolResolver.Load() - if pr == nil { - return name - } - if member, ok := pr.ResolveActive(name); ok { - return member - } - return name -} - // injectionTarget is the result of expanding a bound credential-or-pool // name at the single chokepoint. phantomName is the name the agent's // phantom string is keyed on (the POOL name when pooled, so the phantom is diff --git a/internal/proxy/pool_failover.go b/internal/proxy/pool_failover.go index 9ea771c..219e8ff 100644 --- a/internal/proxy/pool_failover.go +++ b/internal/proxy/pool_failover.go @@ -161,6 +161,23 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str } return boundName, member, pr, true } + + // Token-endpoint path. An OAuth refresh hits the credential's token-URL + // host (e.g. auth.openai.com), which has no pool binding — the pool + // binding lives on the API host (e.g. api.openai.com). Without this the + // token-endpoint 401 / invalid_grant classification is dead code for the + // primary Codex deployment (only the 429/403 API-host path would ever + // fire). When the request URL matches the OAuth token-URL index for a + // credential that is a pool member, attribute the response to that pool + // and that exact member (idx.Match is strict 1:1 token_url->credential, + // so the member is the one whose refresh token sluice injected). + if idx := a.oauthIndex.Load(); idx != nil && f.Request != nil { + if matched, mok := idx.Match(f.Request.URL); mok && matched != "" { + if pool := pr.PoolForMember(matched); pool != "" { + return pool, matched, pr, true + } + } + } return "", "", nil, false } diff --git a/internal/proxy/pool_failover_test.go b/internal/proxy/pool_failover_test.go index 03c07d5..a16b232 100644 --- a/internal/proxy/pool_failover_test.go +++ b/internal/proxy/pool_failover_test.go @@ -7,11 +7,13 @@ import ( "os" "path/filepath" "strings" + "sync/atomic" "testing" "time" mitmproxy "github.com/lqqyt2423/go-mitmproxy/proxy" "github.com/nemirovsky/sluice/internal/audit" + "github.com/nemirovsky/sluice/internal/store" "github.com/nemirovsky/sluice/internal/vault" uuid "github.com/satori/go.uuid" ) @@ -226,7 +228,19 @@ func TestFailoverNonPooledDestinationIgnored(t *testing.T) { called := false addon.SetOnFailover(func(FailoverEvent) { called = true }) - f := newPoolRespFlow(client, 429, nil) + // Request URL is a plain API endpoint on an unrelated host: it neither + // has a pooled CONNECT binding NOR matches any pooled member's OAuth + // token URL, so poolForResponse must return ok=false and no failover + // fires. (newPoolRespFlow points the request at the token URL, which + // WOULD legitimately match a pooled member via the CRITICAL-2 token-URL + // path, so it must not be used here.) + u, _ := url.Parse("https://unrelated.example.com/v1/data") + f := &mitmproxy.Flow{ + Id: uuid.NewV4(), + ConnContext: &mitmproxy.ConnContext{ClientConn: client}, + Request: &mitmproxy.Request{Method: "GET", URL: u, Header: make(http.Header)}, + Response: &mitmproxy.Response{StatusCode: 429, Header: make(http.Header)}, + } addon.Response(f) if called { t.Fatal("onFailover invoked for a non-pooled destination") @@ -304,3 +318,111 @@ func TestPoolForResponseResolvesActiveMember(t *testing.T) { t.Fatal("poolForResponse returned a different resolver than the live one") } } + +// setupPoolAddonSplitHost is like setupPoolAddon but the pool binding lives on +// the API host (api.example.com) while the OAuth token URL is on a DIFFERENT +// host (auth.example.com). This mirrors the real Codex deployment: the pool +// binding is on api.openai.com, the OAuth refresh hits auth.openai.com. The +// CONNECT-host reverse mapping in poolForResponse therefore CANNOT match a +// token-endpoint response — only the token-URL-index path can. +func setupPoolAddonSplitHost(t *testing.T, poolName, memberA, memberB string) (*SluiceAddon, *atomic.Pointer[vault.PoolResolver]) { + t.Helper() + + provider := &addonWritableProvider{ + creds: map[string]string{ + memberA: poolMemberCred(t, "A-access-old", "A-refresh-old"), + memberB: poolMemberCred(t, "B-access-old", "B-refresh-old"), + }, + } + + // Pool binding is on the API host, NOT the token-URL host. + bindings := []vault.Binding{{ + Destination: "api.example.com", + Ports: []int{443}, + Credential: poolName, + }} + resolver, err := vault.NewBindingResolver(bindings) + if err != nil { + t.Fatalf("NewBindingResolver: %v", err) + } + var resolverPtr atomic.Pointer[vault.BindingResolver] + resolverPtr.Store(resolver) + + addon := NewSluiceAddon(WithResolver(&resolverPtr), WithProvider(provider)) + addon.persistDone = make(chan struct{}, 10) + + // testOAuthTokenURL is https://auth.example.com/oauth/token — a different + // host from the api.example.com pool binding above. + metas := []store.CredentialMeta{ + {Name: memberA, CredType: "oauth", TokenURL: testOAuthTokenURL}, + {Name: memberB, CredType: "oauth", TokenURL: testOAuthTokenURL}, + } + addon.UpdateOAuthIndex(metas) + + pool := store.Pool{Name: poolName, Strategy: store.PoolStrategyFailover} + pool.Members = []store.PoolMember{ + {Credential: memberA, Position: 0}, + {Credential: memberB, Position: 1}, + } + var prPtr atomic.Pointer[vault.PoolResolver] + prPtr.Store(vault.NewPoolResolver([]store.Pool{pool}, nil)) + addon.SetPoolResolver(&prPtr) + + return addon, &prPtr +} + +// TestTokenEndpointHostFailoverOnPooledMember is the CRITICAL-2 regression. +// The OAuth refresh hits the token-URL host (auth.example.com), which has NO +// pool binding (the binding is on api.example.com). Without the token-URL +// index path in poolForResponse, the token-endpoint 401/invalid_grant +// classification is dead code: poolForResponse returns ok=false and the +// member is never cooled down. The fix recognizes the pooled member via +// idx.Match(f.Request.URL) -> PoolForMember. +func TestTokenEndpointHostFailoverOnPooledMember(t *testing.T) { + addon, prPtr := setupPoolAddonSplitHost(t, "codex_pool", "memA", "memB") + // CONNECT target is the TOKEN-URL host, which has no pool binding. + client := setupAddonConn(addon, "auth.example.com:443") + + pr := prPtr.Load() + if got, _ := pr.ResolveActive("codex_pool"); got != "memA" { + t.Fatalf("pre-failover active = %q, want memA", got) + } + + // Sanity: the CONNECT-host reverse mapping alone must NOT match here + // (this is exactly the gap CRITICAL-2 describes). poolForResponse must + // still succeed via the token-URL index path. + f := newPoolRespFlow(client, 400, []byte(`{"error":"invalid_grant"}`)) + pool, member, _, ok := addon.poolForResponse(f) + if !ok { + t.Fatal("poolForResponse: token-endpoint response on a pooled member must be attributed (CRITICAL-2 fix); got ok=false") + } + if pool != "codex_pool" || member != "memA" { + t.Fatalf("got pool=%q member=%q, want codex_pool/memA", pool, member) + } + + var got FailoverEvent + gotCalled := make(chan struct{}, 1) + addon.SetOnFailover(func(ev FailoverEvent) { + got = ev + gotCalled <- struct{}{} + }) + + // A token-endpoint invalid_grant must cool memA and switch to memB. + addon.Response(newPoolRespFlow(client, 400, []byte(`{"error":"invalid_grant"}`))) + + if active, _ := pr.ResolveActive("codex_pool"); active != "memB" { + t.Fatalf("post-failover active = %q, want memB (token-endpoint auth failure must fail over)", active) + } + + select { + case <-gotCalled: + case <-time.After(2 * time.Second): + t.Fatal("onFailover callback not invoked for token-endpoint failover") + } + if got.Pool != "codex_pool" || got.From != "memA" || got.To != "memB" || got.Reason != "invalid_grant" { + t.Fatalf("FailoverEvent = %+v, want pool=codex_pool from=memA to=memB reason=invalid_grant", got) + } + if got.Class != failoverAuthFailure { + t.Fatalf("class = %v, want auth-failure", got.Class) + } +} diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 9006893..346feb2 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -2708,6 +2708,19 @@ func (s *Server) StoreResolver(r *vault.BindingResolver) { // can call IsPool/ResolveActive without nil-checking; ResolveActive on a // non-pool name is an identity passthrough. func (s *Server) StorePool(r *vault.PoolResolver) { + // Carry forward still-active in-memory cooldowns from the resolver being + // replaced. Phase 2 failover records cooldowns synchronously in memory and + // only persists them to the store from a detached best-effort goroutine, so + // rebuilding from store rows alone (NewPoolResolver, called by any reload — + // SIGHUP or the 2s data_version watcher on any unrelated DB write) would + // otherwise resurrect a just-cooled member for the full cooldown TTL, or + // permanently if the async store write failed. The merge is monotonic: a + // live cooldown is never shortened or erased by an unrelated reload. + if r != nil { + if prev := s.poolResolver.Load(); prev != nil { + r.MergeLiveCooldowns(prev) + } + } s.poolResolver.Store(r) if s.addon != nil { s.addon.SetPoolResolver(&s.poolResolver) diff --git a/internal/store/pools.go b/internal/store/pools.go index 245672c..d540f09 100644 --- a/internal/store/pools.go +++ b/internal/store/pools.go @@ -317,7 +317,10 @@ func (s *Store) SetCredentialHealth(credential, status string, cooldownUntil tim } // GetCredentialHealth returns the health row for a credential, or nil if no -// row exists (which callers treat as healthy). +// row exists (which callers treat as healthy). This is an intentional +// single-row introspection surface (tests, and a targeted lookup the +// failover/reconcile paths can use instead of scanning ListCredentialHealth); +// it is not currently on a hot path. func (s *Store) GetCredentialHealth(credential string) (*CredentialHealth, error) { var h CredentialHealth var cu, reason sql.NullString diff --git a/internal/vault/pool.go b/internal/vault/pool.go index ceb54f7..d33e778 100644 --- a/internal/vault/pool.go +++ b/internal/vault/pool.go @@ -101,7 +101,9 @@ func (pr *PoolResolver) PoolForMember(credential string) string { return "" } -// Members returns the ordered member list for a pool (copy), or nil. +// Members returns the ordered member list for a pool (copy), or nil. Exposed +// as an introspection surface for tests and potential future `pool status` +// detail output; not on any hot path. func (pr *PoolResolver) Members(pool string) []string { if pr == nil { return nil @@ -174,8 +176,60 @@ func (pr *PoolResolver) MarkCooldown(credential string, until time.Time, reason pr.health[credential] = memberHealth{cooldownUntil: until, reason: reason} } +// MergeLiveCooldowns carries forward still-active in-memory cooldowns from a +// previous resolver into this freshly built one. It MUST be called before the +// new resolver is atomically swapped in. +// +// Why this exists: Phase 2 failover records the active member's cooldown +// synchronously in the in-memory health map (MarkCooldown) and only persists +// SetCredentialHealth to the store from a detached best-effort goroutine. Any +// reload (SIGHUP, or the 2s data_version watcher firing on ANY unrelated DB +// write — a policy add, a cred update, an audit row) rebuilds the resolver +// from store rows alone via NewPoolResolver. Without this merge, a reload that +// races ahead of (or outlives a failed) durable write would seed health only +// from the store and silently resurrect a member that was just cooled down in +// memory, defeating the I1 synchronous-failover guarantee for the full +// cooldown TTL (60s/300s) — or permanently if the async store write failed. +// +// Merge policy: for every member that this resolver still knows about (so a +// cooldown for a credential removed from all pools is correctly dropped), the +// later of the store-seeded expiry and the previous in-memory expiry wins. +// Expired cooldowns on either side are not carried. This is monotonic: a live +// cooldown can never be shortened or erased by an unrelated reload. +func (pr *PoolResolver) MergeLiveCooldowns(prev *PoolResolver) { + if pr == nil || prev == nil { + return + } + now := time.Now() + prev.mu.RLock() + prevHealth := make(map[string]memberHealth, len(prev.health)) + for k, v := range prev.health { + prevHealth[k] = v + } + prev.mu.RUnlock() + + pr.mu.Lock() + defer pr.mu.Unlock() + for cred, ph := range prevHealth { + if ph.cooldownUntil.IsZero() || !ph.cooldownUntil.After(now) { + continue // expired in the old resolver; nothing to carry + } + // Only carry cooldowns for credentials this resolver still tracks as a + // pool member; an orphaned cooldown for a removed member is dropped. + if _, stillMember := pr.memberOf[cred]; !stillMember { + continue + } + existing, ok := pr.health[cred] + if !ok || ph.cooldownUntil.After(existing.cooldownUntil) { + pr.health[cred] = ph + } + } +} + // CooldownUntil returns the in-memory cooldown expiry for a credential and -// whether it is currently cooling down (future expiry). +// whether it is currently cooling down (future expiry). Exposed as an +// introspection surface for tests and potential future `pool status` +// detail output; not on any hot path. func (pr *PoolResolver) CooldownUntil(credential string) (time.Time, bool) { if pr == nil { return time.Time{}, false diff --git a/internal/vault/pool_test.go b/internal/vault/pool_test.go index 0293d3d..1bf9e2e 100644 --- a/internal/vault/pool_test.go +++ b/internal/vault/pool_test.go @@ -116,6 +116,76 @@ func TestPoolForMemberAndMembers(t *testing.T) { } } +// TestMergeLiveCooldownsSurvivesUnrelatedReload is the CRITICAL-1 regression: +// an unrelated reload rebuilds the resolver from store rows alone (no cooldown +// row, because the durable SetCredentialHealth write is detached/best-effort +// and may not have landed). Without MergeLiveCooldowns the freshly built +// resolver would resurrect member "a" — defeating the I1 synchronous-failover +// guarantee. With the merge, "a" stays cooled and ResolveActive picks "b". +func TestMergeLiveCooldownsSurvivesUnrelatedReload(t *testing.T) { + // Live resolver: member "a" failed over and was cooled down in memory. + prev := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, nil) + prev.MarkCooldown("a", time.Now().Add(60*time.Second), "429") + if got, _ := prev.ResolveActive("pool"); got != "b" { + t.Fatalf("precondition: live resolver active = %q; want b", got) + } + + // Unrelated reload: store has NO credential_health row (async write not + // yet persisted), so NewPoolResolver seeds an empty health map. + fresh := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, nil) + if got, _ := fresh.ResolveActive("pool"); got != "a" { + t.Fatalf("sanity: fresh resolver without merge resurrects to %q; want a (proves the bug exists without the fix)", got) + } + + // The fix: StorePool calls MergeLiveCooldowns before the atomic swap. + fresh.MergeLiveCooldowns(prev) + + if got, ok := fresh.ResolveActive("pool"); !ok || got != "b" { + t.Errorf("after merge ResolveActive(pool) = %q,%v; want b,true (cooled member must NOT be resurrected by an unrelated reload)", got, ok) + } + if until, cooling := fresh.CooldownUntil("a"); !cooling || until.IsZero() { + t.Errorf("after merge member a should still be cooling down; got until=%v cooling=%v", until, cooling) + } +} + +// TestMergeLiveCooldownsIsMonotonic: a store-seeded cooldown that is later +// than the in-memory one is kept (never shortened), and an expired in-memory +// cooldown is not carried. +func TestMergeLiveCooldownsIsMonotonic(t *testing.T) { + storeLater := time.Now().Add(300 * time.Second) + fresh := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, + []store.CredentialHealth{{Credential: "a", Status: "cooldown", CooldownUntil: storeLater, LastFailureReason: "401"}}) + + prev := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, nil) + prev.MarkCooldown("a", time.Now().Add(10*time.Second), "429") // earlier than store + prev.MarkCooldown("b", time.Now().Add(-1*time.Second), "stale") // already expired + + fresh.MergeLiveCooldowns(prev) + + until, cooling := fresh.CooldownUntil("a") + if !cooling || until.Before(storeLater.Add(-time.Second)) { + t.Errorf("merge must not shorten a longer store cooldown: got %v, want ~%v", until, storeLater) + } + if _, cooling := fresh.CooldownUntil("b"); cooling { + t.Error("expired in-memory cooldown for b must not be carried forward") + } +} + +// TestMergeLiveCooldownsDropsRemovedMember: a cooldown for a credential no +// longer in any pool (membership change) is not carried. +func TestMergeLiveCooldownsDropsRemovedMember(t *testing.T) { + prev := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, nil) + prev.MarkCooldown("b", time.Now().Add(60*time.Second), "429") + + // New membership: "b" was removed from the pool. + fresh := NewPoolResolver([]store.Pool{mkPool("pool", "a")}, nil) + fresh.MergeLiveCooldowns(prev) + + if _, cooling := fresh.CooldownUntil("b"); cooling { + t.Error("cooldown for a removed member must be dropped, not carried") + } +} + func TestNilPoolResolverSafe(t *testing.T) { var pr *PoolResolver if got, ok := pr.ResolveActive("x"); !ok || got != "x" { From 8823362740223c198cd788fe65111019d4174d38 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 09:59:27 +0800 Subject: [PATCH 22/49] fix(proxy): correct token-endpoint failover member attribution + harden cooldown durability --- cmd/sluice/main.go | 41 ++-- internal/proxy/pool_attribution.go | 42 ++++ internal/proxy/pool_failover.go | 54 ++++- internal/proxy/pool_failover_test.go | 317 +++++++++++++++++++++++++++ internal/vault/pool.go | 219 ++++++++++++------ internal/vault/pool_test.go | 86 ++++++++ 6 files changed, 676 insertions(+), 83 deletions(-) diff --git a/cmd/sluice/main.go b/cmd/sluice/main.go index 022d008..d4e63d5 100644 --- a/cmd/sluice/main.go +++ b/cmd/sluice/main.go @@ -347,14 +347,23 @@ func main() { } } + // Process-wide shared pool health. CRITICAL-1: every PoolResolver + // generation (startup + every SIGHUP / data_version reload) is built + // against THIS single PoolHealth, so a failover cooldown recorded via + // MarkCooldown on any generation is observed by ResolveActive on the + // current generation regardless of how many resolver pointer swaps + // happened in between, and the cooldown never depends on the detached + // durable SetCredentialHealth write succeeding. + sharedPoolHealth := vault.NewPoolHealth() + // Populate the initial credential pool resolver at startup so pool // expansion works for pools defined before the first SIGHUP. Always // store a non-nil resolver (empty when no pools) so the addon never // has to nil-check before ResolveActive (non-pool names passthrough). if db != nil { - if pr, perr := loadPoolResolver(db); perr != nil { + if pr, perr := loadPoolResolver(db, sharedPoolHealth); perr != nil { log.Printf("pool resolver init failed: %v", perr) - srv.StorePool(vault.NewPoolResolver(nil, nil)) + srv.StorePool(vault.NewPoolResolverShared(nil, nil, sharedPoolHealth)) } else { srv.StorePool(pr) } @@ -737,12 +746,14 @@ func main() { log.Printf("reload oauth index failed: %v", metaErr) } - // Rebuild and atomically swap the credential pool resolver. - // Membership changes (pool create/remove) take effect here; - // durable health rows are reloaded too, which only reconciles - // the in-memory health that Phase 2 failover already updated - // synchronously on the response path. - if pr, perr := loadPoolResolver(db); perr != nil { + // Rebuild and atomically swap the credential pool resolver. The + // new generation is built against the SAME process-wide + // sharedPoolHealth (CRITICAL-1), so membership changes (pool + // create/remove) take effect here while live failover cooldowns + // recorded in memory survive the swap with zero dependency on the + // detached durable write. Reloading health rows only seeds the + // shared map monotonically (never shortens a live cooldown). + if pr, perr := loadPoolResolver(db, sharedPoolHealth); perr != nil { log.Printf("reload pool resolver failed: %v", perr) } else { srv.StorePool(pr) @@ -875,11 +886,13 @@ func readBindings(db *store.Store) ([]vault.Binding, error) { } // loadPoolResolver builds a vault.PoolResolver from the store's pool, -// member, and credential-health tables. A non-nil resolver is always -// returned on success (empty when no pools), so callers can store it -// unconditionally and the addon never has to nil-check before -// ResolveActive (a non-pool name is an identity passthrough). -func loadPoolResolver(db *store.Store) (*vault.PoolResolver, error) { +// member, and credential-health tables, bound to the process-wide shared +// PoolHealth (CRITICAL-1) so failover cooldowns survive every resolver +// pointer swap. A non-nil resolver is always returned on success (empty +// when no pools), so callers can store it unconditionally and the addon +// never has to nil-check before ResolveActive (a non-pool name is an +// identity passthrough). +func loadPoolResolver(db *store.Store, shared *vault.PoolHealth) (*vault.PoolResolver, error) { pools, err := db.ListPools() if err != nil { return nil, fmt.Errorf("list pools: %w", err) @@ -888,7 +901,7 @@ func loadPoolResolver(db *store.Store) (*vault.PoolResolver, error) { if err != nil { return nil, fmt.Errorf("list credential health: %w", err) } - return vault.NewPoolResolver(pools, health), nil + return vault.NewPoolResolverShared(pools, health, shared), nil } // injectEnvVarsFromStore reads bindings with env_var set from the store, diff --git a/internal/proxy/pool_attribution.go b/internal/proxy/pool_attribution.go index b9a44fd..b733d86 100644 --- a/internal/proxy/pool_attribution.go +++ b/internal/proxy/pool_attribution.go @@ -68,6 +68,11 @@ func (r *refreshAttribution) Tag(realRefreshToken, member string) { // removes the entry (single-use: a rotated refresh token will never be // presented again). Returns ("", false) when no live tag exists — the // caller MUST fail closed (skip the vault write, never guess) per R1. +// +// Recover is used exclusively by the 2xx persist path +// (resolveOAuthResponseAttribution): a successful refresh rotates the +// refresh token, so the tag is dead after one use and must be deleted to +// bound the map. func (r *refreshAttribution) Recover(realRefreshToken string) (string, bool) { if realRefreshToken == "" { return "", false @@ -84,3 +89,40 @@ func (r *refreshAttribution) Recover(realRefreshToken string) (string, bool) { } return e.member, true } + +// Peek returns the member tagged for the given real refresh token WITHOUT +// removing the entry. Returns ("", false) when no live tag exists. +// +// This is the CRITICAL-2 join key for the FAILOVER path +// (poolForResponse). Two pool members share one token URL, so +// OAuthIndex.Match is 1:1 and always returns the first index entry +// regardless of which member's refresh token is actually in the request +// body. Attributing a token-endpoint failure by idx.Match therefore cools +// the WRONG member whenever the failing member is not the first index +// entry. The injected real refresh token, by contrast, is unique per +// member and present verbatim in the refresh-grant request body, so it +// recovers the true owning member. +// +// Peek does NOT delete the entry because, unlike Recover (2xx success +// rotates the token, making the tag dead), a token-endpoint FAILURE +// (401 / invalid_grant) does NOT rotate the refresh token: the agent's +// SDK will retry the same refresh token and the tag must still resolve. +// processOAuthResponseIfMatching (the Recover caller) is 2xx-only, so on +// a 4xx the tag has not been consumed and is still live for this Peek. +// The entry is allowed to expire naturally via refreshAttrTTL / the +// opportunistic sweep in Tag. +func (r *refreshAttribution) Peek(realRefreshToken string) (string, bool) { + if realRefreshToken == "" { + return "", false + } + r.mu.Lock() + defer r.mu.Unlock() + e, ok := r.entries[realRefreshToken] + if !ok { + return "", false + } + if time.Now().After(e.expires) { + return "", false + } + return e.member, true +} diff --git a/internal/proxy/pool_failover.go b/internal/proxy/pool_failover.go index 219e8ff..c802aa5 100644 --- a/internal/proxy/pool_failover.go +++ b/internal/proxy/pool_failover.go @@ -167,13 +167,59 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str // binding lives on the API host (e.g. api.openai.com). Without this the // token-endpoint 401 / invalid_grant classification is dead code for the // primary Codex deployment (only the 429/403 API-host path would ever - // fire). When the request URL matches the OAuth token-URL index for a - // credential that is a pool member, attribute the response to that pool - // and that exact member (idx.Match is strict 1:1 token_url->credential, - // so the member is the one whose refresh token sluice injected). + // fire). + // + // CRITICAL-2: OAuthIndex.Match is 1:1 token_url->credential and returns + // the FIRST matching index entry. For the documented primary deployment + // (two Codex OAuth accounts in ONE pool sharing the SAME token URL + // auth.openai.com) every member's index entry has an identical token + // URL, so idx.Match ALWAYS returns the first entry regardless of which + // member's refresh token is actually in the request body. Attributing + // the failure by idx.Match alone cools the wrong member whenever the + // failing member is not the first index entry (e.g. memA cooled by an + // API 429, memB now active, memB's refresh invalid_grants -> idx.Match + // returns memA -> innocent memA re-cooled, dead memB stays active -> + // the pool thrashes the broken account forever). + // + // The correct join key is the per-member-UNIQUE real refresh token that + // pass-2 injected into this exact request body — the SAME mechanism the + // 2xx persist path (resolveOAuthResponseAttribution) uses. We Peek (not + // Recover) the refresh-attribution map so the single-use tag survives + // for the persist path; a token-endpoint FAILURE does not rotate the + // refresh token and processOAuthResponseIfMatching is 2xx-only, so the + // tag is still live here. if idx := a.oauthIndex.Load(); idx != nil && f.Request != nil { if matched, mok := idx.Match(f.Request.URL); mok && matched != "" { if pool := pr.PoolForMember(matched); pool != "" { + // Recover the TRUE owning member from the injected real + // refresh token in the buffered request body. + reqCT := "" + if f.Request.Header != nil { + reqCT = f.Request.Header.Get("Content-Type") + } + realRefresh := extractRequestRefreshToken(f.Request.Body, reqCT) + if owner, ok := a.refreshAttr.Peek(realRefresh); ok && owner != "" { + if ownerPool := pr.PoolForMember(owner); ownerPool != "" { + return ownerPool, owner, pr, true + } + // owner is no longer in any pool (membership change + // raced the failure); fall through to the active-member + // fallback below for a still-meaningful attribution. + } + // Fallback ONLY when the real refresh token cannot be + // extracted / attributed: cool the ACTIVE member rather + // than blindly the first index entry. The active member is + // the one whose token was most likely just injected, so it + // is strictly better than idx.Match's deterministic-first. + if active, aok := pr.ResolveActive(pool); aok && active != "" { + log.Printf("[POOL-FAILOVER] pool %q: could not attribute "+ + "token-endpoint failure via injected refresh token; "+ + "falling back to active member %q", pool, active) + return pool, active, pr, true + } + // Last resort: the index match (preserves prior behavior + // when even ResolveActive cannot decide; better than no + // attribution at all). return pool, matched, pr, true } } diff --git a/internal/proxy/pool_failover_test.go b/internal/proxy/pool_failover_test.go index a16b232..b650542 100644 --- a/internal/proxy/pool_failover_test.go +++ b/internal/proxy/pool_failover_test.go @@ -426,3 +426,320 @@ func TestTokenEndpointHostFailoverOnPooledMember(t *testing.T) { t.Fatalf("class = %v, want auth-failure", got.Class) } } + +// newPoolRespFlowBody builds a token-endpoint response flow whose REQUEST +// body carries the given (already pass-2-swapped) real refresh token, so +// poolForResponse can recover the true owning member via the refresh +// attribution map (the CRITICAL-2 join key). +func newPoolRespFlowBody(client *mitmproxy.ClientConn, status int, realRefresh string, respBody []byte) *mitmproxy.Flow { + u, _ := url.Parse(testOAuthTokenURL) + reqHdr := make(http.Header) + reqHdr.Set("Content-Type", "application/x-www-form-urlencoded") + respHdr := make(http.Header) + respHdr.Set("Content-Type", "application/json") + return &mitmproxy.Flow{ + Id: uuid.NewV4(), + ConnContext: &mitmproxy.ConnContext{ClientConn: client}, + Request: &mitmproxy.Request{ + Method: "POST", + URL: u, + Header: reqHdr, + Body: []byte("grant_type=refresh_token&refresh_token=" + realRefresh), + }, + Response: &mitmproxy.Response{ + StatusCode: status, + Header: respHdr, + Body: respBody, + }, + } +} + +// TestTokenEndpointFailoverAttributesInjectedMemberNotFirstIndex is the +// CRITICAL-2 regression. Both members share one token URL, so +// OAuthIndex.Match deterministically returns the FIRST index entry (memA) +// regardless of which member's refresh token is in the request body. The +// failing/active member here is memB (not the first index entry). The bug: +// the failover path attributed by idx.Match and cooled the WRONG member +// (memA), leaving the dead memB active so the pool thrashed the broken +// account forever. The fix recovers the true owner from the injected real +// refresh token (refreshAttribution.Peek), the SAME join key the 2xx +// persist path uses. +// +// This test MUST fail before the fix: idx.Match -> memA, so memA would be +// (re-)cooled and memB left untouched/active. +func TestTokenEndpointFailoverAttributesInjectedMemberNotFirstIndex(t *testing.T) { + addon, prPtr := setupPoolAddonSplitHost(t, "codex_pool", "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + pr := prPtr.Load() + + // memA is first index AND would be first active. Cool memA via an API + // 429 path so memB becomes the active member (the realistic precursor: + // memA rate-limited on api host, traffic rolled to memB). + memACooldown := time.Now().Add(90 * time.Second) + pr.MarkCooldown("memA", memACooldown, "429") + if got, _ := pr.ResolveActive("codex_pool"); got != "memB" { + t.Fatalf("after cooling memA, active = %q, want memB", got) + } + + // pass-2 injected memB's real refresh token into this refresh request; + // mirror that by tagging the attribution map (what the real Request() + // pass-2 swap does) and putting memB's real refresh in the body. + addon.refreshAttr.Tag("B-refresh-old", "memB") + + // Sanity: idx.Match alone returns memA (the collision the bug rode on). + if idx := addon.oauthIndex.Load(); idx != nil { + u, _ := url.Parse(testOAuthTokenURL) + if matched, _ := idx.Match(u); matched != "memA" { + t.Fatalf("precondition: idx.Match must return first entry memA, got %q", matched) + } + } + + // poolForResponse must now attribute the failure to memB (the injected + // member), NOT memA (the first index entry). + f := newPoolRespFlowBody(client, 400, "B-refresh-old", []byte(`{"error":"invalid_grant"}`)) + pool, member, _, ok := addon.poolForResponse(f) + if !ok { + t.Fatal("poolForResponse: token-endpoint failure on a pooled member must be attributed") + } + if pool != "codex_pool" || member != "memB" { + t.Fatalf("got pool=%q member=%q, want codex_pool/memB (CRITICAL-2: must attribute the INJECTED member, not idx.Match's first entry)", pool, member) + } + + var got FailoverEvent + gotCalled := make(chan struct{}, 1) + addon.SetOnFailover(func(ev FailoverEvent) { + got = ev + gotCalled <- struct{}{} + }) + + addon.Response(newPoolRespFlowBody(client, 400, "B-refresh-old", []byte(`{"error":"invalid_grant"}`))) + + // memB must now be cooled with the long auth-failure TTL. + bUntil, bCooling := pr.CooldownUntil("memB") + if !bCooling { + t.Fatal("memB must be in cooldown after its own invalid_grant (CRITICAL-2)") + } + if time.Until(bUntil) < vault.AuthFailCooldown-30*time.Second { + t.Fatalf("memB cooldown TTL = %s, want ~%s (auth-failure)", time.Until(bUntil), vault.AuthFailCooldown) + } + + // memA must be UNTOUCHED: still cooling on its ORIGINAL 90s 429 window, + // NOT re-cooled with memB's 300s auth-failure TTL. The bug re-cooled + // memA here; the fix must leave memA's cooldown exactly as it was. + aUntil, aCooling := pr.CooldownUntil("memA") + if !aCooling { + t.Fatal("memA should still be cooling on its original 429 window") + } + if aUntil.Sub(memACooldown).Abs() > time.Second { + t.Fatalf("memA cooldown was modified: got %s, want original %s (innocent member must not be re-cooled — CRITICAL-2)", + aUntil.Format(time.RFC3339Nano), memACooldown.Format(time.RFC3339Nano)) + } + + select { + case <-gotCalled: + case <-time.After(2 * time.Second): + t.Fatal("onFailover callback not invoked") + } + if got.From != "memB" { + t.Fatalf("FailoverEvent.From = %q, want memB (the correctly-attributed failing member)", got.From) + } + if got.Pool != "codex_pool" || got.Reason != "invalid_grant" || got.Class != failoverAuthFailure { + t.Fatalf("FailoverEvent = %+v, want pool=codex_pool reason=invalid_grant class=auth-failure", got) + } +} + +// setupPoolAddonSplitHost3 is setupPoolAddonSplitHost with three members, +// all sharing one token URL on a host distinct from the pool binding host. +func setupPoolAddonSplitHost3(t *testing.T, poolName, a, b, c string) (*SluiceAddon, *atomic.Pointer[vault.PoolResolver]) { + t.Helper() + provider := &addonWritableProvider{ + creds: map[string]string{ + a: poolMemberCred(t, a+"-access", a+"-refresh"), + b: poolMemberCred(t, b+"-access", b+"-refresh"), + c: poolMemberCred(t, c+"-access", c+"-refresh"), + }, + } + bindings := []vault.Binding{{ + Destination: "api.example.com", + Ports: []int{443}, + Credential: poolName, + }} + resolver, err := vault.NewBindingResolver(bindings) + if err != nil { + t.Fatalf("NewBindingResolver: %v", err) + } + var resolverPtr atomic.Pointer[vault.BindingResolver] + resolverPtr.Store(resolver) + + addon := NewSluiceAddon(WithResolver(&resolverPtr), WithProvider(provider)) + addon.persistDone = make(chan struct{}, 10) + + metas := []store.CredentialMeta{ + {Name: a, CredType: "oauth", TokenURL: testOAuthTokenURL}, + {Name: b, CredType: "oauth", TokenURL: testOAuthTokenURL}, + {Name: c, CredType: "oauth", TokenURL: testOAuthTokenURL}, + } + addon.UpdateOAuthIndex(metas) + + pool := store.Pool{Name: poolName, Strategy: store.PoolStrategyFailover} + pool.Members = []store.PoolMember{ + {Credential: a, Position: 0}, + {Credential: b, Position: 1}, + {Credential: c, Position: 2}, + } + var prPtr atomic.Pointer[vault.PoolResolver] + prPtr.Store(vault.NewPoolResolver([]store.Pool{pool}, nil)) + addon.SetPoolResolver(&prPtr) + return addon, &prPtr +} + +// TestTokenEndpointFailover3MemberAttributesMiddleMember is the 3-member +// CRITICAL-2 variant: memA (first index) and memC are cooled, memB is +// active and refreshing. idx.Match still returns memA (first entry). The +// fix must cool memB (the injected member) and leave memA/memC's distinct +// cooldown windows untouched. +func TestTokenEndpointFailover3MemberAttributesMiddleMember(t *testing.T) { + addon, prPtr := setupPoolAddonSplitHost3(t, "codex_pool", "memA", "memB", "memC") + client := setupAddonConn(addon, "auth.example.com:443") + pr := prPtr.Load() + + aUntil0 := time.Now().Add(45 * time.Second) + cUntil0 := time.Now().Add(75 * time.Second) + pr.MarkCooldown("memA", aUntil0, "429") + pr.MarkCooldown("memC", cUntil0, "403") + if got, _ := pr.ResolveActive("codex_pool"); got != "memB" { + t.Fatalf("active = %q, want memB", got) + } + + addon.refreshAttr.Tag("memB-refresh", "memB") + + f := newPoolRespFlowBody(client, 401, "memB-refresh", []byte(`{"error":"invalid_token"}`)) + pool, member, _, ok := addon.poolForResponse(f) + if !ok || pool != "codex_pool" || member != "memB" { + t.Fatalf("poolForResponse got ok=%v pool=%q member=%q, want codex_pool/memB", ok, pool, member) + } + + addon.Response(newPoolRespFlowBody(client, 401, "memB-refresh", []byte(`{"error":"invalid_token"}`))) + + if _, cooling := pr.CooldownUntil("memB"); !cooling { + t.Fatal("memB must be cooled after its own 401") + } + if aU, c := pr.CooldownUntil("memA"); !c || aU.Sub(aUntil0).Abs() > time.Second { + t.Fatalf("memA cooldown changed: got %v (cooling=%v), want original %v", aU, c, aUntil0) + } + if cU, c := pr.CooldownUntil("memC"); !c || cU.Sub(cUntil0).Abs() > time.Second { + t.Fatalf("memC cooldown changed: got %v (cooling=%v), want original %v", cU, c, cUntil0) + } +} + +// TestTokenEndpointFailoverFallsBackToActiveMember asserts the documented +// fallback: when the real refresh token cannot be recovered from the body +// (no attribution tag — e.g. the request was not driven through pass-2), +// poolForResponse cools the ACTIVE member, never blindly idx.Match's first +// index entry. +func TestTokenEndpointFailoverFallsBackToActiveMember(t *testing.T) { + addon, prPtr := setupPoolAddonSplitHost(t, "codex_pool", "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + pr := prPtr.Load() + + // memA cooled -> memB active. NO refreshAttr tag is recorded, and the + // body's refresh token is not in the attribution map, so Peek misses. + pr.MarkCooldown("memA", time.Now().Add(90*time.Second), "429") + if got, _ := pr.ResolveActive("codex_pool"); got != "memB" { + t.Fatalf("active = %q, want memB", got) + } + + f := newPoolRespFlowBody(client, 400, "untagged-refresh", []byte(`{"error":"invalid_grant"}`)) + pool, member, _, ok := addon.poolForResponse(f) + if !ok { + t.Fatal("poolForResponse: expected attribution via active-member fallback") + } + if pool != "codex_pool" || member != "memB" { + t.Fatalf("fallback got pool=%q member=%q, want codex_pool/memB (active member, NOT idx.Match's memA)", pool, member) + } +} + +// TestServerStorePoolConcurrentMarkCooldown is the CRITICAL-1 integration +// regression at the real production code path: Server.StorePool's atomic +// pointer swap (the SIGHUP / data_version reload) racing handlePoolFailover's +// lock-free MarkCooldown. With the shared-PoolHealth fix the cooldown can +// never be lost across the swap. Run with -race. +func TestServerStorePoolConcurrentMarkCooldown(t *testing.T) { + srv := &Server{} // addon nil: StorePool's `if s.addon != nil` guards it. + shared := vault.NewPoolHealth() + pool := store.Pool{Name: "p", Strategy: store.PoolStrategyFailover} + pool.Members = []store.PoolMember{ + {Credential: "m0", Position: 0}, + {Credential: "m1", Position: 1}, + {Credential: "m2", Position: 2}, + } + srv.StorePool(vault.NewPoolResolverShared([]store.Pool{pool}, nil, shared)) + + const iters = 400 + far := 10 * time.Minute + done := make(chan struct{}) + + // Reload loop: rebuild + StorePool (the real atomic swap), bound to the + // SAME shared health, exactly like loadPoolResolver -> StorePool. + go func() { + for i := 0; i < iters; i++ { + srv.StorePool(vault.NewPoolResolverShared([]store.Pool{pool}, nil, shared)) + } + close(done) + }() + + // Failover loop: MarkCooldown on whatever resolver is live now (often + // one about to be replaced), with NO ReloadMu held — exactly the + // handlePoolFailover discipline. + for i := 0; i < iters; i++ { + pr := srv.poolResolver.Load() + pr.MarkCooldown(pool.Members[i%3].Credential, time.Now().Add(far), "429") + } + <-done + + latest := srv.poolResolver.Load() + for _, m := range pool.Members { + if _, cooling := latest.CooldownUntil(m.Credential); !cooling { + t.Fatalf("cooldown for %q lost across Server.StorePool swaps (CRITICAL-1)", m.Credential) + } + } +} + +// TestServerStorePoolStaleGenerationCooldownNotLost is the deterministic +// CRITICAL-1 regression that MergeLiveCooldowns' one-generation-back +// chaining provably cannot rescue. A reference to a generation is captured, +// TWO StorePool swaps happen (so the captured pointer is two generations +// stale and was already merged forward BEFORE the cooldown), THEN +// MarkCooldown is applied to that stale generation. Pre-fix, the cooldown +// was applied to a private health map that no live generation points at and +// that was merged forward before the mark — permanently invisible. The +// shared-PoolHealth fix makes it visible because every generation mutates +// the SAME map. A credential ("z") cooled by nothing else makes the +// assertion unambiguous. +func TestServerStorePoolStaleGenerationCooldownNotLost(t *testing.T) { + srv := &Server{} + shared := vault.NewPoolHealth() + pool := store.Pool{Name: "p", Strategy: store.PoolStrategyFailover} + pool.Members = []store.PoolMember{ + {Credential: "y", Position: 0}, + {Credential: "z", Position: 1}, + } + srv.StorePool(vault.NewPoolResolverShared([]store.Pool{pool}, nil, shared)) + + stale := srv.poolResolver.Load() // generation N + srv.StorePool(vault.NewPoolResolverShared([]store.Pool{pool}, nil, shared)) + srv.StorePool(vault.NewPoolResolverShared([]store.Pool{pool}, nil, shared)) + // "z" has never been cooled; mark it on the two-generations-stale ref. + stale.MarkCooldown("z", time.Now().Add(10*time.Minute), "401") + + cur := srv.poolResolver.Load() + if _, cooling := cur.CooldownUntil("z"); !cooling { + t.Fatal("cooldown applied to a two-generations-stale resolver was lost " + + "(CRITICAL-1: MergeLiveCooldowns chains only one generation back and " + + "runs before the late mark; only shared-PoolHealth survives this)") + } + // And it must steer ResolveActive on the live generation. + if got, _ := cur.ResolveActive("p"); got != "y" { + t.Fatalf("ResolveActive = %q, want y (z cooled via stale-gen mark)", got) + } +} diff --git a/internal/vault/pool.go b/internal/vault/pool.go index d33e778..c667109 100644 --- a/internal/vault/pool.go +++ b/internal/vault/pool.go @@ -25,6 +25,65 @@ type memberHealth struct { reason string } +// PoolHealth is the mutex-guarded credential cooldown map. It is +// deliberately a SEPARATE object from PoolResolver so it can outlive any +// single resolver generation. +// +// CRITICAL-1: pool membership is immutable per resolver, but a membership +// change (or any unrelated DB write triggering the 2s data_version +// watcher, or a SIGHUP) rebuilds a fresh PoolResolver that the server +// atomically pointer-swaps. Phase 2 failover's MarkCooldown runs on the +// response path WITHOUT holding ReloadMu, so a MarkCooldown landing on the +// old generation between a swap's snapshot and store could be lost; and a +// merge that chains only one generation back loses a cooldown permanently +// if the detached durable SetCredentialHealth write fails. +// +// The fix: construct ONE PoolHealth at process start and inject the SAME +// instance into every NewPoolResolver. MarkCooldown on any generation and +// ResolveActive on the current generation then mutate/read the SAME +// underlying map under the SAME mutex, so a cooldown can never be lost +// across a pointer swap and never depends on a durable write succeeding. +// Store rows still seed the map at startup (Seed) for cross-restart +// durability, and the seed is monotonic (never shortens a live in-memory +// cooldown). +type PoolHealth struct { + mu sync.RWMutex + health map[string]memberHealth +} + +// NewPoolHealth returns an empty shared health map. Call this exactly once +// per process and thread the result through every NewPoolResolver so all +// resolver generations share one cooldown view. +func NewPoolHealth() *PoolHealth { + return &PoolHealth{health: make(map[string]memberHealth)} +} + +// Seed merges store-persisted cooldown rows into the shared map. It is +// monotonic: a store row never shortens or clears a live in-memory +// cooldown (the in-memory value is authoritative because Phase 2 failover +// updates it synchronously and the durable write is best-effort/detached). +// Expired rows are ignored. Safe to call on every resolver rebuild. +func (ph *PoolHealth) Seed(healthRows []store.CredentialHealth) { + if ph == nil { + return + } + now := time.Now() + ph.mu.Lock() + defer ph.mu.Unlock() + for _, h := range healthRows { + if h.Status != "cooldown" || h.CooldownUntil.IsZero() || !h.CooldownUntil.After(now) { + continue + } + existing, ok := ph.health[h.Credential] + if !ok || h.CooldownUntil.After(existing.cooldownUntil) { + ph.health[h.Credential] = memberHealth{ + cooldownUntil: h.CooldownUntil, + reason: h.LastFailureReason, + } + } + } +} + // PoolResolver maps a pool name to its currently active member. It is the // single chokepoint every credential consumer routes through (injection // passes, OAuthIndex.Has gating, persist attribution), so a pool name is @@ -33,28 +92,56 @@ type memberHealth struct { // Locking discipline: pool membership is immutable for the lifetime of a // PoolResolver instance (membership changes rebuild a fresh resolver that // the server atomically pointer-swaps). Health, by contrast, is mutated -// synchronously on the response path during Phase 2 failover, so the health -// map is guarded by mu. ResolveActive takes mu.RLock; MarkCooldown takes -// mu.Lock. Readers therefore always observe a consistent active member even -// while a concurrent response is recording a failover. +// synchronously on the response path during Phase 2 failover and MUST +// survive resolver pointer swaps, so it lives in a SHARED *PoolHealth +// (one instance per process, injected into every generation). ResolveActive +// takes the shared RLock; MarkCooldown takes the shared Lock. A failover +// recorded on any generation is therefore visible to ResolveActive on the +// current generation regardless of how many reloads happened in between, +// and a cooldown can never be lost across a swap (CRITICAL-1). type PoolResolver struct { // pools maps pool name -> ordered member credential names. pools map[string][]string // memberOf maps a credential name -> the pools that contain it. memberOf map[string][]string - mu sync.RWMutex - health map[string]memberHealth + // health is the shared, swap-surviving cooldown map. Never nil after + // NewPoolResolver (a fresh PoolHealth is allocated when none is given, + // preserving the old single-generation behavior for ad-hoc callers). + health *PoolHealth } -// NewPoolResolver builds a resolver from store snapshots. Health rows with -// status "cooldown" and a future cooldown_until seed the in-memory health -// map; healthy rows and expired cooldowns are treated as eligible. +// NewPoolResolver builds a resolver from store snapshots with a PRIVATE +// per-instance health map. Use this for short-lived throwaway resolvers +// (CLI `pool` subcommands that build a resolver, print, and discard it). +// The long-lived proxy server MUST use NewPoolResolverShared so cooldowns +// survive resolver pointer swaps (CRITICAL-1). Health rows with status +// "cooldown" and a future cooldown_until seed the map; healthy rows and +// expired cooldowns are treated as eligible. func NewPoolResolver(pools []store.Pool, healthRows []store.CredentialHealth) *PoolResolver { + return NewPoolResolverShared(pools, healthRows, nil) +} + +// NewPoolResolverShared builds a resolver that shares the given PoolHealth +// across every resolver generation. Pass the process-wide *PoolHealth here +// (NewPoolHealth, created once) so MarkCooldown on any generation and +// ResolveActive on the current generation operate on the SAME mutex-guarded +// map — a cooldown can never be lost across an atomic pointer swap and +// never depends on the detached durable write succeeding (CRITICAL-1). +// When shared is nil a fresh private PoolHealth is allocated, preserving +// the old single-generation semantics for ad-hoc callers. +// +// healthRows seed the (possibly shared) map monotonically: an existing +// live in-memory cooldown is never shortened by a store row. Seeding the +// shared map on every rebuild is therefore safe and idempotent. +func NewPoolResolverShared(pools []store.Pool, healthRows []store.CredentialHealth, shared *PoolHealth) *PoolResolver { + if shared == nil { + shared = NewPoolHealth() + } pr := &PoolResolver{ pools: make(map[string][]string, len(pools)), memberOf: make(map[string][]string), - health: make(map[string]memberHealth), + health: shared, } for _, p := range pools { members := make([]string, 0, len(p.Members)) @@ -64,37 +151,29 @@ func NewPoolResolver(pools []store.Pool, healthRows []store.CredentialHealth) *P } pr.pools[p.Name] = members } - for _, h := range healthRows { - if h.Status == "cooldown" && !h.CooldownUntil.IsZero() { - pr.health[h.Credential] = memberHealth{ - cooldownUntil: h.CooldownUntil, - reason: h.LastFailureReason, - } - } - } + shared.Seed(healthRows) return pr } -// IsPool reports whether name is a configured pool. +// IsPool reports whether name is a configured pool. Pool membership is +// immutable for a resolver instance (a membership change builds a fresh +// resolver), so no lock is needed. func (pr *PoolResolver) IsPool(name string) bool { if pr == nil { return false } - pr.mu.RLock() - defer pr.mu.RUnlock() _, ok := pr.pools[name] return ok } // PoolForMember returns the first pool that contains the given credential, // or "" if the credential is not a pool member. Used by the response path to -// attribute a failover/refresh to its pool for audit + Telegram. +// attribute a failover/refresh to its pool for audit + Telegram. Membership +// is immutable per instance, so no lock is needed. func (pr *PoolResolver) PoolForMember(credential string) string { if pr == nil { return "" } - pr.mu.RLock() - defer pr.mu.RUnlock() if pools := pr.memberOf[credential]; len(pools) > 0 { return pools[0] } @@ -103,13 +182,12 @@ func (pr *PoolResolver) PoolForMember(credential string) string { // Members returns the ordered member list for a pool (copy), or nil. Exposed // as an introspection surface for tests and potential future `pool status` -// detail output; not on any hot path. +// detail output; not on any hot path. Membership is immutable per instance, +// so no lock is needed. func (pr *PoolResolver) Members(pool string) []string { if pr == nil { return nil } - pr.mu.RLock() - defer pr.mu.RUnlock() m, ok := pr.pools[pool] if !ok { return nil @@ -128,8 +206,6 @@ func (pr *PoolResolver) ResolveActive(name string) (member string, ok bool) { if pr == nil { return name, true } - pr.mu.RLock() - defer pr.mu.RUnlock() members, isPool := pr.pools[name] if !isPool { @@ -140,11 +216,17 @@ func (pr *PoolResolver) ResolveActive(name string) (member string, ok bool) { return "", false } + // Read the shared health map under its RLock. A concurrent failover's + // MarkCooldown takes the same map's write lock, so this observes a + // consistent cooldown view regardless of resolver generation. + pr.health.mu.RLock() + defer pr.health.mu.RUnlock() + now := time.Now() var soonest string var soonestUntil time.Time for _, m := range members { - h, tracked := pr.health[m] + h, tracked := pr.health.health[m] if !tracked || h.cooldownUntil.IsZero() || !h.cooldownUntil.After(now) { return m, true } @@ -167,49 +249,56 @@ func (pr *PoolResolver) MarkCooldown(credential string, until time.Time, reason if pr == nil { return } - pr.mu.Lock() - defer pr.mu.Unlock() + // Mutate the SHARED health map. Because every resolver generation points + // at the same *PoolHealth, a MarkCooldown that lands on an + // about-to-be-replaced generation is still observed by ResolveActive on + // the new generation — the pointer swap no longer races the cooldown + // (CRITICAL-1). Monotonic clear/set semantics are unchanged: a zero/past + // `until` clears the cooldown (recovery). + pr.health.mu.Lock() + defer pr.health.mu.Unlock() if until.IsZero() || !until.After(time.Now()) { - delete(pr.health, credential) + delete(pr.health.health, credential) return } - pr.health[credential] = memberHealth{cooldownUntil: until, reason: reason} + pr.health.health[credential] = memberHealth{cooldownUntil: until, reason: reason} } -// MergeLiveCooldowns carries forward still-active in-memory cooldowns from a -// previous resolver into this freshly built one. It MUST be called before the -// new resolver is atomically swapped in. -// -// Why this exists: Phase 2 failover records the active member's cooldown -// synchronously in the in-memory health map (MarkCooldown) and only persists -// SetCredentialHealth to the store from a detached best-effort goroutine. Any -// reload (SIGHUP, or the 2s data_version watcher firing on ANY unrelated DB -// write — a policy add, a cred update, an audit row) rebuilds the resolver -// from store rows alone via NewPoolResolver. Without this merge, a reload that -// races ahead of (or outlives a failed) durable write would seed health only -// from the store and silently resurrect a member that was just cooled down in -// memory, defeating the I1 synchronous-failover guarantee for the full -// cooldown TTL (60s/300s) — or permanently if the async store write failed. +// MergeLiveCooldowns is retained for API compatibility but is now a +// near-no-op. CRITICAL-1's race and permanent-loss bugs were fixed by +// making the cooldown map a process-wide shared *PoolHealth that every +// resolver generation points at (NewPoolResolverShared), so a cooldown +// recorded via MarkCooldown on any generation is already visible to +// ResolveActive on the new generation — there is nothing to "carry +// forward" because both generations mutate the SAME map under the SAME +// mutex. The pointer swap can no longer lose a cooldown, and durability no +// longer depends on the detached store write succeeding. // -// Merge policy: for every member that this resolver still knows about (so a -// cooldown for a credential removed from all pools is correctly dropped), the -// later of the store-seeded expiry and the previous in-memory expiry wins. -// Expired cooldowns on either side are not carried. This is monotonic: a live -// cooldown can never be shortened or erased by an unrelated reload. +// When prev and pr happen to share the same *PoolHealth (the normal server +// path) this is a pure no-op. The only case where it still does work is a +// defensive one: prev was built with a DIFFERENT (e.g. nil-defaulted) +// PoolHealth than pr — then still-live cooldowns are copied forward +// monotonically and orphaned members dropped, exactly as before, so the +// old single-generation callers are not regressed. func (pr *PoolResolver) MergeLiveCooldowns(prev *PoolResolver) { - if pr == nil || prev == nil { + if pr == nil || prev == nil || prev.health == nil || pr.health == nil { + return + } + if pr.health == prev.health { + // Shared health map: both generations already see the same + // cooldowns. Nothing to do — this is the CRITICAL-1 fix. return } now := time.Now() - prev.mu.RLock() - prevHealth := make(map[string]memberHealth, len(prev.health)) - for k, v := range prev.health { + prev.health.mu.RLock() + prevHealth := make(map[string]memberHealth, len(prev.health.health)) + for k, v := range prev.health.health { prevHealth[k] = v } - prev.mu.RUnlock() + prev.health.mu.RUnlock() - pr.mu.Lock() - defer pr.mu.Unlock() + pr.health.mu.Lock() + defer pr.health.mu.Unlock() for cred, ph := range prevHealth { if ph.cooldownUntil.IsZero() || !ph.cooldownUntil.After(now) { continue // expired in the old resolver; nothing to carry @@ -219,9 +308,9 @@ func (pr *PoolResolver) MergeLiveCooldowns(prev *PoolResolver) { if _, stillMember := pr.memberOf[cred]; !stillMember { continue } - existing, ok := pr.health[cred] + existing, ok := pr.health.health[cred] if !ok || ph.cooldownUntil.After(existing.cooldownUntil) { - pr.health[cred] = ph + pr.health.health[cred] = ph } } } @@ -234,9 +323,9 @@ func (pr *PoolResolver) CooldownUntil(credential string) (time.Time, bool) { if pr == nil { return time.Time{}, false } - pr.mu.RLock() - defer pr.mu.RUnlock() - h, ok := pr.health[credential] + pr.health.mu.RLock() + defer pr.health.mu.RUnlock() + h, ok := pr.health.health[credential] if !ok || h.cooldownUntil.IsZero() || !h.cooldownUntil.After(time.Now()) { return time.Time{}, false } diff --git a/internal/vault/pool_test.go b/internal/vault/pool_test.go index 1bf9e2e..80fcce3 100644 --- a/internal/vault/pool_test.go +++ b/internal/vault/pool_test.go @@ -1,6 +1,7 @@ package vault import ( + "sync" "testing" "time" @@ -196,3 +197,88 @@ func TestNilPoolResolverSafe(t *testing.T) { } pr.MarkCooldown("x", time.Now(), "") // must not panic } + +// TestSharedHealthSurvivesResolverRebuild is the CRITICAL-1 regression: +// when the long-lived path rebuilds the resolver against the SAME shared +// PoolHealth (every SIGHUP / data_version reload), a cooldown recorded on +// the OLD generation must be visible to ResolveActive on the NEW +// generation — with zero dependency on the detached durable store write. +func TestSharedHealthSurvivesResolverRebuild(t *testing.T) { + shared := NewPoolHealth() + gen1 := NewPoolResolverShared([]store.Pool{mkPool("pool", "a", "b")}, nil, shared) + + // Failover cools "a" on gen1. The store write has NOT landed (best + // effort/detached), so a rebuild sees no health rows. + gen1.MarkCooldown("a", time.Now().Add(120*time.Second), "429") + + // Reload rebuilds a fresh generation from store rows alone (empty), + // against the SAME shared health. + gen2 := NewPoolResolverShared([]store.Pool{mkPool("pool", "a", "b")}, nil, shared) + + if got, _ := gen2.ResolveActive("pool"); got != "b" { + t.Fatalf("gen2 active = %q, want b (cooldown on gen1 must survive the rebuild — CRITICAL-1)", got) + } + // And a MarkCooldown that lands on the OLD generation AFTER gen2 exists + // must still be observed by gen2 (no lost update across the swap). + gen1.MarkCooldown("b", time.Now().Add(120*time.Second), "401") + if _, cooling := gen2.CooldownUntil("b"); !cooling { + t.Fatal("MarkCooldown on old generation not visible on new generation — CRITICAL-1 lost-update race") + } +} + +// TestSharedHealthConcurrentMarkCooldownVsRebuild stresses the CRITICAL-1 +// race: MarkCooldown on rotating "old" generations racing continuous +// resolver rebuilds (the StorePool/reload swap) against one shared health. +// Run with -race. The invariant: a cooldown that was set is NEVER lost — +// every credential we cooled is still cooling when observed through the +// latest generation. +func TestSharedHealthConcurrentMarkCooldownVsRebuild(t *testing.T) { + shared := NewPoolHealth() + pool := mkPool("pool", "m0", "m1", "m2", "m3") + + var cur struct { + sync.RWMutex + pr *PoolResolver + } + cur.pr = NewPoolResolverShared([]store.Pool{pool}, nil, shared) + + const iters = 400 + far := 10 * time.Minute + done := make(chan struct{}) + + // Rebuilder: continuously swaps in a fresh generation bound to the + // SAME shared health (models StorePool's reload swap). + go func() { + for i := 0; i < iters; i++ { + fresh := NewPoolResolverShared([]store.Pool{pool}, nil, shared) + cur.Lock() + cur.pr = fresh + cur.Unlock() + } + close(done) + }() + + // Marker: cools members on whatever generation is current at the time + // (often an about-to-be-replaced one). Every cooldown uses a far-future + // expiry so it must still be live at the assertion. + members := pool.Members + for i := 0; i < iters; i++ { + cur.RLock() + g := cur.pr + cur.RUnlock() + m := members[i%len(members)].Credential + g.MarkCooldown(m, time.Now().Add(far), "429") + } + <-done + + // Observe through the latest generation: every member we cooled must + // still be cooling. None lost across any swap. + cur.RLock() + latest := cur.pr + cur.RUnlock() + for _, m := range members { + if _, cooling := latest.CooldownUntil(m.Credential); !cooling { + t.Fatalf("cooldown for %q was lost across resolver swaps (CRITICAL-1 race)", m.Credential) + } + } +} From 22de0ced0db72010f1ab7e3ef8236410e6b967f7 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 10:23:44 +0800 Subject: [PATCH 23/49] style: satisfy golangci-lint (errorlint, QF1002, unparam) --- internal/channel/channel_test.go | 11 ++++++----- internal/proxy/pool_failover.go | 8 ++++---- internal/proxy/pool_failover_test.go | 14 +++++++------- internal/proxy/pool_phantom_test.go | 9 +++++---- internal/store/pools.go | 18 ++++++++++-------- 5 files changed, 32 insertions(+), 28 deletions(-) diff --git a/internal/channel/channel_test.go b/internal/channel/channel_test.go index 528c899..840a8e6 100644 --- a/internal/channel/channel_test.go +++ b/internal/channel/channel_test.go @@ -865,8 +865,9 @@ func TestBrokerChannelErrorDoesNotBlockOthers(t *testing.T) { // dest:port and waits until the broker reports all n have attached to a // single primary waiter. It returns the primary request ID and a channel // that yields each call's (resp, err) result. -func fireCoalescedBurst(t *testing.T, broker *Broker, ch *mockChannel, dest string, port, n int, timeout time.Duration) (string, <-chan result) { +func fireCoalescedBurst(t *testing.T, broker *Broker, ch *mockChannel, dest string, n int, timeout time.Duration) (string, <-chan result) { t.Helper() + const port = 443 type res = result out := make(chan res, n) for i := 0; i < n; i++ { @@ -905,7 +906,7 @@ func TestBrokerCoalesceOneBroadcastFanToAll(t *testing.T) { broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) const n = 8 - primaryID, out := fireCoalescedBurst(t, broker, ch, "cas.example.com", 443, n, 5*time.Second) + primaryID, out := fireCoalescedBurst(t, broker, ch, "cas.example.com", n, 5*time.Second) // Exactly one prompt was broadcast for the whole burst. if got := len(ch.getRequests()); got != 1 { @@ -942,7 +943,7 @@ func TestBrokerCoalesceDenyFanOut(t *testing.T) { broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) const n = 5 - primaryID, out := fireCoalescedBurst(t, broker, ch, "deny.example.com", 443, n, 5*time.Second) + primaryID, out := fireCoalescedBurst(t, broker, ch, "deny.example.com", n, 5*time.Second) broker.Resolve(primaryID, ResponseDeny) for i := 0; i < n; i++ { @@ -962,7 +963,7 @@ func TestBrokerCoalesceTimeoutFanOut(t *testing.T) { // every subscriber. The primary itself returns the timeout error; // subscribers receive Deny via the fan-out (nil err, like any // terminal resolution). Every caller must end up denied. - _, out := fireCoalescedBurst(t, broker, ch, "slowburst.example.com", 443, n, 80*time.Millisecond) + _, out := fireCoalescedBurst(t, broker, ch, "slowburst.example.com", n, 80*time.Millisecond) timeoutErrs := 0 for i := 0; i < n; i++ { @@ -984,7 +985,7 @@ func TestBrokerCoalesceShutdownFanOut(t *testing.T) { broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) const n = 6 - _, out := fireCoalescedBurst(t, broker, ch, "shutdown.example.com", 443, n, 5*time.Second) + _, out := fireCoalescedBurst(t, broker, ch, "shutdown.example.com", n, 5*time.Second) broker.CancelAll() for i := 0; i < n; i++ { diff --git a/internal/proxy/pool_failover.go b/internal/proxy/pool_failover.go index c802aa5..dc1e083 100644 --- a/internal/proxy/pool_failover.go +++ b/internal/proxy/pool_failover.go @@ -76,12 +76,12 @@ func failoverReasonTag(class failoverClass, statusCode int, bodyTag string) stri // "invalid_grant" in unrelated prose). bodyTag returns the matched body // token (for the audit reason) when the decision came from the body. func classifyFailover(statusCode int, body []byte, isTokenEndpoint bool) (class failoverClass, bodyTag string) { - switch { - case statusCode == 429: + switch statusCode { + case 429: return failoverRateLimited, "" - case statusCode == 401: + case 401: return failoverAuthFailure, "" - case statusCode == 403: + case 403: if bodyContainsAny(body, "insufficient_quota", "quota_exceeded", "quota exhausted", "rate_limit_exceeded") { return failoverRateLimited, "" } diff --git a/internal/proxy/pool_failover_test.go b/internal/proxy/pool_failover_test.go index b650542..cd122bf 100644 --- a/internal/proxy/pool_failover_test.go +++ b/internal/proxy/pool_failover_test.go @@ -87,7 +87,7 @@ func TestClassifyFailover(t *testing.T) { // pooled destination, the very NEXT ResolveActive call returns the next // member — without any reliance on the 2s store-reconcile watcher (Risk I1). func TestFailoverSynchronousHealthSwap(t *testing.T) { - addon, _, prPtr := setupPoolAddon(t, "codex_pool", "memA", "memB") + addon, _, prPtr := setupPoolAddon(t, "memA", "memB") client := setupAddonConn(addon, "auth.example.com:443") pr := prPtr.Load() @@ -134,7 +134,7 @@ func TestFailoverCooldownTTLAndLazyRecovery(t *testing.T) { t.Fatalf("AuthFailCooldown = %v, want 300s", vault.AuthFailCooldown) } - addon, _, prPtr := setupPoolAddon(t, "codex_pool", "memA", "memB") + addon, _, prPtr := setupPoolAddon(t, "memA", "memB") client := setupAddonConn(addon, "auth.example.com:443") pr := prPtr.Load() @@ -162,7 +162,7 @@ func TestFailoverCooldownTTLAndLazyRecovery(t *testing.T) { // TestFailoverNoopForNonPooledAndSuccess asserts the failover path is a // no-op for a successful response and never invokes the callback. func TestFailoverNoopForSuccessfulResponse(t *testing.T) { - addon, _, prPtr := setupPoolAddon(t, "codex_pool", "memA", "memB") + addon, _, prPtr := setupPoolAddon(t, "memA", "memB") client := setupAddonConn(addon, "auth.example.com:443") called := false @@ -191,7 +191,7 @@ func TestFailoverNoopForSuccessfulResponse(t *testing.T) { // assert Response itself never waits on callback-internal work by having the // callback spawn the slow part and return immediately, mirroring main.go). func TestFailoverNoticeNonBlocking(t *testing.T) { - addon, _, _ := setupPoolAddon(t, "codex_pool", "memA", "memB") + addon, _, _ := setupPoolAddon(t, "memA", "memB") client := setupAddonConn(addon, "auth.example.com:443") done := make(chan struct{}) @@ -221,7 +221,7 @@ func TestFailoverNoticeNonBlocking(t *testing.T) { // TestFailoverNonPooledDestinationIgnored asserts a response whose // destination is NOT bound to a pool never triggers failover. func TestFailoverNonPooledDestinationIgnored(t *testing.T) { - addon, _, _ := setupPoolAddon(t, "codex_pool", "memA", "memB") + addon, _, _ := setupPoolAddon(t, "memA", "memB") // Connect to a destination with no pooled binding. client := setupAddonConn(addon, "unrelated.example.com:443") @@ -258,7 +258,7 @@ func TestFailoverAuditEvent(t *testing.T) { } t.Cleanup(func() { _ = logger.Close() }) - addon, _, _ := setupPoolAddon(t, "codex_pool", "memA", "memB") + addon, _, _ := setupPoolAddon(t, "memA", "memB") addon.auditLog = logger client := setupAddonConn(addon, "auth.example.com:443") @@ -303,7 +303,7 @@ func TestFailoverAuditEvent(t *testing.T) { // TestPoolForResponseResolvesActiveMember sanity-checks the destination -> // pool reverse mapping used by handlePoolFailover. func TestPoolForResponseResolvesActiveMember(t *testing.T) { - addon, _, prPtr := setupPoolAddon(t, "codex_pool", "memA", "memB") + addon, _, prPtr := setupPoolAddon(t, "memA", "memB") client := setupAddonConn(addon, "auth.example.com:443") f := newPoolRespFlow(client, 429, nil) diff --git a/internal/proxy/pool_phantom_test.go b/internal/proxy/pool_phantom_test.go index 640dcd1..991ec43 100644 --- a/internal/proxy/pool_phantom_test.go +++ b/internal/proxy/pool_phantom_test.go @@ -34,8 +34,9 @@ func poolMemberCred(t *testing.T, access, refresh string) string { // setupPoolAddon wires a SluiceAddon with a two-member pool bound to // auth.example.com. Both members share testOAuthTokenURL (the Risk R1 // collision shape: two Codex accounts behind one OpenAI token endpoint). -func setupPoolAddon(t *testing.T, poolName, memberA, memberB string) (*SluiceAddon, *addonWritableProvider, *atomic.Pointer[vault.PoolResolver]) { +func setupPoolAddon(t *testing.T, memberA, memberB string) (*SluiceAddon, *addonWritableProvider, *atomic.Pointer[vault.PoolResolver]) { t.Helper() + const poolName = "codex_pool" provider := &addonWritableProvider{ creds: map[string]string{ @@ -131,7 +132,7 @@ func TestR3PoolPhantomByteIdenticalAcrossMemberSwitch(t *testing.T) { // End-to-end: the access phantom the agent receives in a token-endpoint // response must be identical when member A is active and after failover // to member B (members have DIFFERENT real access tokens). - addon, _, prPtr := setupPoolAddon(t, "codex_pool", "codexA", "codexB") + addon, _, prPtr := setupPoolAddon(t, "codexA", "codexB") client := setupAddonConn(addon, "auth.example.com:443") // Member A active. Request body carries A's real refresh token (as if @@ -189,7 +190,7 @@ func TestR3PoolPhantomByteIdenticalAcrossMemberSwitch(t *testing.T) { // response is persisted to B's vault entry, never A's, even though both // members share one token URL (OAuthIndex.Match is 1:1 and collides). func TestR1RefreshAttributionByInjectedRefreshToken(t *testing.T) { - addon, provider, prPtr := setupPoolAddon(t, "codex_pool", "memA", "memB") + addon, provider, prPtr := setupPoolAddon(t, "memA", "memB") client := setupAddonConn(addon, "auth.example.com:443") // --- Member A round-trip via the real pass-2 path. --- @@ -260,7 +261,7 @@ func TestR1RefreshAttributionByInjectedRefreshToken(t *testing.T) { // is still swapped to phantoms (agent safe) but ZERO vault writes occur — no // guess, no fallback to OAuthIndex.Match. func TestR1FailClosedWhenMemberTagMissing(t *testing.T) { - addon, provider, _ := setupPoolAddon(t, "codex_pool", "memA", "memB") + addon, provider, _ := setupPoolAddon(t, "memA", "memB") client := setupAddonConn(addon, "auth.example.com:443") beforeA := provider.creds["memA"] diff --git a/internal/store/pools.go b/internal/store/pools.go index d540f09..5f843af 100644 --- a/internal/store/pools.go +++ b/internal/store/pools.go @@ -2,6 +2,7 @@ package store import ( "database/sql" + "errors" "fmt" "time" ) @@ -59,7 +60,7 @@ func parseHealthTime(s sql.NullString) time.Time { func (s *Store) PoolExists(name string) (bool, error) { var one int err := s.db.QueryRow("SELECT 1 FROM credential_pools WHERE name = ?", name).Scan(&one) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return false, nil } if err != nil { @@ -79,7 +80,7 @@ func validatePoolMemberTx(tx *sql.Tx, credential string) error { err := tx.QueryRow( "SELECT cred_type, token_url FROM credential_meta WHERE name = ?", credential, ).Scan(&credType, &tokenURL) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("credential %q does not exist (add it with --type oauth first)", credential) } if err != nil { @@ -133,13 +134,14 @@ func (s *Store) CreatePoolWithMembers(name, strategy string, members []string) e // Namespace mutual-exclusion: a pool must not shadow a credential. var credName string - switch err := tx.QueryRow("SELECT name FROM credential_meta WHERE name = ?", name).Scan(&credName); { - case err == nil: + collErr := tx.QueryRow("SELECT name FROM credential_meta WHERE name = ?", name).Scan(&credName) + switch { + case collErr == nil: return fmt.Errorf("name %q is already a credential; pool and credential names share one namespace", name) - case err == sql.ErrNoRows: + case errors.Is(collErr, sql.ErrNoRows): // ok default: - return fmt.Errorf("check name collision for %q: %w", name, err) + return fmt.Errorf("check name collision for %q: %w", name, collErr) } if _, err := tx.Exec( @@ -173,7 +175,7 @@ func (s *Store) GetPool(name string) (*Pool, error) { err := s.db.QueryRow( "SELECT name, strategy, created_at FROM credential_pools WHERE name = ?", name, ).Scan(&p.Name, &p.Strategy, &p.CreatedAt) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } if err != nil { @@ -328,7 +330,7 @@ func (s *Store) GetCredentialHealth(credential string) (*CredentialHealth, error "SELECT credential, status, cooldown_until, last_failure_reason, updated_at FROM credential_health WHERE credential = ?", credential, ).Scan(&h.Credential, &h.Status, &cu, &reason, &h.UpdatedAt) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } if err != nil { From fe64664a5bd2187dcec8ca07ab65971edea80053 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 10:37:39 +0800 Subject: [PATCH 24/49] test(e2e): pool failover + approval coalescing end-to-end --- .../completed/20260515-approval-coalescing.md | 2 +- .../20260515-credential-pool-failover.md | 24 +- e2e/approval_coalesce_test.go | 258 +++++++++ e2e/pool_failover_test.go | 523 ++++++++++++++++++ 4 files changed, 805 insertions(+), 2 deletions(-) create mode 100644 e2e/approval_coalesce_test.go create mode 100644 e2e/pool_failover_test.go diff --git a/docs/plans/completed/20260515-approval-coalescing.md b/docs/plans/completed/20260515-approval-coalescing.md index d5a59a2..c883fe1 100644 --- a/docs/plans/completed/20260515-approval-coalescing.md +++ b/docs/plans/completed/20260515-approval-coalescing.md @@ -97,7 +97,7 @@ Verified against the working tree on `main` (tip `20cc367`): ### Task 5: Verify acceptance + docs -- [x] verify the prompt-wall scenario via e2e (burst → one prompt → one tap dismisses all). The burst→one-prompt→fan-out scenario is verified at the unit/integration level: `internal/channel` 11 coalescing tests (TestBrokerCoalesceOneBroadcastFanToAll, ...DenyFanOut, ...TimeoutFanOut, ...ShutdownFanOut, ...SubTimeoutDoesNotBlockFanOut, ...LateAttachOpensNewPrompt, ...ConcurrentResolveAndAttach, TestBrokerDistinctDestNotCoalesced, TestBrokerSamePortDifferentDestNotCoalesced, TestBrokerWithNoCoalesceNeverCoalesces, TestBrokerCoalesceCrossChannelFirstWins) + `internal/proxy` TestRequestPolicyChecker_ConcurrentAllowOnceCoalesces / _SSHStyleConnectionLevelCoalesces + telegram TestHandleCallbackRendersCoalescedCount / TestCancelApprovalRendersCoalescedCount. [x] (skipped: dedicated burst e2e) — the `e2e/` suite has no delayed-verdict-server helper to keep a first approval pending while a concurrent burst arrives (the verdict server answers synchronously), so a true broker-coalescing e2e cannot be expressed without new harness code, which is out of scope for Task 5. The non-container e2e suite (66 tests, `-tags=e2e`) was run and passes, exercising the same `resolveAsk → broker.Request` Ask path via TestPerRequestAllowOnce*/AlwaysAllow*/Deny*. +- [x] verify the prompt-wall scenario via e2e (burst → one prompt → one tap dismisses all). The burst→one-prompt→fan-out scenario is verified at the unit/integration level: `internal/channel` 11 coalescing tests (TestBrokerCoalesceOneBroadcastFanToAll, ...DenyFanOut, ...TimeoutFanOut, ...ShutdownFanOut, ...SubTimeoutDoesNotBlockFanOut, ...LateAttachOpensNewPrompt, ...ConcurrentResolveAndAttach, TestBrokerDistinctDestNotCoalesced, TestBrokerSamePortDifferentDestNotCoalesced, TestBrokerWithNoCoalesceNeverCoalesces, TestBrokerCoalesceCrossChannelFirstWins) + `internal/proxy` TestRequestPolicyChecker_ConcurrentAllowOnceCoalesces / _SSHStyleConnectionLevelCoalesces + telegram TestHandleCallbackRendersCoalescedCount / TestCancelApprovalRendersCoalescedCount. [x] **dedicated burst e2e now implemented**: `e2e/approval_coalesce_test.go` (`TestApprovalCoalesce_BurstOnePrompt`, `e2e` build tag, CI non-container jobs). It adds a `gatedVerdictServer` HTTP webhook backend that HOLDS the first approval decision until the test releases it — the missing piece the synchronous `verdictServer` could not express. A burst of 8 concurrent SOCKS5 CONNECTs to one Ask `dest:port` is launched while the first approval is held pending; the test asserts exactly ONE approval webhook call for the whole burst (coalesced subs never re-prompt), peak concurrent approval handlers at the webhook == 1, and after a single `always_allow` release all 8 requests fan out and succeed. Proven non-vacuous: forcing the broker's coalesce branch off (`if false`) makes the test fail with "approval webhook calls during pending window = N, want exactly 1". The `WithNoCoalesce`/MCP opt-out is NOT expressible at the e2e level (MCP gateway has no SOCKS5 burst surface) and remains covered by the existing unit tests (TestBrokerWithNoCoalesceNeverCoalesces + the MCP `gateway.go` `WithNoCoalesce()` wiring). The non-container e2e suite (61 top-level PASS / 0 FAIL with both new tests added, `-tags=e2e`) was run and passes. - [x] run full suite `go test ./... -timeout 120s` (2524 passed, 13 packages); ran e2e `go test -tags=e2e ./e2e/ -count=1 -timeout=300s` (66 passed, non-container `e2e` tag). [x] (skipped: docker/apple-container e2e — `e2e && linux` / `e2e && darwin` compose/Apple-Container tags not run; the burst-coalescing scenario is verified by unit/integration tests above, container e2e adds no coalescing-specific coverage). - [x] update CLAUDE.md "Channel/approval abstraction" + "QUIC broker dedup" notes to mention broker-level coalescing. - [x] move plan to `docs/plans/completed/`. diff --git a/docs/plans/completed/20260515-credential-pool-failover.md b/docs/plans/completed/20260515-credential-pool-failover.md index dd97f89..8ffc477 100644 --- a/docs/plans/completed/20260515-credential-pool-failover.md +++ b/docs/plans/completed/20260515-credential-pool-failover.md @@ -105,7 +105,29 @@ rotate` is an operator override, not the primary mechanism. - [x] update CLAUDE.md credential-pool/failover notes — added `### Credential pools and auto-failover` (pool concept, `sluice pool` CLI, migration 000006 tables, Phase 1 chokepoint + R1 fail-closed attribution + R3 pool-stable JWT, Phase 2 classification + synchronous failover + `cred_failover` audit + Telegram notice + cooldown TTLs). - [x] move plan to `docs/plans/completed/`. -> **E2e gap (Testing Strategy item, honestly noted):** the dedicated two-fake-OAuth-upstreams pool-failover e2e (assert A used until 429 → switch to B → B's refreshed tokens land in B's vault not A's → phantom access JWT byte-identical across failover) was **not added**. Standing up a full e2e harness with JWT-issuing fake token endpoints, pool wiring through SOCKS5 + MITM, and 429-then-switch assertions is a substantial new harness beyond the reasonable scope of this verify+docs task. The failover behavior it would cover is already exercised by unit tests added in Tasks 2 & 3 (`internal/vault/pool_test.go`, `internal/proxy/pool_failover_test.go`): R1 collision/fail-closed, R3 phantom byte-identity, classification, synchronous health swap, cooldown TTL/lazy recovery, non-blocking notice. The existing non-container e2e suite was run in full (64 tests, all passing). Recommend tracking the pool-failover e2e as follow-up future work. +> **E2e (Testing Strategy item — now implemented):** the dedicated +> two-fake-OAuth-upstreams pool-failover e2e exists at +> `e2e/pool_failover_test.go` (`TestPoolFailover_EndToEnd`, `e2e` build +> tag, runs in CI's non-container E2E jobs). It stands up one TLS upstream +> serving both an OAuth `/token` endpoint that mints **real JWT** access +> tokens (HMAC-signed, per-member-and-per-refresh unique) and a protected +> `/api` endpoint that returns 429 for member A's real access token and +> 200 for member B's. It creates two OAuth members sharing one token URL +> (the R1 collision), a `failover` pool over them, and a Bearer-header +> binding to the API destination, all driven through the live SOCKS5 + +> MITM proxy. Asserts: (a) member A used until `/api` 429s; (b) the next +> API request fails over to member B and 200s (synchronous in-memory +> health swap, no cooldown wait); (c) member B's rotated tokens land in +> member B's vault entry and member A's vault entry is NOT clobbered with +> B's tokens (Risk R1 attribution via the injected real refresh token, +> JWT payload decoded to verify the `member` claim); (d) the phantom +> access token the agent receives from `/token` is byte-identical before +> and after the failover (Risk R3 pool-keyed synthetic JWT) and no real +> upstream JWT or real refresh token ever leaks to the agent; the +> `cred_failover` audit event is present. Proven non-vacuous: neutering +> `PoolResolver.MarkCooldown` (so failover never switches members) makes +> the test fail at assertion (b). The full non-container e2e suite was +> run with the two new tests added: 61 top-level PASS / 0 FAIL. ## Out of scope / future work diff --git a/e2e/approval_coalesce_test.go b/e2e/approval_coalesce_test.go new file mode 100644 index 0000000..b205fc0 --- /dev/null +++ b/e2e/approval_coalesce_test.go @@ -0,0 +1,258 @@ +//go:build e2e + +package e2e + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" +) + +// gatedVerdictServer is a webhook channel backend that HOLDS the first +// approval decision until the test explicitly releases it. This is the +// piece the existing synchronous verdictServer cannot express: to observe +// broker-level approval coalescing we need a first approval to stay pending +// while a concurrent burst of requests to the same dest:port arrives. With +// a synchronous server every request would get an instant verdict and the +// burst would never overlap a pending approval. +// +// Behavior: +// - Every approval POST increments approvalCalls and is recorded. +// - The FIRST approval POST blocks on the release channel. The broker +// coalesces concurrent same-dest:port requests into that one pending +// waiter, so a correctly-coalescing sluice delivers exactly ONE +// approval POST for the whole burst. +// - After release, the held call (and any further calls) return the +// configured verdict. +// +// maxConcurrent tracks the peak number of approval handlers in flight, +// which lets the test prove the burst genuinely overlapped the pending +// approval rather than serializing behind it. +type gatedVerdictServer struct { + verdict string + + release chan struct{} + once sync.Once + + mu sync.Mutex + calls int + approvalCalls int + cancelCalls int + requests []map[string]interface{} + + inFlight atomic.Int64 + maxConcurrent atomic.Int64 +} + +func newGatedVerdictServer(verdict string) *gatedVerdictServer { + return &gatedVerdictServer{ + verdict: verdict, + release: make(chan struct{}), + } +} + +// Release unblocks the held first approval. Safe to call once. +func (g *gatedVerdictServer) Release() { + g.once.Do(func() { close(g.release) }) +} + +func (g *gatedVerdictServer) ApprovalCalls() int { + g.mu.Lock() + defer g.mu.Unlock() + return g.approvalCalls +} + +func (g *gatedVerdictServer) CancelCalls() int { + g.mu.Lock() + defer g.mu.Unlock() + return g.cancelCalls +} + +func (g *gatedVerdictServer) MaxConcurrent() int64 { + return g.maxConcurrent.Load() +} + +func (g *gatedVerdictServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "read body failed", http.StatusInternalServerError) + return + } + var parsed map[string]interface{} + if err := json.Unmarshal(body, &parsed); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + + reqType, _ := parsed["type"].(string) + + g.mu.Lock() + g.calls++ + g.requests = append(g.requests, parsed) + if reqType != "approval" { + if reqType == "cancel" { + g.cancelCalls++ + } + g.mu.Unlock() + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + return + } + g.approvalCalls++ + isFirst := g.approvalCalls == 1 + g.mu.Unlock() + + // Track concurrency of approval handlers in flight. + n := g.inFlight.Add(1) + for { + cur := g.maxConcurrent.Load() + if n <= cur || g.maxConcurrent.CompareAndSwap(cur, n) { + break + } + } + defer g.inFlight.Add(-1) + + if isFirst { + // Hold the decision so the broker accumulates coalesced subs. + <-g.release + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"verdict": g.verdict}) +} + +func startGatedVerdictServer(t *testing.T, verdict string) (*httptest.Server, *gatedVerdictServer) { + t.Helper() + g := newGatedVerdictServer(verdict) + srv := newIPv4Server(t, g) + t.Cleanup(srv.Close) + return srv, g +} + +// TestApprovalCoalesce_BurstOnePrompt is the GAP 2 e2e: a burst of +// concurrent requests to ONE Ask destination through the proxy must +// produce exactly ONE approval prompt (broker coalescing), and one +// resolve must fan out so ALL requests proceed. +// +// The gated verdict server holds the first approval pending while the +// rest of the burst arrives, so the coalescing window is real (not a +// synchronous race). Asserts: +// +// - exactly ONE approval webhook call for the whole burst, +// - all N requests succeed after the single release, +// - the peak concurrency at the webhook is 1 (the burst coalesced into +// the single held approval rather than each firing its own POST). +func TestApprovalCoalesce_BurstOnePrompt(t *testing.T) { + backend := startTLSEchoServer(t) + host, port := mustSplitAddr(t, backend.URL) + + // always_allow so the single resolve both fans out to every coalesced + // waiter AND persists a rule (the persisted rule is keyed dest:port — + // the same key the broker coalesces on). + srv, g := startGatedVerdictServer(t, "always_allow") + + config := fmt.Sprintf(` +[policy] +default = "deny" + +[[ask]] +destination = "%s" +ports = [%s] +name = "ask backend" +`, host, port) + + proc := sluiceWithWebhook(t, config, srv.URL) + + const burst = 8 + + type result struct { + status int + err error + } + results := make(chan result, burst) + + // Launch the burst. Each goroutine opens its OWN SOCKS5 CONNECT to the + // same dest:port, so each is an independent connection-level Ask that + // hits broker.Request with the same dedup key. The first opens the + // prompt (held by the gated server); the rest must coalesce onto it. + var launched sync.WaitGroup + for i := 0; i < burst; i++ { + launched.Add(1) + go func() { + launched.Done() + status, _, err := tryHTTPGetViaSOCKS5(t, proc.ProxyAddr, backend.URL+"/burst") + results <- result{status: status, err: err} + }() + } + launched.Wait() + + // Give the burst time to reach the broker and coalesce behind the + // single held approval before releasing. The first approval POST is + // blocked in the gated server during this window. + deadline := time.Now().Add(8 * time.Second) + for { + if g.ApprovalCalls() >= 1 { + break + } + if time.Now().After(deadline) { + t.Fatal("no approval webhook call arrived; broker never delivered the prompt") + } + time.Sleep(50 * time.Millisecond) + } + // Let the rest of the burst pile up onto the pending waiter. + time.Sleep(1500 * time.Millisecond) + + // Exactly ONE approval prompt for the whole burst. + if got := g.ApprovalCalls(); got != 1 { + t.Fatalf("approval webhook calls during pending window = %d, want exactly 1 (burst must coalesce)", got) + } + + // Release the single held decision; it must fan out to ALL waiters. + g.Release() + + // Collect all results. + oks := 0 + for i := 0; i < burst; i++ { + select { + case r := <-results: + if r.err != nil { + t.Errorf("burst request %d errored: %v", i, r.err) + continue + } + if r.status == http.StatusOK { + oks++ + } else { + t.Errorf("burst request %d: status=%d, want 200", i, r.status) + } + case <-time.After(20 * time.Second): + t.Fatalf("burst request %d never completed (fan-out broke)", i) + } + } + + if oks != burst { + t.Fatalf("only %d/%d burst requests succeeded after single resolve", oks, burst) + } + + // Still exactly one approval call total: the coalesced subs must not + // have triggered their own webhook deliveries. + if got := g.ApprovalCalls(); got != 1 { + t.Fatalf("total approval webhook calls = %d, want exactly 1 (coalesced subs must not re-prompt)", got) + } + + // Peak concurrency at the webhook proves the burst overlapped the + // pending approval (it was 1 because all but the first coalesced and + // never reached the webhook). + if mc := g.MaxConcurrent(); mc != 1 { + t.Fatalf("peak concurrent approval handlers = %d, want 1 (more than 1 means requests did NOT coalesce)", mc) + } +} diff --git a/e2e/pool_failover_test.go b/e2e/pool_failover_test.go new file mode 100644 index 0000000..4900b11 --- /dev/null +++ b/e2e/pool_failover_test.go @@ -0,0 +1,523 @@ +//go:build e2e + +package e2e + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/nemirovsky/sluice/internal/vault" +) + +// fakeOAuthUpstream is a single TLS server that plays two roles for the +// credential-pool failover e2e: +// +// - POST /token : an OAuth refresh-grant token endpoint. It inspects the +// refresh_token in the request body (which sluice has already swapped +// from the pool phantom to the *active member's* real refresh token) to +// learn WHICH member is refreshing, mints a fresh real JWT access token +// plus a rotated real refresh token, and returns them. The minted tokens +// are unique per member so the test can prove B's rotated tokens land in +// B's vault entry, not A's (Risk R1). +// +// - GET /api : a protected API endpoint. It reads the Bearer token +// (which sluice injected as the active member's real access token) and +// returns 429 for memberA's real access token and 200 for memberB's. +// That 429 is what drives sluice's auto-failover from A to B. +// +// The token endpoint issues *real* JWTs (header.payload.signature, signed +// with an HMAC test key) so the "phantom access token byte-identical across +// the failover" assertion is meaningful: sluice must re-key the phantom on +// the POOL name, not on the per-member real JWT. +type fakeOAuthUpstream struct { + mu sync.Mutex + + // realRefreshToMember maps a member's *current* real refresh token to + // the member name. Seeded with the initial vault refresh tokens and + // updated on every rotation so a follow-up refresh round-trip is still + // attributable. + realRefreshToMember map[string]string + + // realAccessToMember maps a minted real access token to the member it + // was minted for. Used by /api to decide 429 (memberA) vs 200 (memberB). + realAccessToMember map[string]string + + // counters + tokenCalls map[string]int + apiCalls map[string]int +} + +func newFakeOAuthUpstream() *fakeOAuthUpstream { + return &fakeOAuthUpstream{ + realRefreshToMember: map[string]string{}, + realAccessToMember: map[string]string{}, + tokenCalls: map[string]int{}, + apiCalls: map[string]int{}, + } +} + +// seedMember registers a member's initial real refresh AND access tokens so +// the first refresh round-trip is attributable and an /api call made with +// the seed access token (before any /token mint) is still recognized. +func (u *fakeOAuthUpstream) seedMember(member, initialRefresh, initialAccess string) { + u.mu.Lock() + defer u.mu.Unlock() + u.realRefreshToMember[initialRefresh] = member + u.realAccessToMember[initialAccess] = member +} + +// mintJWT builds a real, structurally valid JWT whose payload encodes the +// member name and a monotonically increasing counter, signed with an HMAC +// key. Every call returns a DIFFERENT token (distinct per member and per +// refresh) so the test can prove sluice does not leak the real JWT to the +// agent (the agent must always see the pool-stable phantom instead). +func (u *fakeOAuthUpstream) mintJWT(member string, n int) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) + payload := base64.RawURLEncoding.EncodeToString([]byte( + fmt.Sprintf(`{"sub":"real-%s","member":"%s","n":%d,"iss":"fake-upstream"}`, member, member, n), + )) + signingInput := header + "." + payload + mac := hmac.New(sha256.New, []byte("fake-upstream-hmac-key")) + mac.Write([]byte(signingInput)) + sig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + return signingInput + "." + sig +} + +func (u *fakeOAuthUpstream) handler() http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + // sluice has already swapped the pool phantom for the active + // member's REAL refresh token by the time the request reaches + // the upstream, so the body carries the real refresh token. + var refreshToken string + // RFC 6749 form encoding (what sluice's CLI-added bindings use). + if vals, err := parseFormBody(body); err == nil { + refreshToken = vals + } + + u.mu.Lock() + member, known := u.realRefreshToMember[refreshToken] + if !known { + u.mu.Unlock() + http.Error(w, `{"error":"invalid_grant"}`, http.StatusBadRequest) + return + } + u.tokenCalls[member]++ + n := u.tokenCalls[member] + newAccess := u.mintJWT(member, n) + newRefresh := fmt.Sprintf("real-refresh-%s-rot-%d", member, n) + // Rotate: the old refresh token is single-use; register the new + // one so a subsequent refresh by the same member still resolves. + delete(u.realRefreshToMember, refreshToken) + u.realRefreshToMember[newRefresh] = member + u.realAccessToMember[newAccess] = member + u.mu.Unlock() + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": newAccess, + "refresh_token": newRefresh, + "token_type": "Bearer", + "expires_in": 3600, + }) + }) + + mux.HandleFunc("/api", func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + bearer := strings.TrimPrefix(auth, "Bearer ") + + u.mu.Lock() + member := u.realAccessToMember[bearer] + u.apiCalls[member]++ + u.mu.Unlock() + + switch member { + case "memberA": + // memberA is rate-limited: this is the failover trigger. + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":"rate_limited"}`)) + case "memberB": + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte("api-ok member=memberB\n")) + default: + // Unknown bearer (phantom leaked, or unexpected token). + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte("unknown bearer: " + bearer + "\n")) + } + }) + + return mux +} + +func (u *fakeOAuthUpstream) TokenCalls(member string) int { + u.mu.Lock() + defer u.mu.Unlock() + return u.tokenCalls[member] +} + +func (u *fakeOAuthUpstream) APICalls(member string) int { + u.mu.Lock() + defer u.mu.Unlock() + return u.apiCalls[member] +} + +// parseFormBody extracts the refresh_token field from an +// application/x-www-form-urlencoded body. Returns the value or "". +func parseFormBody(body []byte) (string, error) { + for _, kv := range strings.Split(string(body), "&") { + parts := strings.SplitN(kv, "=", 2) + if len(parts) != 2 { + continue + } + if parts[0] == "refresh_token" { + // Values are not percent-encoded in our test client, but + // handle the common case anyway. + return strings.ReplaceAll(parts[1], "%2F", "/"), nil + } + } + return "", fmt.Errorf("no refresh_token") +} + +// startFakeOAuthUpstreamWithCA starts the fake upstream over TLS using a +// cert signed by the supplied test CA so sluice's MITM transport trusts it. +func startFakeOAuthUpstreamWithCA(t *testing.T, ca *testCA, u *fakeOAuthUpstream) *httptest.Server { + t.Helper() + + serverKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generate server key: %v", err) + } + serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: "127.0.0.1"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, ca.X509, &serverKey.PublicKey, ca.Cert.PrivateKey) + if err != nil { + t.Fatalf("create server cert: %v", err) + } + srvCert := tls.Certificate{ + Certificate: [][]byte{der, ca.Cert.Certificate[0]}, + PrivateKey: serverKey, + } + + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + srv := &httptest.Server{ + Listener: ln, + TLS: &tls.Config{Certificates: []tls.Certificate{srvCert}}, + Config: &http.Server{Handler: u.handler()}, + } + srv.StartTLS() + t.Cleanup(srv.Close) + return srv +} + +// readVaultOAuth opens the vault store and returns the parsed OAuth +// credential for the given name. +func readVaultOAuth(t *testing.T, vaultDir, name string) *vault.OAuthCredential { + t.Helper() + vs, err := vault.NewStore(vaultDir) + if err != nil { + t.Fatalf("open vault store: %v", err) + } + sb, err := vs.Get(name) + if err != nil { + t.Fatalf("vault get %q: %v", name, err) + } + defer sb.Release() + cred, err := vault.ParseOAuth(sb.Bytes()) + if err != nil { + t.Fatalf("parse oauth %q: %v", name, err) + } + return cred +} + +// TestPoolFailover_EndToEnd is the GAP 1 e2e: two fake OAuth members behind +// one pool. It asserts: +// +// (a) member A is used until its API call returns 429, +// (b) sluice fails over so the NEXT request uses member B, +// (c) member B's refreshed tokens persist to B's vault entry, NOT A's +// (the R1 attribution: both members share one fake token URL), +// (d) the phantom access token the agent receives is byte-identical +// before and after the failover (the token endpoint issues real +// JWTs so this is a real test of pool-keyed phantom stability). +func TestPoolFailover_EndToEnd(t *testing.T) { + tmpDir := t.TempDir() + vaultDir := filepath.Join(tmpDir, "vault") + ca := generateTestCA(t, vaultDir) + + caCertFile := filepath.Join(tmpDir, "ca-bundle.pem") + if err := os.WriteFile(caCertFile, ca.CertPEM, 0o644); err != nil { + t.Fatalf("write CA bundle: %v", err) + } + + up := newFakeOAuthUpstream() + srv := startFakeOAuthUpstreamWithCA(t, ca, up) + host, portStr := mustSplitAddr(t, srv.URL) + tokenURL := srv.URL + "/token" + apiURL := srv.URL + "/api" + + // Initial real refresh tokens for each member. The fake upstream maps + // these to the member so the first refresh round-trip is attributable. + const ( + memARefresh = "real-refresh-memberA-seed" + memBRefresh = "real-refresh-memberB-seed" + memAAccess = "real-access-memberA-seed" + memBAccess = "real-access-memberB-seed" + ) + up.seedMember("memberA", memARefresh, memAAccess) + up.seedMember("memberB", memBRefresh, memBAccess) + + // Policy: allow the upstream host:port (covers both /token and /api), + // and trust the test CA so MITM works. + config := fmt.Sprintf(` +[policy] +default = "deny" + +[vault] +provider = "age" +dir = %q + +[[allow]] +destination = %q +ports = [%s] +name = "allow fake upstream" +`, vaultDir, host, portStr) + + proc := startSluice(t, SluiceOpts{ + ConfigTOML: config, + Env: []string{ + "SSL_CERT_FILE=" + caCertFile, + "SSL_CERT_DIR=", + }, + }) + + // Add the two OAuth members. They share ONE token URL (the R1 + // collision scenario): two members, one fake token endpoint. + addOAuthMember := func(name, access, refresh string) { + binary := buildSluice(t) + cmd := exec.Command(binary, "cred", "add", "--db", proc.DBPath, + "--type", "oauth", "--token-url", tokenURL, name) + cmd.Stdin = strings.NewReader(access + "\n" + refresh + "\n") + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("cred add %s: %v\n%s", name, err, out) + } + } + addOAuthMember("memberA", memAAccess, memARefresh) + addOAuthMember("memberB", memBAccess, memBRefresh) + + // Create the failover pool with A first, B second. + runSluiceCLI(t, proc, "pool", "create", "codexpool", "--members", "memberA,memberB") + + // Bind the pool to the upstream destination with a Bearer header so + // the active member's real access token is injected into /api calls. + bindingCmd := exec.Command(buildSluice(t), "binding", "add", "--db", proc.DBPath, + "--destination", host, "--ports", portStr, + "--header", "Authorization", "--template", "Bearer {value}", + "codexpool") + if out, err := bindingCmd.CombinedOutput(); err != nil { + t.Fatalf("binding add codexpool: %v\n%s", err, out) + } + + // Reload so the pool resolver, OAuth index, and binding resolver pick + // up the new state. + sendSIGHUP(t, proc) + + // The agent holds the POOL phantom for the refresh token. It is the + // deterministic static string SLUICE_PHANTOM:.refresh. + poolRefreshPhantom := "SLUICE_PHANTOM:codexpool.refresh" + + // doRefresh posts a refresh-grant to the token endpoint through the + // proxy. sluice swaps the pool phantom for the active member's real + // refresh token, the upstream mints a real JWT, and sluice swaps the + // real JWT for the pool-stable phantom before the body reaches us. + // Returns the phantom access token the "agent" receives. + doRefresh := func() string { + body := "grant_type=refresh_token&refresh_token=" + poolRefreshPhantom + status, respBody := httpsRequestViaSOCKS5(t, proc.ProxyAddr, "POST", tokenURL, + map[string]string{"Content-Type": "application/x-www-form-urlencoded"}, body) + if status != http.StatusOK { + t.Fatalf("refresh: status=%d body=%s", status, respBody) + } + var tr struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + } + if err := json.Unmarshal([]byte(respBody), &tr); err != nil { + t.Fatalf("refresh: decode response: %v\nbody=%s", err, respBody) + } + // The agent must NEVER receive a real upstream JWT. Real JWTs + // have payload sub "real-memberX"; assert it is absent. + if strings.Contains(respBody, `real-member`) { + t.Fatalf("real JWT leaked to agent in refresh response:\n%s", respBody) + } + if strings.Contains(tr.RefreshToken, "real-refresh-") { + t.Fatalf("real refresh token leaked to agent: %s", tr.RefreshToken) + } + return tr.AccessToken + } + + // callAPI calls the protected API through the proxy with the pool + // access phantom in the Authorization header (the binding template + // also sets it, but sending it explicitly mirrors a real agent that + // holds a phantom). Returns the HTTP status. + callAPI := func() (int, string) { + // The agent uses its phantom access token. sluice's binding + // header injection overwrites Authorization with the active + // member's real access token regardless, so the value we send + // here only needs to be the pool access phantom for the body + // phantom-swap path; the header injection is what /api checks. + return httpsRequestViaSOCKS5(t, proc.ProxyAddr, "GET", apiURL, nil, "") + } + + // ---- Phase 1: member A is active ---- + phantomBefore := doRefresh() + if phantomBefore == "" { + t.Fatal("phantom access token before failover is empty") + } + // It must be a structurally valid 3-part JWT (pool-stable synthetic). + if parts := strings.Split(phantomBefore, "."); len(parts) != 3 { + t.Fatalf("phantom access token is not a 3-part JWT: %q", phantomBefore) + } + + // (a) member A used until it 429s. + statusA, bodyA := callAPI() + if statusA != http.StatusTooManyRequests { + t.Fatalf("first API call: status=%d body=%s (want 429 from memberA)", statusA, bodyA) + } + if up.APICalls("memberA") < 1 { + t.Fatalf("memberA API calls = %d, want >= 1", up.APICalls("memberA")) + } + if up.APICalls("memberB") != 0 { + t.Fatalf("memberB API calls = %d before failover, want 0", up.APICalls("memberB")) + } + + // ---- Phase 2: failover happened synchronously on the 429 response ---- + // (b) the NEXT API request uses member B and succeeds. + var statusB int + var bodyB string + // The in-memory health swap is synchronous on the Response addon, so + // the next request should already route to B. Retry briefly only to + // absorb connection/keep-alive races, NOT to wait out a cooldown. + deadline := time.Now().Add(5 * time.Second) + for { + statusB, bodyB = callAPI() + if statusB == http.StatusOK { + break + } + if time.Now().After(deadline) { + t.Fatalf("post-failover API call never succeeded: status=%d body=%s", statusB, bodyB) + } + time.Sleep(200 * time.Millisecond) + } + if !strings.Contains(bodyB, "member=memberB") { + t.Fatalf("post-failover API response not from memberB:\n%s", bodyB) + } + if up.APICalls("memberB") < 1 { + t.Fatalf("memberB API calls = %d after failover, want >= 1", up.APICalls("memberB")) + } + + // (d) phantom access token byte-identical after the failover. Refresh + // again: now member B is active, so the upstream mints memberB's real + // JWT. sluice must still hand the agent the SAME pool-keyed phantom. + phantomAfter := doRefresh() + if phantomAfter != phantomBefore { + t.Fatalf("phantom access token changed across failover:\nbefore=%q\nafter =%q", + phantomBefore, phantomAfter) + } + + // (c) member B's refreshed tokens persisted to B's vault entry, NOT + // A's. The vault write is async; poll until B's vault entry shows a + // rotated refresh token (or time out). + var bCred, aCred *vault.OAuthCredential + pdl := time.Now().Add(5 * time.Second) + for { + bCred = readVaultOAuth(t, vaultDir, "memberB") + if strings.HasPrefix(bCred.RefreshToken, "real-refresh-memberB-rot-") { + break + } + if time.Now().After(pdl) { + t.Fatalf("memberB vault refresh token never rotated; got %q", bCred.RefreshToken) + } + time.Sleep(200 * time.Millisecond) + } + aCred = readVaultOAuth(t, vaultDir, "memberA") + + // memberB's rotated tokens must reference memberB. The vault stores + // the raw upstream JWT, so decode the payload to inspect the claim. + bPayload := decodeJWTPayload(t, bCred.AccessToken) + if !strings.Contains(bPayload, `"member":"memberB"`) { + t.Fatalf("memberB vault access token is not memberB's minted JWT; payload=%s", bPayload) + } + // memberA's vault entry must NOT have been overwritten with B's + // rotated tokens (the Risk R1 mis-attribution failure mode). + if strings.HasPrefix(aCred.RefreshToken, "real-refresh-memberB-rot-") { + t.Fatalf("R1 VIOLATION: memberB's rotated refresh token landed in memberA's vault entry: %q", + aCred.RefreshToken) + } + aPayload := decodeJWTPayload(t, aCred.AccessToken) + if strings.Contains(aPayload, `"member":"memberB"`) { + t.Fatalf("R1 VIOLATION: memberB's minted access token landed in memberA's vault entry; payload=%s", aPayload) + } + + // The token endpoint must have been hit for BOTH members across the + // two refreshes (A before failover, B after). + if up.TokenCalls("memberA") < 1 { + t.Fatalf("token endpoint calls for memberA = %d, want >= 1", up.TokenCalls("memberA")) + } + if up.TokenCalls("memberB") < 1 { + t.Fatalf("token endpoint calls for memberB = %d, want >= 1", up.TokenCalls("memberB")) + } + + // Audit log should record the failover. + if !auditLogContains(t, proc.AuditPath, "cred_failover") { + t.Error("audit log should contain a cred_failover entry") + } +} + +// decodeJWTPayload base64url-decodes the payload segment of a JWT for +// diagnostics. Returns the raw JSON string or an error message. +func decodeJWTPayload(t *testing.T, jwt string) string { + t.Helper() + parts := strings.Split(jwt, ".") + if len(parts) != 3 { + return "(not a 3-part JWT: " + jwt + ")" + } + dec, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "(payload not base64url: " + err.Error() + ")" + } + return string(dec) +} From 62d47042ec97cc38d9efa2395e844c0e7d026d18 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 11:01:15 +0800 Subject: [PATCH 25/49] fix(proxy): address Copilot review (per-request failover attribution, JWT payload marshaling, coalesced-count accuracy, single-pool membership) --- CLAUDE.md | 2 +- .../20260515-credential-pool-failover.md | 1 + internal/channel/broker.go | 13 +++ internal/channel/channel_test.go | 104 ++++++++++++++++++ internal/proxy/addon.go | 35 +++++- internal/proxy/oauth_response.go | 25 ++++- internal/proxy/pool_attribution.go | 81 ++++++++++++++ internal/proxy/pool_failover.go | 19 ++++ internal/proxy/pool_failover_test.go | 58 ++++++++++ internal/proxy/pool_phantom_test.go | 76 +++++++++++++ internal/store/pools.go | 25 +++++ internal/store/pools_test.go | 70 +++++++++++- 12 files changed, 497 insertions(+), 12 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 21b3ad2..8efa9f9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -214,7 +214,7 @@ Extends phantom swap to handle OAuth credentials bidirectionally. Static credent ### Credential pools and auto-failover -A **credential pool** lets one phantom identity the agent sees be backed by **N real OAuth credentials**. The agent always holds a single pool-scoped phantom pair (`SLUICE_PHANTOM:.access` / `SLUICE_PHANTOM:.refresh`); sluice maps it to the *currently active member's* real tokens at injection time and persists refreshed tokens back to the member that issued them. Primary use case: two OpenAI Codex OAuth accounts behind one agent so quota exhaustion on one account transparently rolls onto the other. Pool members must be `oauth` credentials — `static` members are rejected. `cred remove` errors on a credential that is a live pool member. +A **credential pool** lets one phantom identity the agent sees be backed by **N real OAuth credentials**. The agent always holds a single pool-scoped phantom pair (`SLUICE_PHANTOM:.access` / `SLUICE_PHANTOM:.refresh`); sluice maps it to the *currently active member's* real tokens at injection time and persists refreshed tokens back to the member that issued them. Primary use case: two OpenAI Codex OAuth accounts behind one agent so quota exhaustion on one account transparently rolls onto the other. Pool members must be `oauth` credentials — `static` members are rejected. `cred remove` errors on a credential that is a live pool member. **One credential belongs to at most one pool**: proxy attribution (`PoolResolver.PoolForMember`) maps a member back to a single pool, so a credential shared across pools would persist/audit a token response against the wrong pool's phantom and leave the agent with an unreplaceable phantom. `pool create` rejects a member that is already in another pool (enforced inside the same transaction as the member insert). **CLI:** diff --git a/docs/plans/completed/20260515-credential-pool-failover.md b/docs/plans/completed/20260515-credential-pool-failover.md index 8ffc477..0e3e459 100644 --- a/docs/plans/completed/20260515-credential-pool-failover.md +++ b/docs/plans/completed/20260515-credential-pool-failover.md @@ -140,4 +140,5 @@ Transparent in-flight retry; round-robin/weighted (`strategy` reserved, `failove - **I1 (important, resolved in Phase 2.2):** the 2s data-version watcher must not gate the active-member switch — synchronous in-memory health update on `Response`, store write only reconciles. - **I2 (important, resolved in Phase 1.1):** all `binding.Credential`/`OAuthIndex.Has`/`extractInjectableSecret`/`findAdder` consumers routed through one `ResolveActive` chokepoint, not just the two injection passes. - Namespace collision resolved by mutual-exclusion at create time (Phase 0.4). Orphan pool members resolved by blocking `cred remove` of a live member (Phase 0.2). +- **One credential = at most one pool** (post-merge hardening, Copilot review): `PoolResolver.PoolForMember` maps a member back to a single pool, so a credential shared across pools would persist/audit a token response against the wrong pool's phantom (agent left with an unreplaceable phantom). `CreatePoolWithMembers` rejects a member already in another pool, enforced in the same transaction as the member insert. API-host failover attribution is also pinned to the member injected for that request (recovered by flow ID) rather than the response-time active member, so concurrent in-flight 429s cannot cool an innocent member. - Alternative rejected for this use case: scheduled `sluice cred update` rotation — cannot react to a 429 in real time and races the async OAuth vault writer. diff --git a/internal/channel/broker.go b/internal/channel/broker.go index 1c600ff..03c972e 100644 --- a/internal/channel/broker.go +++ b/internal/channel/broker.go @@ -417,6 +417,13 @@ func (b *Broker) detachSub(primaryID string, subCh chan Response) { for i, c := range w.subs { if c == subCh { w.subs = append(w.subs[:i], w.subs[i+1:]...) + // A subscriber that timed out and detached is no longer + // covered by the primary's eventual decision, so it must not + // inflate the coalesced count. Decrement, never below 1 (the + // primary itself is always counted) (Finding 3). + if w.count > 1 { + w.count-- + } b.waiters[primaryID] = w return } @@ -567,6 +574,12 @@ func (b *Broker) CancelAll() { waiters := make(map[string]waiter, len(b.waiters)) for id, w := range b.waiters { waiters[id] = w + // Retain each waiter's final coalesced count before the map is + // cleared, mirroring Resolve and the primary-timeout path. Without + // this the shutdown CancelApproval edit sees CoalescedCount==1 and + // omits "applied to N requests" for a burst that was pending at + // shutdown (Finding 2). + b.recordCoalescedLocked(id, w.count) } b.waiters = make(map[string]waiter) b.dedupIndex = make(map[string]string) diff --git a/internal/channel/channel_test.go b/internal/channel/channel_test.go index 840a8e6..e1b7fc0 100644 --- a/internal/channel/channel_test.go +++ b/internal/channel/channel_test.go @@ -996,6 +996,110 @@ func TestBrokerCoalesceShutdownFanOut(t *testing.T) { } } +// TestBrokerCancelAllRetainsCoalescedCount is the Finding 2 regression. +// CancelAll cleared the waiter map without retaining each waiter's final +// coalesced count, so the shutdown CancelApproval edit saw +// CoalescedCount==1 and dropped the "applied to N requests" suffix for a +// burst that was pending at shutdown. The fix records the count under the +// broker lock before the map is cleared, mirroring Resolve. +// +// Pre-fix this test fails: CoalescedCount after CancelAll returns 1. +func TestBrokerCancelAllRetainsCoalescedCount(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const n = 7 + primaryID, out := fireCoalescedBurst(t, broker, ch, "cancelcount.example.com", n, 5*time.Second) + + broker.CancelAll() + + // Drain so the goroutines finish (they all get a terminal Deny). + for i := 0; i < n; i++ { + <-out + } + + if c := broker.CoalescedCount(primaryID); c != n { + t.Fatalf("after CancelAll, CoalescedCount(%s) = %d, want %d "+ + "(the shutdown cancel edit would render \"applied to %d "+ + "requests\"; pre-fix it renders just 1) — Finding 2", + primaryID, c, n, c) + } +} + +// TestBrokerDetachedSubNotCounted is the Finding 3 regression. When a +// coalesced subscriber times out and detaches, the waiter's count was NOT +// decremented, so a later Resolve reported a CoalescedCount that still +// included subscribers that had already given up — Telegram said "applied +// to N" for more than were actually resolved by the tap. The fix +// decrements the count on detach, never below 1 (the primary). +// +// Pre-fix this test fails: the retained count stays at the peak (1 + total +// attached) instead of dropping by the number of detached subs. +func TestBrokerDetachedSubNotCounted(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const dest = "detachcount.example.com" + const port = 443 + + // Long-lived primary so it stays pending while subs come and go. + primaryOut := make(chan result, 1) + go func() { + resp, err := broker.Request(dest, port, "https", 5*time.Second) + primaryOut <- result{resp, err} + }() + var primaryID string + for { + reqs := ch.getRequests() + if len(reqs) == 1 { + primaryID = reqs[0].ID + break + } + time.Sleep(time.Millisecond) + } + + // k subscribers that each attach then time out and detach. + const k = 3 + subOut := make(chan result, k) + for i := 0; i < k; i++ { + go func() { + resp, err := broker.Request(dest, port, "https", 30*time.Millisecond) + subOut <- result{resp, err} + }() + } + // Wait until all k have attached (count == 1 + k at the peak). + for broker.CoalescedCount(primaryID) < 1+k { + time.Sleep(time.Millisecond) + } + // Let every sub time out and detach. + for i := 0; i < k; i++ { + sr := <-subOut + if sr.resp != ResponseDeny || sr.err == nil { + t.Fatalf("sub %d should have timed out with Deny+err, got %v / %v", i, sr.resp, sr.err) + } + } + // All k detached; the primary alone remains. + if c := broker.CoalescedCount(primaryID); c != 1 { + t.Fatalf("after %d subs detached, live CoalescedCount = %d, want 1 "+ + "(detached subs must not inflate the count) — Finding 3", k, c) + } + + // Resolve the primary; the retained count must reflect only the + // primary (the k detached subs gave up before the decision). + if !broker.Resolve(primaryID, ResponseAllowOnce) { + t.Fatal("Resolve returned false for primary") + } + pr := <-primaryOut + if pr.resp != ResponseAllowOnce { + t.Fatalf("primary: expected AllowOnce, got %v", pr.resp) + } + if c := broker.CoalescedCount(primaryID); c != 1 { + t.Fatalf("retained CoalescedCount after resolve = %d, want 1 "+ + "(Telegram would say \"applied to %d requests\" when only the "+ + "primary was actually covered) — Finding 3", c, c) + } +} + func TestBrokerCoalesceSubTimeoutDoesNotBlockFanOut(t *testing.T) { ch := newMockChannel(ChannelTelegram) broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) diff --git a/internal/proxy/addon.go b/internal/proxy/addon.go index 1d27369..15df6be 100644 --- a/internal/proxy/addon.go +++ b/internal/proxy/addon.go @@ -131,6 +131,15 @@ type SluiceAddon struct { // attribution (Risk R1). Never nil after NewSluiceAddon. refreshAttr *refreshAttribution + // flowInjected maps a go-mitmproxy Flow ID to the pool member whose + // credential was injected into THAT request at injection time. It is + // the per-request join key for API-host failover attribution (Finding + // 1): a 429/403-quota failover must cool the member that backed the + // request when it was SENT, not whoever is active when the response is + // processed (which races with a concurrent request's failover). Never + // nil after NewSluiceAddon. + flowInjected *flowInjectedMember + // onOAuthRefresh is called after an OAuth token refresh persist // completes successfully. It receives the credential name so the // caller can re-inject updated phantom env vars into the agent @@ -183,6 +192,7 @@ func NewSluiceAddon(opts ...SluiceAddonOption) *SluiceAddon { a := &SluiceAddon{ pendingCheckers: make(map[string][]*pendingCheck), refreshAttr: newRefreshAttribution(), + flowInjected: newFlowInjectedMember(), } for _, o := range opts { o(a) @@ -650,6 +660,14 @@ func (a *SluiceAddon) injectHeaders(f *mitmproxy.Flow, host string, port int) { f.Request.Header.Set(binding.Header, binding.FormatValue(extractInjectableSecret(a.oauthIndex.Load(), target.secretName, secret.String()))) if target.pooled { + // Pin the API-host failover attribution to the member that backed + // THIS request at send time. The response-side poolForResponse + // API-host branch reads this by flow ID so a concurrent request's + // failover (which switches the active member) cannot mis-attribute + // this request's 429/403 to the wrong member (Finding 1). + if target.secretName != "" { + a.flowInjected.Tag(f.Id, target.secretName) + } log.Printf("[ADDON-INJECT] injected header %q for %s:%d (pool %q -> member %q)", binding.Header, host, port, binding.Credential, target.secretName) } else { @@ -723,7 +741,7 @@ func (a *SluiceAddon) Request(f *mitmproxy.Flow) { proto := a.detectRequestProtocol(f, port) protoStr := proto.String() - pairs := a.buildPhantomPairs(host, port, protoStr) + pairs := a.buildPhantomPairs(host, port, protoStr, f.Id) if len(pairs) == 0 && !a.hasPhantomPrefix(f) { return } @@ -776,7 +794,7 @@ func (a *SluiceAddon) StreamRequestModifier(f *mitmproxy.Flow, in io.Reader) io. proto := a.detectRequestProtocol(f, port) protoStr := proto.String() - pairs := a.buildPhantomPairs(host, port, protoStr) + pairs := a.buildPhantomPairs(host, port, protoStr, f.Id) if len(pairs) == 0 { return in } @@ -1363,8 +1381,11 @@ func (a *SluiceAddon) persistAddonOAuthTokens(credName string, realAccess, realR } // buildPhantomPairs builds the sorted list of phantom/secret pairs for a -// destination. The caller must call releasePhantomPairs when done. -func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string) []phantomPair { +// destination. The caller must call releasePhantomPairs when done. flowID is +// the go-mitmproxy Flow ID of the request being processed (uuid.Nil when no +// flow is associated, e.g. the QUIC path); it is used to pin per-request +// pool-member attribution for API-host failover (Finding 1). +func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flowID uuid.UUID) []phantomPair { res := a.resolver.Load() if res == nil { return nil @@ -1391,6 +1412,12 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string) []p if target.pooled { poolName := target.phantomName member := target.secretName + // Pin API-host failover attribution to this request's + // injected member (Finding 1). Idempotent with the + // pass-1 injectHeaders tag for the same flow. + if member != "" { + a.flowInjected.Tag(flowID, member) + } oauthPairs, parseErr := buildPooledOAuthPhantomPairs( poolName, member, secret, "ADDON-INJECT", func(realRefresh string) { diff --git a/internal/proxy/oauth_response.go b/internal/proxy/oauth_response.go index 8639e7e..e2fbfd8 100644 --- a/internal/proxy/oauth_response.go +++ b/internal/proxy/oauth_response.go @@ -61,9 +61,28 @@ func poolStablePhantomAccess(poolName string) string { // expiry checks treat it as valid; iat is intentionally omitted so the // payload is a pure function of the pool name (an iat would make the // phantom time-varying and break byte-identity). - payload := base64.RawURLEncoding.EncodeToString([]byte( - `{"sub":"sluice-pool:` + poolName + `","iss":"sluice-phantom","exp":4102444800}`, - )) + // + // The pool name is marshaled through encoding/json — never concatenated + // into the JSON string — so a name containing '"', '\', or control + // characters cannot produce an invalid JWT or inject extra claims + // (Finding 4). A fixed-field struct keeps the output deterministic and + // byte-stable for a given pool name (no map iteration ordering). + payloadJSON, err := json.Marshal(struct { + Sub string `json:"sub"` + Iss string `json:"iss"` + Exp int64 `json:"exp"` + }{ + Sub: "sluice-pool:" + poolName, + Iss: "sluice-phantom", + Exp: 4102444800, + }) + if err != nil { + // json.Marshal of a fixed struct with string/int fields cannot + // fail in practice; fall back to the static-form pool-stable + // phantom rather than emitting a malformed token. + return "SLUICE_PHANTOM:" + poolName + ".access" + } + payload := base64.RawURLEncoding.EncodeToString(payloadJSON) signingInput := header + "." + payload mac := hmac.New(sha256.New, phantomSigningKey) mac.Write([]byte(signingInput)) diff --git a/internal/proxy/pool_attribution.go b/internal/proxy/pool_attribution.go index b733d86..40c110f 100644 --- a/internal/proxy/pool_attribution.go +++ b/internal/proxy/pool_attribution.go @@ -3,8 +3,89 @@ package proxy import ( "sync" "time" + + uuid "github.com/satori/go.uuid" ) +// flowAttrTTL bounds how long a flow-id -> injected-member tag is retained. +// An HTTP request/response round-trip completes in well under a second in +// practice; a generous TTL absorbs slow upstreams while still bounding the +// map so a flow whose response never arrives cannot leak the tag forever. +// The tag is also deleted on first successful lookup (single-use per flow). +const flowAttrTTL = 5 * time.Minute + +// flowInjectedMember maps a go-mitmproxy Flow ID to the pool member whose +// credential was injected into THAT request at injection time (pass-1 header +// inject / pass-2 phantom swap in Requestheaders/Request). +// +// This is the join key for the API-host failover attribution bug (Finding +// 1). A pooled API-host failover (HTTP 429 / 403-quota) must be attributed +// to the member that was ACTIVE WHEN THE REQUEST WAS SENT, not the member +// that happens to be active when the response is processed. With concurrent +// in-flight requests both backed by member A, request1's 429 cools A and the +// pool switches to B; if request2's 429 is then attributed via a +// response-time ResolveActive it would wrongly cool B (now active) and park +// both accounts. The flow ID is stable across Requestheaders -> Request -> +// Response for one HTTP request (or HTTP/2 stream), so recording the +// resolved member per flow at injection time and reading it on the matching +// response pins attribution to the request's own injected member. +type flowInjectedMember struct { + mu sync.Mutex + entries map[uuid.UUID]flowAttrEntry +} + +type flowAttrEntry struct { + member string + expires time.Time +} + +func newFlowInjectedMember() *flowInjectedMember { + return &flowInjectedMember{entries: make(map[uuid.UUID]flowAttrEntry)} +} + +// Tag records that the given pool member's credential was injected for the +// request identified by flowID. Idempotent: pass-1 (injectHeaders) and +// pass-2 (buildPhantomPairs) both resolve the same member for one flow, so +// recording twice is harmless. A best-effort opportunistic sweep of expired +// entries keeps the map bounded without a background goroutine. +func (m *flowInjectedMember) Tag(flowID uuid.UUID, member string) { + if member == "" || flowID == uuid.Nil { + return + } + now := time.Now() + m.mu.Lock() + defer m.mu.Unlock() + if len(m.entries) > 0 { + for k, e := range m.entries { + if now.After(e.expires) { + delete(m.entries, k) + } + } + } + m.entries[flowID] = flowAttrEntry{member: member, expires: now.Add(flowAttrTTL)} +} + +// Recover returns the member tagged for the given flow ID and removes the +// entry (single-use: a flow's response is processed exactly once). Returns +// ("", false) when no live tag exists — the caller falls back to +// response-time ResolveActive. +func (m *flowInjectedMember) Recover(flowID uuid.UUID) (string, bool) { + if flowID == uuid.Nil { + return "", false + } + m.mu.Lock() + defer m.mu.Unlock() + e, ok := m.entries[flowID] + if !ok { + return "", false + } + delete(m.entries, flowID) + if time.Now().After(e.expires) { + return "", false + } + return e.member, true +} + // refreshAttrTTL is how long a real-refresh-token -> member tag is retained. // An OAuth refresh round-trip (agent POSTs refresh_token, upstream answers // with rotated tokens) completes in well under a second in practice; a diff --git a/internal/proxy/pool_failover.go b/internal/proxy/pool_failover.go index dc1e083..f46a63d 100644 --- a/internal/proxy/pool_failover.go +++ b/internal/proxy/pool_failover.go @@ -10,6 +10,7 @@ import ( mitmproxy "github.com/lqqyt2423/go-mitmproxy/proxy" "github.com/nemirovsky/sluice/internal/audit" "github.com/nemirovsky/sluice/internal/vault" + uuid "github.com/satori/go.uuid" ) // failoverClass is the result of classifying an upstream response for a @@ -155,6 +156,24 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str if !pr.IsPool(boundName) { continue } + // Attribute the failover to the member that backed THIS request + // when it was SENT, recovered by flow ID from the injection-time + // tag. ResolveActive at response time is unsafe under concurrency: + // a sibling request's 429 may have already switched the active + // member, so attributing by response-time active would cool an + // innocent member and park both accounts (Finding 1). Fall back to + // ResolveActive only when no per-flow tag exists (e.g. the request + // never went through the pooled injection path). + if f != nil && f.Id != uuid.Nil { + if injected, ok := a.flowInjected.Recover(f.Id); ok && injected != "" { + // Only honor the tag if the injected member is still a + // member of this pool (a membership change could have + // raced); otherwise fall through to ResolveActive. + if pr.PoolForMember(injected) == boundName { + return boundName, injected, pr, true + } + } + } member, mok := pr.ResolveActive(boundName) if !mok || member == "" { continue diff --git a/internal/proxy/pool_failover_test.go b/internal/proxy/pool_failover_test.go index cd122bf..699b539 100644 --- a/internal/proxy/pool_failover_test.go +++ b/internal/proxy/pool_failover_test.go @@ -659,6 +659,64 @@ func TestTokenEndpointFailoverFallsBackToActiveMember(t *testing.T) { } } +// TestAPIHostFailoverConcurrentAttributesInjectedMemberNotActive is the +// Finding 1 regression. Two concurrent in-flight API-host requests are both +// backed by member A (the active member at send time). request1's 429 +// arrives first: it cools A and the pool switches active to B. request2's +// 429 then arrives. The bug attributed request2's 429 via response-time +// pr.ResolveActive, which now returns B (already active after request1's +// failover) — so B would be wrongly cooled too, parking BOTH accounts. +// +// The fix pins attribution to the member that was injected for THAT request +// (recovered by flow ID from the injection-time tag). Both requests were +// backed by A, so both 429s must be attributed to A; B must remain healthy +// and active-eligible. +// +// This test MUST fail before the fix: with response-time ResolveActive, +// request2's 429 cools B (active after request1's failover), so B ends up +// in cooldown. +func TestAPIHostFailoverConcurrentAttributesInjectedMemberNotActive(t *testing.T) { + addon, _, prPtr := setupPoolAddon(t, "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + pr := prPtr.Load() + + if got, _ := pr.ResolveActive("codex_pool"); got != "memA" { + t.Fatalf("pre-failover active = %q, want memA", got) + } + + // Two concurrent requests, both sent while memA was the active member, + // so pass-1/pass-2 injected memA's credential into both. Mirror that by + // tagging each flow's injected member as memA (what injectHeaders / + // buildPhantomPairs now record at injection time). + req1 := newPoolRespFlow(client, 429, []byte(`{"error":"rate_limited"}`)) + req2 := newPoolRespFlow(client, 429, []byte(`{"error":"rate_limited"}`)) + addon.flowInjected.Tag(req1.Id, "memA") + addon.flowInjected.Tag(req2.Id, "memA") + + // request1's 429 arrives: cools memA, pool switches active to memB. + addon.Response(req1) + if _, cooling := pr.CooldownUntil("memA"); !cooling { + t.Fatal("memA must be cooling after request1's 429") + } + if got, _ := pr.ResolveActive("codex_pool"); got != "memB" { + t.Fatalf("after request1 failover, active = %q, want memB", got) + } + + // request2's 429 arrives. memB is now the active member. The bug would + // attribute this to memB (response-time ResolveActive) and cool it. The + // fix attributes it to memA (request2's injected member, by flow ID). + addon.Response(req2) + + if _, cooling := pr.CooldownUntil("memB"); cooling { + t.Fatal("memB was cooled by request2's 429 — attribution used " + + "response-time active member instead of the request's injected " + + "member (Finding 1). Both accounts are now parked.") + } + if got, _ := pr.ResolveActive("codex_pool"); got != "memB" { + t.Fatalf("active = %q, want memB (memB must stay healthy and active)", got) + } +} + // TestServerStorePoolConcurrentMarkCooldown is the CRITICAL-1 integration // regression at the real production code path: Server.StorePool's atomic // pointer swap (the SIGHUP / data_version reload) racing handlePoolFailover's diff --git a/internal/proxy/pool_phantom_test.go b/internal/proxy/pool_phantom_test.go index 991ec43..ca2822d 100644 --- a/internal/proxy/pool_phantom_test.go +++ b/internal/proxy/pool_phantom_test.go @@ -1,6 +1,8 @@ package proxy import ( + "encoding/base64" + "encoding/json" "net/http" "net/url" "strings" @@ -16,6 +18,80 @@ import ( func timeFuture() time.Time { return time.Now().Add(5 * time.Minute) } +// TestPoolStablePhantomAccessNameInjectionSafe is the Finding 4 regression. +// The pool name was interpolated directly into the JWT payload JSON string, +// so a name containing '"', '\', or control characters produced an invalid +// or claim-injected JWT, breaking the agent-facing phantom. The fix marshals +// the payload through encoding/json (fixed-field struct, deterministic). +// +// Pre-fix this test fails: base64-decoding the payload of the produced +// phantom yields invalid JSON (the embedded '"' / '\' / control byte breaks +// the hand-rolled string), so json.Unmarshal errors and the sub claim does +// not round-trip the exact pool name. +func TestPoolStablePhantomAccessNameInjectionSafe(t *testing.T) { + hostile := []string{ + `a"b`, // double quote — closes the JSON string early + `a\b`, // backslash — invalid JSON escape + "a\x01b", // control character — invalid in a JSON string + `","admin":true,"x":"`, // claim-injection attempt + `pool"}` + "\n" + `garbage`, // quote + newline + trailing junk + "normal_pool", // sanity: the common case still works + } + + for _, name := range hostile { + name := name + t.Run(name, func(t *testing.T) { + tok := poolStablePhantomAccess(name) + + // Determinism / byte-stability for a given pool name. + if tok2 := poolStablePhantomAccess(name); tok != tok2 { + t.Fatalf("phantom not deterministic for %q: %q != %q", name, tok, tok2) + } + + parts := strings.Split(tok, ".") + if len(parts) != 3 { + t.Fatalf("phantom not a 3-part JWT for %q: %q", name, tok) + } + + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatalf("payload not valid base64url for %q: %v", name, err) + } + + var claims struct { + Sub string `json:"sub"` + Iss string `json:"iss"` + Exp int64 `json:"exp"` + } + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + t.Fatalf("payload not valid JSON for pool name %q: %v (raw: %s) — Finding 4", + name, err, payloadBytes) + } + + // The exact pool name must round-trip — no truncation at the + // first quote, no injected claims. + if claims.Sub != "sluice-pool:"+name { + t.Fatalf("sub claim = %q, want %q (pool name must round-trip exactly) — Finding 4", + claims.Sub, "sluice-pool:"+name) + } + if claims.Iss != "sluice-phantom" || claims.Exp != 4102444800 { + t.Fatalf("fixed claims corrupted for %q: iss=%q exp=%d", name, claims.Iss, claims.Exp) + } + + // No extra top-level keys (claim injection would add e.g. + // "admin"). Decode into a generic map and assert exactly 3. + var generic map[string]interface{} + if err := json.Unmarshal(payloadBytes, &generic); err != nil { + t.Fatalf("payload re-decode failed for %q: %v", name, err) + } + if len(generic) != 3 { + t.Fatalf("payload has %d keys for %q, want exactly 3 (claim injection) — Finding 4: %v", + len(generic), name, generic) + } + }) + } +} + // poolMemberCred builds an OAuth credential envelope for a pool member. func poolMemberCred(t *testing.T, access, refresh string) string { t.Helper() diff --git a/internal/store/pools.go b/internal/store/pools.go index 5f843af..dc6e834 100644 --- a/internal/store/pools.go +++ b/internal/store/pools.go @@ -95,6 +95,28 @@ func validatePoolMemberTx(tx *sql.Tx, credential string) error { return nil } +// assertCredentialNotInAnotherPoolTx fails if the credential is already a +// member of a pool other than newPool. A credential may belong to at most +// one pool: proxy attribution (PoolResolver.PoolForMember) maps a member +// back to a SINGLE pool, so a token response for a second pool would be +// persisted/audited against the first pool's phantom, leaving the agent +// with an unreplaceable phantom (Finding 5). Runs inside the supplied +// transaction so the check and the member insert are atomic. +func assertCredentialNotInAnotherPoolTx(tx *sql.Tx, credential, newPool string) error { + var existing string + err := tx.QueryRow( + "SELECT pool FROM credential_pool_members WHERE credential = ? AND pool != ? LIMIT 1", + credential, newPool, + ).Scan(&existing) + if errors.Is(err, sql.ErrNoRows) { + return nil + } + if err != nil { + return fmt.Errorf("check existing pool membership for %q: %w", credential, err) + } + return fmt.Errorf("credential %q is already a member of pool %q; a credential may belong to at most one pool", credential, existing) +} + // CreatePoolWithMembers creates a pool and its ordered members atomically. // Member positions are assigned from the slice order (0-based). It enforces // the pool/credential namespace mutual-exclusion (a pool name must not @@ -154,6 +176,9 @@ func (s *Store) CreatePoolWithMembers(name, strategy string, members []string) e if err := validatePoolMemberTx(tx, m); err != nil { return err } + if err := assertCredentialNotInAnotherPoolTx(tx, m, name); err != nil { + return err + } if _, err := tx.Exec( "INSERT INTO credential_pool_members (pool, credential, position) VALUES (?, ?, ?)", name, m, i, diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go index 1ddc31d..6601910 100644 --- a/internal/store/pools_test.go +++ b/internal/store/pools_test.go @@ -103,6 +103,54 @@ func TestPoolCredentialNamespaceMutualExclusion(t *testing.T) { } } +// TestCreatePoolRejectsMemberAlreadyInAnotherPool is the Finding 5 +// regression. A credential may belong to at most one pool: proxy +// attribution (PoolResolver.PoolForMember) maps a member back to a SINGLE +// pool, so a token response for a second pool would be persisted/audited +// against the first pool's phantom, leaving the agent with an +// unreplaceable phantom. Adding a credential that already belongs to +// another pool must fail, and the second pool must not be left behind. +func TestCreatePoolRejectsMemberAlreadyInAnotherPool(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "shared") + seedOAuthCred(t, s, "solo") + + if err := s.CreatePoolWithMembers("pool_one", "failover", []string{"shared"}); err != nil { + t.Fatalf("CreatePoolWithMembers(pool_one): %v", err) + } + + // "shared" already belongs to pool_one; adding it to pool_two must fail. + err := s.CreatePoolWithMembers("pool_two", "failover", []string{"solo", "shared"}) + if err == nil { + t.Fatal("expected error: credential already a member of another pool (Finding 5)") + } + + // The second pool must not survive the rejected insert (tx rollback). + if exists, _ := s.PoolExists("pool_two"); exists { + t.Error("pool_two leaked after a member belonging to another pool was rejected") + } + + // pool_one is untouched and "shared" is still only in pool_one. + pools, err := s.PoolsForMember("shared") + if err != nil { + t.Fatalf("PoolsForMember: %v", err) + } + if len(pools) != 1 || pools[0] != "pool_one" { + t.Fatalf("PoolsForMember(shared) = %v, want [pool_one] (one credential = at most one pool)", pools) + } + + // Re-adding the same member to its OWN pool is rejected too (the pool + // already exists; this would be a duplicate pool name), but the + // single-pool invariant itself must not block recreating a fresh pool + // after the old one is removed. + if _, err := s.RemovePool("pool_one"); err != nil { + t.Fatalf("RemovePool: %v", err) + } + if err := s.CreatePoolWithMembers("pool_three", "failover", []string{"shared"}); err != nil { + t.Fatalf("after removing pool_one, re-adding shared to a new pool must succeed: %v", err) + } +} + func TestListPoolsOrdersMembers(t *testing.T) { s := newTestStore(t) for _, n := range []string{"a", "b", "c"} { @@ -158,15 +206,29 @@ func TestPoolsForMember(t *testing.T) { if err := s.CreatePoolWithMembers("p1", "failover", []string{"shared", "x"}); err != nil { t.Fatalf("create p1: %v", err) } - if err := s.CreatePoolWithMembers("p2", "failover", []string{"shared"}); err != nil { - t.Fatalf("create p2: %v", err) + // A credential belongs to at most one pool (Finding 5): adding "shared" + // to a second pool must be rejected. + if err := s.CreatePoolWithMembers("p2", "failover", []string{"shared"}); err == nil { + t.Fatal("expected p2 creation to fail: shared already belongs to p1") } + + // PoolsForMember still reports the (single) owning pool. It returns a + // slice because it also guards `cred remove` and must tolerate any + // pre-invariant rows; the live invariant keeps it to one entry. pools, err := s.PoolsForMember("shared") if err != nil { t.Fatalf("PoolsForMember: %v", err) } - if len(pools) != 2 || pools[0] != "p1" || pools[1] != "p2" { - t.Errorf("PoolsForMember(shared) = %v, want [p1 p2]", pools) + if len(pools) != 1 || pools[0] != "p1" { + t.Errorf("PoolsForMember(shared) = %v, want [p1] (one credential = at most one pool)", pools) + } + // "x" is also only in p1. + xpools, err := s.PoolsForMember("x") + if err != nil { + t.Fatalf("PoolsForMember(x): %v", err) + } + if len(xpools) != 1 || xpools[0] != "p1" { + t.Errorf("PoolsForMember(x) = %v, want [p1]", xpools) } } From 52689ffff1e1b02bdf81dcc3f6d2eba6bb049f6c Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 11:18:32 +0800 Subject: [PATCH 26/49] fix: address Copilot re-review (failover callback registration, cred-remove fail-closed, CLAUDE.md accuracy) --- CLAUDE.md | 6 ++-- cmd/sluice/cred.go | 24 ++++++++------- cmd/sluice/main.go | 66 ++++++++++++++++++++++------------------- cmd/sluice/pool_test.go | 48 ++++++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 43 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 8efa9f9..f7afedc 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -219,16 +219,16 @@ A **credential pool** lets one phantom identity the agent sees be backed by **N **CLI:** ``` -sluice pool create --member [--member ...] # ordered members; rejects static; namespace must not collide with a credential name +sluice pool create --members credA,credB[,credC] # comma-separated ordered members; rejects static; namespace must not collide with a credential name sluice pool list -sluice pool status # active member, per-member health (healthy / cooldown + recover-at + reason) +sluice pool status # active member, per-member health (healthy / cooldown + cooldown-until + reason) sluice pool rotate # operator override: advance the active member manually sluice pool remove ``` Auto-failover on 429/401 is the primary mechanism; `pool rotate` is an operator override. Pool and credential namespaces are mutually exclusive at create time. -**Data model (migration `000006_credential_pools`):** three tables — `credential_pools` (pool name, strategy reserved `failover`), `credential_pool_members` (ordered membership, pool→credential FK), `credential_health` (per-member state `healthy|cooldown`, `recover_at`, reason) — with CHECK constraints. Store API lives in `internal/store/pools.go`. `reloadAll` loads pool + health into an atomic-pointer-swapped `PoolResolver` (`internal/vault/pool.go`), rewired into the addon via `srv.StorePool`/`SetPoolResolver` on SIGHUP and the 2s data-version watcher. +**Data model (migration `000006_credential_pools`):** three tables — `credential_pools` (pool name, strategy reserved `failover`), `credential_pool_members` (ordered membership, pool→credential FK), `credential_health` (per-member state `healthy|cooldown`, `cooldown_until`, `last_failure_reason`) — with CHECK constraints. Store API lives in `internal/store/pools.go`. `reloadAll` loads pool + health into an atomic-pointer-swapped `PoolResolver` (`internal/vault/pool.go`), rewired into the addon via `srv.StorePool`/`SetPoolResolver` on SIGHUP and the 2s data-version watcher. **Phase 1 — phantom indirection (pool phantom → active member):** diff --git a/cmd/sluice/cred.go b/cmd/sluice/cred.go index 2b35bea..0195791 100644 --- a/cmd/sluice/cred.go +++ b/cmd/sluice/cred.go @@ -570,16 +570,20 @@ func handleCredRemove(args []string) error { if _, statErr := os.Stat(*dbPath); statErr == nil { guardDB, gerr := store.New(*dbPath) if gerr != nil { - log.Printf("warning: could not open database %q to check pool membership: %v", *dbPath, gerr) - } else { - pools, perr := guardDB.PoolsForMember(name) - _ = guardDB.Close() - if perr != nil { - return fmt.Errorf("check pool membership for %q: %w", name, perr) - } - if len(pools) > 0 { - return fmt.Errorf("credential %q is a member of pool(s) %s; remove it from the pool first (sluice pool remove

, or recreate the pool without it)", name, strings.Join(pools, ", ")) - } + // Fail closed: the DB exists but cannot be opened, so the + // pool-membership guard cannot run. Proceeding to delete the + // vault secret would orphan any credential_pool_members row + // pointing at this now-missing credential -- exactly what the + // guard prevents. Refuse the removal instead. + return fmt.Errorf("open database %q to check pool membership for %q (refusing to remove; a pool member may otherwise be orphaned): %w", *dbPath, name, gerr) + } + pools, perr := guardDB.PoolsForMember(name) + _ = guardDB.Close() + if perr != nil { + return fmt.Errorf("check pool membership for %q: %w", name, perr) + } + if len(pools) > 0 { + return fmt.Errorf("credential %q is a member of pool(s) %s; remove it from the pool first (sluice pool remove

, or recreate the pool without it)", name, strings.Join(pools, ", ")) } } diff --git a/cmd/sluice/main.go b/cmd/sluice/main.go index d4e63d5..7553c37 100644 --- a/cmd/sluice/main.go +++ b/cmd/sluice/main.go @@ -470,39 +470,47 @@ func main() { // Update the proxy's broker reference now that it's created. srv.SetBroker(broker) + } else { + log.Printf("no approval channels configured (ask rules will auto-deny)") + } - // Wire Phase 2 pool failover side effects: durable health write - // + best-effort Telegram notice. The in-memory active-member - // switch already happened synchronously on the response path - // before this callback fires (Risk I1); this only persists for - // restart durability and tells the operator. Everything here runs - // in a detached goroutine so the response/injection path is never - // blocked by a SQLite write or a Telegram round-trip. - failoverBroker := broker - srv.SetOnFailover(func(ev proxy.FailoverEvent) { - go func() { - if db != nil { - reason := fmt.Sprintf("failover:%s", ev.Reason) - if herr := db.SetCredentialHealth(ev.From, "cooldown", ev.Until, reason); herr != nil { - log.Printf("[POOL-FAILOVER] durable health write for %q failed: %v", ev.From, herr) - } + // Wire Phase 2 pool failover side effects: durable health write + // + best-effort Telegram notice. The in-memory active-member switch + // already happened synchronously on the response path before this + // callback fires (Risk I1); this only persists for restart durability + // and tells the operator. Registered UNCONDITIONALLY (outside the + // channel block): the durable SetCredentialHealth write is the + // CRITICAL-1 cooldown-durability guarantee and must run even in + // deployments with no Telegram/HTTP approval channel. Only the + // operator notice is gated on a broker being present. Everything + // here runs in a detached goroutine so the response/injection path + // is never blocked by a SQLite write or a Telegram round-trip. + failoverBroker := broker + srv.SetOnFailover(func(ev proxy.FailoverEvent) { + go func() { + if db != nil { + reason := fmt.Sprintf("failover:%s", ev.Reason) + if herr := db.SetCredentialHealth(ev.From, "cooldown", ev.Until, reason); herr != nil { + log.Printf("[POOL-FAILOVER] durable health write for %q failed: %v", ev.From, herr) } - if failoverBroker != nil { - // Plain text: TelegramChannel.Notify sends with no parse - // mode, so markdown backticks would render literally. - msg := fmt.Sprintf("pool %s failed over %s -> %s (%s)", - ev.Pool, ev.From, ev.To, ev.Reason) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - for _, ch := range failoverBroker.Channels() { - if nerr := ch.Notify(ctx, msg); nerr != nil { - log.Printf("[POOL-FAILOVER] notice via %s failed: %v", ch.Type(), nerr) - } + } + if failoverBroker != nil { + // Plain text: TelegramChannel.Notify sends with no parse + // mode, so markdown backticks would render literally. + msg := fmt.Sprintf("pool %s failed over %s -> %s (%s)", + ev.Pool, ev.From, ev.To, ev.Reason) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for _, ch := range failoverBroker.Channels() { + if nerr := ch.Notify(ctx, msg); nerr != nil { + log.Printf("[POOL-FAILOVER] notice via %s failed: %v", ch.Type(), nerr) } } - }() - }) + } + }() + }) + if len(allChannels) > 0 { // Start all channels. if tgChannel != nil { if err := tgChannel.Start(); err != nil { @@ -517,8 +525,6 @@ func main() { } defer hc.Stop() } - } else { - log.Printf("no approval channels configured (ask rules will auto-deny)") } // MCP gateway: always start (even with zero upstreams) so the diff --git a/cmd/sluice/pool_test.go b/cmd/sluice/pool_test.go index d984544..165fdad 100644 --- a/cmd/sluice/pool_test.go +++ b/cmd/sluice/pool_test.go @@ -2,6 +2,7 @@ package main import ( "os" + "path/filepath" "strings" "testing" @@ -164,3 +165,50 @@ func TestCredRemoveBlockedForLivePoolMember(t *testing.T) { } sb.Release() } + +// TestCredRemoveFailsClosedWhenDBUnopenable asserts that when the policy DB +// path exists but cannot be opened, cred remove refuses (fails closed) +// instead of logging a warning and deleting the vault secret anyway. A +// continue-on-error here would orphan a credential_pool_members row pointing +// at a now-missing credential -- exactly what the membership guard prevents. +// Regression for Copilot re-review finding 2. +func TestCredRemoveFailsClosedWhenDBUnopenable(t *testing.T) { + dir := t.TempDir() + + // Put a vault secret in place so we can prove it survives the refused + // removal. The vault dir is independent of the DB path. + vs, verr := vault.NewStore(dir) + if verr != nil { + t.Fatalf("open vault: %v", verr) + } + if _, err := vs.Add("acct_a", `{"access_token":"x"}`); err != nil { + t.Fatalf("vault add: %v", err) + } + + // dbPath exists (os.Stat succeeds, so the membership guard is entered) + // but is a directory, so store.New cannot open it as a SQLite file. + dbPath := filepath.Join(dir, "broken.db") + if err := os.Mkdir(dbPath, 0o755); err != nil { + t.Fatalf("mkdir broken db: %v", err) + } + + err := handleCredCommand([]string{"remove", "--db", dbPath, "acct_a"}) + if err == nil { + t.Fatalf("cred remove with unopenable DB: err = nil, want fail-closed error") + } + if !strings.Contains(err.Error(), "refusing to remove") { + t.Fatalf("cred remove error = %v, want fail-closed message containing %q", err, "refusing to remove") + } + + // The secret must still be present: the removal was refused before the + // vault delete. + vs2, verr2 := vault.NewStore(dir) + if verr2 != nil { + t.Fatalf("reopen vault: %v", verr2) + } + sb2, gerr2 := vs2.Get("acct_a") + if gerr2 != nil { + t.Fatalf("credential acct_a was destroyed despite refused removal: %v", gerr2) + } + sb2.Release() +} From 01760688ab23a3fe47ce5a28a7a63d5a939ade02 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 11:41:54 +0800 Subject: [PATCH 27/49] fix(proxy): split-host pool OAuth refresh attribution + protocol-aware failover lookup --- internal/proxy/addon.go | 160 +++++++++-- internal/proxy/oauth_index.go | 26 ++ internal/proxy/pool_failover.go | 99 ++++--- internal/proxy/pool_phantom_test.go | 6 + internal/proxy/pool_splithost_test.go | 385 ++++++++++++++++++++++++++ 5 files changed, 626 insertions(+), 50 deletions(-) create mode 100644 internal/proxy/pool_splithost_test.go diff --git a/internal/proxy/addon.go b/internal/proxy/addon.go index 15df6be..a52de4a 100644 --- a/internal/proxy/addon.go +++ b/internal/proxy/addon.go @@ -7,6 +7,7 @@ import ( "log" "net" "net/http" + "net/url" "runtime/debug" "sort" "strconv" @@ -741,7 +742,7 @@ func (a *SluiceAddon) Request(f *mitmproxy.Flow) { proto := a.detectRequestProtocol(f, port) protoStr := proto.String() - pairs := a.buildPhantomPairs(host, port, protoStr, f.Id) + pairs := a.buildPhantomPairs(host, port, protoStr, f.Id, f.Request.URL) if len(pairs) == 0 && !a.hasPhantomPrefix(f) { return } @@ -794,7 +795,7 @@ func (a *SluiceAddon) StreamRequestModifier(f *mitmproxy.Flow, in io.Reader) io. proto := a.detectRequestProtocol(f, port) protoStr := proto.String() - pairs := a.buildPhantomPairs(host, port, protoStr, f.Id) + pairs := a.buildPhantomPairs(host, port, protoStr, f.Id, f.Request.URL) if len(pairs) == 0 { return in } @@ -898,22 +899,48 @@ type oauthRespAttribution struct { // that survives two members sharing one token URL). When recovery fails it // returns skipPersist=true and never falls back to OAuthIndex.Match for the // persist target (R1: never guess). +// +// Finding 1: the caller passes the FIRST OAuthIndex match (Match is +// name-ordered and token URLs are commonly shared). If a plain OAuth +// credential sorts before the pool members and shares the token URL, +// matchedCred is that plain credential and pr.PoolForMember(matchedCred) +// returns "" — the old code then took the plain-credential identity branch +// and a pooled refresh response was swapped/persisted under the plain +// credential's phantom + vault entry. We therefore consult MatchAll: if ANY +// credential sharing this token URL is a pool member, the response could be +// a pooled refresh, so we recover the true owner from the injected refresh +// token instead of trusting the deterministic-first match. func (a *SluiceAddon) resolveOAuthResponseAttribution(f *mitmproxy.Flow, matchedCred string) oauthRespAttribution { pr := (*vault.PoolResolver)(nil) if a.poolResolver != nil { pr = a.poolResolver.Load() } + + // Determine whether this token URL has ANY pooled credential, not just + // whether the deterministic-first match happens to be one. poolName := "" if pr != nil { poolName = pr.PoolForMember(matchedCred) + if poolName == "" && f.Request != nil { + if idx := a.oauthIndex.Load(); idx != nil { + for _, c := range idx.MatchAll(f.Request.URL) { + if p := pr.PoolForMember(c); p != "" { + poolName = p + break + } + } + } + } } if poolName == "" { - // Not a pooled token URL: unchanged 1:1 behavior. + // No pooled credential shares this token URL: unchanged 1:1 + // behavior for the plain-credential case. return oauthRespAttribution{phantomName: matchedCred, persistMember: matchedCred} } - // Pooled token URL. Recover the owning member from the real refresh - // token sluice injected into this request's body (R1 join key). + // A pooled credential shares this token URL. Recover the owning member + // from the real refresh token sluice injected into this request's body + // (R1 join key — unique per member, unlike the shared token URL). reqCT := "" reqBody := []byte(nil) if f.Request != nil { @@ -925,11 +952,29 @@ func (a *SluiceAddon) resolveOAuthResponseAttribution(f *mitmproxy.Flow, matched realRefresh := extractRequestRefreshToken(reqBody, reqCT) member, ok := a.refreshAttr.Recover(realRefresh) if !ok { + // Recovery failed. The refresh may legitimately belong to a plain + // OAuth credential that shares this token URL (not a pool member); + // only attribute to that plain credential when the + // deterministic-first match itself is NOT pooled AND no per-member + // refresh tag was recorded for it (a pooled refresh always records + // a tag in buildPooledMemberPairs, so a missing tag means this was + // not a pooled refresh). Otherwise fail closed: never misfile a + // pooled member's rotated tokens under the wrong entry (R1). + if matchedCred != "" && pr.PoolForMember(matchedCred) == "" { + log.Printf("[ADDON-OAUTH] token URL shared with pool %q but no pooled "+ + "refresh tag; attributing to plain credential %q", poolName, matchedCred) + return oauthRespAttribution{phantomName: matchedCred, persistMember: matchedCred} + } log.Printf("[ADDON-OAUTH] R1 fail-closed: pooled token URL for pool %q but owning member "+ "could not be recovered from the injected refresh token; skipping vault write "+ "(next refresh will retry)", poolName) return oauthRespAttribution{phantomName: poolName, pooled: true, skipPersist: true} } + // The recovered member's own pool is authoritative (a membership change + // could have raced; attribute to whatever pool the member is in now). + if mp := pr.PoolForMember(member); mp != "" { + poolName = mp + } log.Printf("[ADDON-OAUTH] R1 attributed pooled refresh to member %q (pool %q)", member, poolName) return oauthRespAttribution{phantomName: poolName, persistMember: member, pooled: true} } @@ -1384,16 +1429,24 @@ func (a *SluiceAddon) persistAddonOAuthTokens(credName string, realAccess, realR // destination. The caller must call releasePhantomPairs when done. flowID is // the go-mitmproxy Flow ID of the request being processed (uuid.Nil when no // flow is associated, e.g. the QUIC path); it is used to pin per-request -// pool-member attribution for API-host failover (Finding 1). -func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flowID uuid.UUID) []phantomPair { +// pool-member attribution for API-host failover (Finding 1). reqURL is the +// outbound request URL (nil on the QUIC path, which has no parsed URL); it +// is used to expand pooled OAuth credentials whose token endpoint matches +// the request even when they are not bound to the CONNECT host (Finding 4, +// the split-host token-refresh case). +func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flowID uuid.UUID, reqURL *url.URL) []phantomPair { res := a.resolver.Load() if res == nil { return nil } boundCreds := res.CredentialsForDestination(host, port, proto) - if len(boundCreds) == 0 { - return nil - } + + // covered tracks every credential name already turned into pairs by the + // CONNECT-host binding loop so the token-host expansion below does not + // double-inject a pool whose API host is also its token host (the + // common same-host Codex/OpenAI deployment, where this whole second + // pass is a no-op). + covered := make(map[string]bool, len(boundCreds)) var pairs []phantomPair for _, boundName := range boundCreds { @@ -1418,15 +1471,11 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flo if member != "" { a.flowInjected.Tag(flowID, member) } - oauthPairs, parseErr := buildPooledOAuthPhantomPairs( - poolName, member, secret, "ADDON-INJECT", - func(realRefresh string) { - a.refreshAttr.Tag(realRefresh, member) - }, - ) + oauthPairs, parseErr := a.buildPooledMemberPairs(poolName, member, secret) if parseErr != nil { continue } + covered[member] = true pairs = append(pairs, oauthPairs...) continue } @@ -1434,6 +1483,7 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flo if parseErr != nil { continue } + covered[name] = true pairs = append(pairs, oauthPairs...) continue } @@ -1442,6 +1492,7 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flo // resolved name (== bound name for plain creds). phantom := []byte(PhantomToken(name)) encoded := encodePhantomForPair(phantom) + covered[name] = true pairs = append(pairs, phantomPair{ phantom: phantom, encodedPhantom: encoded, @@ -1450,6 +1501,69 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flo }) } + // Finding 4 (the crux): token-host phantom expansion. A pooled OAuth + // credential's refresh-grant POST goes to its token-URL host (e.g. + // auth.openai.com), which has NO pool binding — the pool binding lives + // on the API host (e.g. api.openai.com). The CONNECT-host loop above + // therefore produces no pairs for the token host, so the agent-held + // SLUICE_PHANTOM:.refresh would travel upstream verbatim and the + // refresh would fail (and Findings 1/2 attribution could never even + // trigger because no refresh-attribution tag would be recorded). + // + // Mirror what the existing pooled CONNECT-host path does, but key the + // expansion off OAuthIndex (token_url) instead of the binding resolver: + // for every pooled OAuth credential whose token endpoint matches this + // request, resolve the pool's active member and emit the pool-keyed + // phantom pairs + the flowInjected / refreshAttr tags (so R1 persist + // and token-endpoint failover work). MatchAll (not Match) is used so a + // plain OAuth credential that sorts before the pool members and shares + // the token URL cannot mask them. + if reqURL != nil { + if idx := a.oauthIndex.Load(); idx != nil { + pr := (*vault.PoolResolver)(nil) + if a.poolResolver != nil { + pr = a.poolResolver.Load() + } + if pr != nil { + seenPool := make(map[string]bool) + for _, credName := range idx.MatchAll(reqURL) { + poolName := pr.PoolForMember(credName) + if poolName == "" || seenPool[poolName] { + continue + } + seenPool[poolName] = true + member, ok := pr.ResolveActive(poolName) + if !ok || member == "" || covered[member] { + continue + } + secret, err := a.provider.Get(member) + if err != nil { + log.Printf("[ADDON-INJECT] token-host pool %q member %q lookup failed: %v", + poolName, member, err) + continue + } + if !vault.IsOAuth(secret.Bytes()) { + secret.Release() + continue + } + a.flowInjected.Tag(flowID, member) + oauthPairs, parseErr := a.buildPooledMemberPairs(poolName, member, secret) + if parseErr != nil { + continue + } + covered[member] = true + log.Printf("[ADDON-INJECT] token-host phantom expansion for pool %q -> member %q (%s)", + poolName, member, reqURL.Host) + pairs = append(pairs, oauthPairs...) + } + } + } + } + + if len(pairs) == 0 { + return nil + } + // Sort by phantom length descending so longer tokens are replaced // before shorter prefixes that could corrupt them. sort.Slice(pairs, func(i, j int) bool { @@ -1458,6 +1572,20 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flo return pairs } +// buildPooledMemberPairs builds the pool-keyed phantom pairs for one active +// pool member and records the realRefreshToken -> member attribution tag +// (the Risk R1 join key consumed by the 2xx persist path and the +// token-endpoint failover path). Shared by the CONNECT-host binding loop and +// the Finding 4 token-host expansion so both record attribution identically. +func (a *SluiceAddon) buildPooledMemberPairs(poolName, member string, secret vault.SecureBytes) ([]phantomPair, error) { + return buildPooledOAuthPhantomPairs( + poolName, member, secret, "ADDON-INJECT", + func(realRefresh string) { + a.refreshAttr.Tag(realRefresh, member) + }, + ) +} + // releasePhantomPairs zeroes all secret values in the pairs slice. func releasePhantomPairs(pairs []phantomPair) { for i := range pairs { diff --git a/internal/proxy/oauth_index.go b/internal/proxy/oauth_index.go index d63c9c6..df69a75 100644 --- a/internal/proxy/oauth_index.go +++ b/internal/proxy/oauth_index.go @@ -94,6 +94,32 @@ func (idx *OAuthIndex) Match(requestURL *url.URL) (credName string, ok bool) { return "", false } +// MatchAll returns every credential whose token endpoint matches the +// request URL, in index order. Two pool members commonly share ONE token +// URL (the documented Codex deployment: two OpenAI accounts, one +// auth.openai.com), and a plain OAuth credential may share that same token +// URL too. Match returns only the first index entry, which silently drops +// the others; callers that must reason about pool membership (request-side +// token-host phantom expansion, response-side attribution, token-endpoint +// failover) need the full set so they can pick the pooled/correct member +// instead of whichever name happened to sort first in credential_meta. +func (idx *OAuthIndex) MatchAll(requestURL *url.URL) []string { + if idx == nil || requestURL == nil { + return nil + } + reqPath := normalizePath(requestURL.Path) + reqHost := normalizeHost(requestURL.Host, requestURL.Scheme) + var creds []string + for _, e := range idx.entries { + if e.tokenURL.Scheme == requestURL.Scheme && + normalizeHost(e.tokenURL.Host, e.tokenURL.Scheme) == reqHost && + normalizePath(e.tokenURL.Path) == reqPath { + creds = append(creds, e.credential) + } + } + return creds +} + // Len returns the number of entries in the index. func (idx *OAuthIndex) Len() int { if idx == nil { diff --git a/internal/proxy/pool_failover.go b/internal/proxy/pool_failover.go index f46a63d..fededd4 100644 --- a/internal/proxy/pool_failover.go +++ b/internal/proxy/pool_failover.go @@ -149,10 +149,16 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str if host == "" { return "", "", nil, false } - // The Response addon path is HTTP/HTTPS/HTTP2 (gRPC). Bindings without - // an explicit protocol list match any protocol; pass "https" so a - // protocol-scoped binding still resolves on the common case. - for _, boundName := range res.CredentialsForDestination(host, port, "https") { + // Finding 3: the failover binding lookup MUST use the same protocol the + // request-side injection (injectHeaders / buildPhantomPairs) used, not a + // hardcoded "https". A protocol-scoped pooled binding (grpc / http2 / + // any meta protocol) is invisible to a "https" lookup even though the + // credential WAS injected for it, so its 429/401 would never fail over. + // detectRequestProtocol mirrors the injection path exactly (URL scheme + // then header refinement); for the common unscoped-binding case the + // result is still https-equivalent so behavior is unchanged. + proto := a.detectRequestProtocol(f, port).String() + for _, boundName := range res.CredentialsForDestination(host, port, proto) { if !pr.IsPool(boundName) { continue } @@ -207,40 +213,65 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str // for the persist path; a token-endpoint FAILURE does not rotate the // refresh token and processOAuthResponseIfMatching is 2xx-only, so the // tag is still live here. + // + // Finding 2: idx.Match returns only the FIRST index entry, and + // credential_meta is name-ordered. If a plain OAuth credential sorts + // before the pool members and shares the token URL, idx.Match returns + // the plain credential, pr.PoolForMember(matched) is "", the whole + // block is skipped, and a pooled token-host 401 / invalid_grant never + // fails over (no cooldown -> the broken member stays active forever). + // Use MatchAll and find ANY pool sharing this token URL so the gate is + // independent of which credential sorts first; the true owning member + // is still recovered from the per-member-unique injected refresh token. if idx := a.oauthIndex.Load(); idx != nil && f.Request != nil { - if matched, mok := idx.Match(f.Request.URL); mok && matched != "" { - if pool := pr.PoolForMember(matched); pool != "" { - // Recover the TRUE owning member from the injected real - // refresh token in the buffered request body. - reqCT := "" - if f.Request.Header != nil { - reqCT = f.Request.Header.Get("Content-Type") - } - realRefresh := extractRequestRefreshToken(f.Request.Body, reqCT) - if owner, ok := a.refreshAttr.Peek(realRefresh); ok && owner != "" { - if ownerPool := pr.PoolForMember(owner); ownerPool != "" { - return ownerPool, owner, pr, true - } - // owner is no longer in any pool (membership change - // raced the failure); fall through to the active-member - // fallback below for a still-meaningful attribution. + matches := idx.MatchAll(f.Request.URL) + pool := "" + matched := "" + for _, c := range matches { + if matched == "" { + matched = c // preserve the deterministic-first as last resort + } + if p := pr.PoolForMember(c); p != "" { + pool = p + break + } + } + if pool != "" { + // Recover the TRUE owning member from the injected real + // refresh token in the buffered request body. + reqCT := "" + if f.Request.Header != nil { + reqCT = f.Request.Header.Get("Content-Type") + } + realRefresh := extractRequestRefreshToken(f.Request.Body, reqCT) + if owner, ok := a.refreshAttr.Peek(realRefresh); ok && owner != "" { + if ownerPool := pr.PoolForMember(owner); ownerPool != "" { + return ownerPool, owner, pr, true } - // Fallback ONLY when the real refresh token cannot be - // extracted / attributed: cool the ACTIVE member rather - // than blindly the first index entry. The active member is - // the one whose token was most likely just injected, so it - // is strictly better than idx.Match's deterministic-first. - if active, aok := pr.ResolveActive(pool); aok && active != "" { - log.Printf("[POOL-FAILOVER] pool %q: could not attribute "+ - "token-endpoint failure via injected refresh token; "+ - "falling back to active member %q", pool, active) - return pool, active, pr, true + // owner is no longer in any pool (membership change + // raced the failure); fall through to the active-member + // fallback below for a still-meaningful attribution. + } + // Fallback ONLY when the real refresh token cannot be + // extracted / attributed: cool the ACTIVE member rather + // than blindly the first index entry. The active member is + // the one whose token was most likely just injected, so it + // is strictly better than idx.Match's deterministic-first. + if active, aok := pr.ResolveActive(pool); aok && active != "" { + log.Printf("[POOL-FAILOVER] pool %q: could not attribute "+ + "token-endpoint failure via injected refresh token; "+ + "falling back to active member %q", pool, active) + return pool, active, pr, true + } + // Last resort: a pooled index match if any (preserves prior + // behavior when even ResolveActive cannot decide; better than + // no attribution at all). + for _, c := range matches { + if pr.PoolForMember(c) != "" { + return pool, c, pr, true } - // Last resort: the index match (preserves prior behavior - // when even ResolveActive cannot decide; better than no - // attribution at all). - return pool, matched, pr, true } + return pool, matched, pr, true } } return "", "", nil, false diff --git a/internal/proxy/pool_phantom_test.go b/internal/proxy/pool_phantom_test.go index ca2822d..7591515 100644 --- a/internal/proxy/pool_phantom_test.go +++ b/internal/proxy/pool_phantom_test.go @@ -160,6 +160,12 @@ func setupPoolAddon(t *testing.T, memberA, memberB string) (*SluiceAddon, *addon // refreshGrantBody is an RFC-6749 form-encoded refresh grant carrying the // pool-scoped refresh phantom. Pass-2 swaps the phantom for the active // member's real refresh token before the request leaves sluice. +// poolName is parameterized on purpose: this is a general RFC-6749 +// refresh-grant body builder reused across pool tests, and a multi-pool +// test legitimately passes a different name. unparam only sees the current +// callers all using "codex_pool". +// +//nolint:unparam func refreshGrantBody(poolName string) []byte { return []byte("grant_type=refresh_token&refresh_token=SLUICE_PHANTOM:" + poolName + ".refresh") } diff --git a/internal/proxy/pool_splithost_test.go b/internal/proxy/pool_splithost_test.go new file mode 100644 index 0000000..eb22d71 --- /dev/null +++ b/internal/proxy/pool_splithost_test.go @@ -0,0 +1,385 @@ +package proxy + +import ( + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/nemirovsky/sluice/internal/store" + "github.com/nemirovsky/sluice/internal/vault" +) + +// setupPoolSplitHostWithPlainCred wires the EXACT topology Copilot round-3 +// flagged: a credential pool bound ONLY to the API host (api.example.com), +// whose members refresh against a DIFFERENT token-URL host +// (auth.example.com, testOAuthTokenURL) that has NO pool binding — AND a +// plain (non-pool) OAuth credential that +// +// (1) shares the same token URL as the pool members, and +// (2) sorts BEFORE the pool members in credential_meta order. +// +// (2) is the trigger for Findings 1 & 2: OAuthIndex.Match is +// deterministic-first, so it returns the plain credential even when a pool +// member's refresh token is actually in the request body. The metas slice +// below puts the plain credential first so idx.Match(tokenURL) == plain. +func setupPoolSplitHostWithPlainCred(t *testing.T) (*SluiceAddon, *addonWritableProvider, *atomic.Pointer[vault.PoolResolver]) { + t.Helper() + const ( + poolName = "codex_pool" + plain = "aaa_plain" // sorts before memA/memB + memA = "memA" + memB = "memB" + ) + + provider := &addonWritableProvider{ + creds: map[string]string{ + plain: poolMemberCred(t, "plain-access-old", "plain-refresh-old"), + memA: poolMemberCred(t, "A-access-old", "A-refresh-old"), + memB: poolMemberCred(t, "B-access-old", "B-refresh-old"), + }, + } + + // Pool binding is on the API host only. The plain credential is bound to + // its own (different) API host so the split-host token-refresh path is + // the ONLY way its / the pool's refresh can be swapped. + bindings := []vault.Binding{ + {Destination: "api.example.com", Ports: []int{443}, Credential: poolName}, + {Destination: "plain-api.example.com", Ports: []int{443}, Credential: plain}, + } + resolver, err := vault.NewBindingResolver(bindings) + if err != nil { + t.Fatalf("NewBindingResolver: %v", err) + } + var resolverPtr atomic.Pointer[vault.BindingResolver] + resolverPtr.Store(resolver) + + addon := NewSluiceAddon(WithResolver(&resolverPtr), WithProvider(provider)) + addon.persistDone = make(chan struct{}, 10) + + // Plain credential FIRST so idx.Match(testOAuthTokenURL) returns it. + metas := []store.CredentialMeta{ + {Name: plain, CredType: "oauth", TokenURL: testOAuthTokenURL}, + {Name: memA, CredType: "oauth", TokenURL: testOAuthTokenURL}, + {Name: memB, CredType: "oauth", TokenURL: testOAuthTokenURL}, + } + addon.UpdateOAuthIndex(metas) + + pool := store.Pool{Name: poolName, Strategy: store.PoolStrategyFailover} + pool.Members = []store.PoolMember{ + {Credential: memA, Position: 0}, + {Credential: memB, Position: 1}, + } + var prPtr atomic.Pointer[vault.PoolResolver] + prPtr.Store(vault.NewPoolResolver([]store.Pool{pool}, nil)) + addon.SetPoolResolver(&prPtr) + + return addon, provider, &prPtr +} + +// TestSplitHost_RequestSidePhantomSwapOnTokenHost is the Finding 4 +// regression (the crux). The agent POSTs a refresh-grant to the token-URL +// host (auth.example.com), which has NO pool binding — the pool binding +// lives on api.example.com. Before the fix, buildPhantomPairs only iterated +// credentials bound to the CONNECT host, so SLUICE_PHANTOM:codex_pool.refresh +// was NEVER swapped on the token host: the phantom would travel upstream +// verbatim and the refresh would fail. The fix expands pooled OAuth +// credentials whose token_url matches the request even with no CONNECT-host +// binding. +// +// Asserts: (a) the pool refresh phantom is swapped to the ACTIVE member's +// real refresh token, and (b) the realRefreshToken -> member attribution tag +// is recorded (so Findings 1/2 persist + failover can trigger). +func TestSplitHost_RequestSidePhantomSwapOnTokenHost(t *testing.T) { + addon, _, prPtr := setupPoolSplitHostWithPlainCred(t) + // CONNECT target is the TOKEN host, which has NO pool binding. + client := setupAddonConn(addon, "auth.example.com:443") + + pr := prPtr.Load() + if got, _ := pr.ResolveActive("codex_pool"); got != "memA" { + t.Fatalf("pre-condition active = %q, want memA", got) + } + + // Agent holds the pool-keyed refresh phantom. POST it to the token host. + reqFlow := newTestFlow(client, "POST", testOAuthTokenURL) + reqFlow.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + reqFlow.Request.Body = refreshGrantBody("codex_pool") + + addon.Requestheaders(reqFlow) + addon.Request(reqFlow) + + // (a) The pool refresh phantom must be gone and replaced by memA's REAL + // refresh token (memA is active). + body := string(reqFlow.Request.Body) + if strings.Contains(body, "SLUICE_PHANTOM:codex_pool.refresh") { + t.Fatalf("Finding 4: pool refresh phantom NOT swapped on the token host; body=%q", body) + } + if !strings.Contains(body, "A-refresh-old") { + t.Fatalf("Finding 4: active member memA's real refresh token not injected; body=%q", body) + } + + // (b) The R1 attribution tag must be recorded for memA's real refresh + // token (Peek does not consume it, so this is a non-destructive check). + if owner, ok := addon.refreshAttr.Peek("A-refresh-old"); !ok || owner != "memA" { + t.Fatalf("Finding 4: refresh-attribution tag not recorded for memA; got owner=%q ok=%v", owner, ok) + } +} + +// TestSplitHost_2xxPersistAttributedToPoolMemberNotPlainFirstMatch is the +// Finding 1 regression. A successful refresh on the token host where +// idx.Match returns the PLAIN credential (it sorts first and shares the +// token URL). Before the fix, resolveOAuthResponseAttribution saw +// pr.PoolForMember(plain) == "" and took the plain-credential identity +// branch: the pooled member's rotated tokens were persisted under the PLAIN +// credential's vault entry (and the agent got the plain credential's phantom +// instead of the pool-stable one). The fix consults MatchAll, detects the +// pool sharing the token URL, and recovers the true owner from the injected +// refresh token. +func TestSplitHost_2xxPersistAttributedToPoolMemberNotPlainFirstMatch(t *testing.T) { + addon, provider, _ := setupPoolSplitHostWithPlainCred(t) + client := setupAddonConn(addon, "auth.example.com:443") + + // Real pass-2 swap on the token host: pool phantom -> memA real refresh, + // and tags A-refresh-old -> memA. + reqFlow := newTestFlow(client, "POST", testOAuthTokenURL) + reqFlow.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + reqFlow.Request.Body = refreshGrantBody("codex_pool") + addon.Request(reqFlow) + if !strings.Contains(string(reqFlow.Request.Body), "A-refresh-old") { + t.Fatalf("pass-2 did not inject memA real refresh; body=%q", reqFlow.Request.Body) + } + + // Sanity: idx.Match deterministically returns the PLAIN credential (the + // collision the Finding-1 bug rode on). + if idx := addon.oauthIndex.Load(); idx != nil { + if matched, _ := idx.Match(reqFlow.Request.URL); matched != "aaa_plain" { + t.Fatalf("precondition: idx.Match must return the first entry aaa_plain, got %q", matched) + } + } + + respFlow := newPoolReqRespFlow(client, reqFlow.Request.Body, mustJSON(t, map[string]interface{}{ + "access_token": "A-access-rotated-1", + "refresh_token": "A-refresh-rotated-1", + "expires_in": 3600, + })) + addon.Response(respFlow) + waitAddonPersist(t, addon) + + // memA's vault entry must have the rotated tokens. + credA, err := vault.ParseOAuth([]byte(provider.creds["memA"])) + if err != nil { + t.Fatalf("parse memA: %v", err) + } + if credA.RefreshToken != "A-refresh-rotated-1" { + t.Fatalf("Finding 1: pooled member memA refresh not persisted; got %q want A-refresh-rotated-1", + credA.RefreshToken) + } + + // The PLAIN credential must be UNTOUCHED — the bug persisted memA's + // rotated tokens here because idx.Match returned the plain credential. + credPlain, err := vault.ParseOAuth([]byte(provider.creds["aaa_plain"])) + if err != nil { + t.Fatalf("parse aaa_plain: %v", err) + } + if credPlain.RefreshToken != "plain-refresh-old" || credPlain.AccessToken != "plain-access-old" { + t.Fatalf("Finding 1 VIOLATION: pooled member's rotated tokens landed in the plain credential's vault entry; got access=%q refresh=%q", + credPlain.AccessToken, credPlain.RefreshToken) + } + + // The agent must receive the POOL-STABLE phantom, NOT the plain + // credential's phantom. The pool-stable access phantom is a 3-part + // synthetic JWT keyed on the pool name. + agentBody := string(respFlow.Response.Body) + if strings.Contains(agentBody, "A-access-rotated-1") { + t.Fatalf("Finding 1: real rotated access token leaked to agent; body=%q", agentBody) + } + wantAccessPhantom := poolStablePhantomAccess("codex_pool") + if !strings.Contains(agentBody, wantAccessPhantom) { + t.Fatalf("Finding 1: agent did not receive the pool-stable access phantom; body=%q", agentBody) + } + if !strings.Contains(agentBody, "SLUICE_PHANTOM:codex_pool.refresh") { + t.Fatalf("Finding 1: agent did not receive the pool-stable refresh phantom; body=%q", agentBody) + } +} + +// TestSplitHost_TokenEndpointFailoverWithPlainCredSortingFirst is the +// Finding 2 regression. A token-endpoint invalid_grant on the token host +// where idx.Match returns the PLAIN credential (sorts first, shares token +// URL). Before the fix, poolForResponse gated the token-endpoint branch on +// pr.PoolForMember(idx.Match(...)) != "" — which is "" for the plain +// credential — so the whole branch was skipped, poolForResponse returned +// ok=false, NO cooldown was applied, and the broken pool member stayed +// active forever. The fix uses MatchAll to find the pool sharing the token +// URL independent of which credential sorts first, then recovers the true +// owner from the injected refresh token. +func TestSplitHost_TokenEndpointFailoverWithPlainCredSortingFirst(t *testing.T) { + addon, _, prPtr := setupPoolSplitHostWithPlainCred(t) + client := setupAddonConn(addon, "auth.example.com:443") + pr := prPtr.Load() + + // memA is first member AND first active. The realistic precursor: memA + // got API-429-cooled, traffic rolled to memB, and now memB's refresh + // invalid_grants on the token host. + memACooldown := time.Now().Add(90 * time.Second) + pr.MarkCooldown("memA", memACooldown, "429") + if got, _ := pr.ResolveActive("codex_pool"); got != "memB" { + t.Fatalf("after cooling memA, active = %q, want memB", got) + } + + // pass-2 injected memB's real refresh token; mirror the tag the real + // pass-2 swap records. + addon.refreshAttr.Tag("B-refresh-old", "memB") + + // Precondition: idx.Match returns the PLAIN credential (the collision). + if idx := addon.oauthIndex.Load(); idx != nil { + u := newPoolRespFlowBody(client, 400, "B-refresh-old", nil).Request.URL + if matched, _ := idx.Match(u); matched != "aaa_plain" { + t.Fatalf("precondition: idx.Match must return aaa_plain, got %q", matched) + } + } + + // poolForResponse MUST attribute the failure to memB (the injected + // member), not return ok=false because the plain credential sorted first. + f := newPoolRespFlowBody(client, 400, "B-refresh-old", []byte(`{"error":"invalid_grant"}`)) + pool, member, _, ok := addon.poolForResponse(f) + if !ok { + t.Fatal("Finding 2: token-endpoint failure on a pooled member must be attributed even when a plain cred sorts first; got ok=false") + } + if pool != "codex_pool" || member != "memB" { + t.Fatalf("Finding 2: got pool=%q member=%q, want codex_pool/memB", pool, member) + } + + var got FailoverEvent + gotCalled := make(chan struct{}, 1) + addon.SetOnFailover(func(ev FailoverEvent) { + got = ev + gotCalled <- struct{}{} + }) + + addon.Response(newPoolRespFlowBody(client, 400, "B-refresh-old", []byte(`{"error":"invalid_grant"}`))) + + // memB must now be cooled with the long auth-failure TTL. + bUntil, bCooling := pr.CooldownUntil("memB") + if !bCooling { + t.Fatal("Finding 2: memB must be in cooldown after its own invalid_grant") + } + if time.Until(bUntil) < vault.AuthFailCooldown-30*time.Second { + t.Fatalf("memB cooldown TTL = %s, want ~%s (auth-failure)", time.Until(bUntil), vault.AuthFailCooldown) + } + + // memA must be UNTOUCHED: still cooling on its ORIGINAL 90s 429 window. + aUntil, aCooling := pr.CooldownUntil("memA") + if !aCooling { + t.Fatal("memA should still be cooling on its original 429 window") + } + if aUntil.Sub(memACooldown).Abs() > time.Second { + t.Fatalf("Finding 2: innocent member memA was re-cooled: got %s, want original %s", + aUntil.Format(time.RFC3339Nano), memACooldown.Format(time.RFC3339Nano)) + } + + select { + case <-gotCalled: + case <-time.After(2 * time.Second): + t.Fatal("onFailover callback not invoked") + } + if got.From != "memB" || got.Pool != "codex_pool" || got.Reason != "invalid_grant" { + t.Fatalf("FailoverEvent = %+v, want from=memB pool=codex_pool reason=invalid_grant", got) + } +} + +// TestFinding3_ProtocolScopedPooledBindingFailoverLookup is the Finding 3 +// regression. A pooled binding scoped to a non-https protocol (grpc) on the +// API host. The request-side injection resolves the protocol via +// detectRequestProtocol, so the credential IS injected for a gRPC request. +// Before the fix, poolForResponse hardcoded "https" in its +// CredentialsForDestination lookup, so the protocol-scoped grpc binding was +// invisible on the response path: a 429 on that binding would NOT fail over. +// The fix uses the same detectRequestProtocol result for the lookup. +func TestFinding3_ProtocolScopedPooledBindingFailoverLookup(t *testing.T) { + const poolName = "grpc_pool" + provider := &addonWritableProvider{ + creds: map[string]string{ + "gA": poolMemberCred(t, "gA-access", "gA-refresh"), + "gB": poolMemberCred(t, "gB-access", "gB-refresh"), + }, + } + // Pool binding scoped to grpc ONLY on the API host. + bindings := []vault.Binding{{ + Destination: "grpc.example.com", + Ports: []int{443}, + Credential: poolName, + Protocols: []string{"grpc"}, + }} + resolver, err := vault.NewBindingResolver(bindings) + if err != nil { + t.Fatalf("NewBindingResolver: %v", err) + } + var resolverPtr atomic.Pointer[vault.BindingResolver] + resolverPtr.Store(resolver) + + addon := NewSluiceAddon(WithResolver(&resolverPtr), WithProvider(provider)) + addon.persistDone = make(chan struct{}, 10) + addon.UpdateOAuthIndex([]store.CredentialMeta{ + {Name: "gA", CredType: "oauth", TokenURL: testOAuthTokenURL}, + {Name: "gB", CredType: "oauth", TokenURL: testOAuthTokenURL}, + }) + pool := store.Pool{Name: poolName, Strategy: store.PoolStrategyFailover} + pool.Members = []store.PoolMember{ + {Credential: "gA", Position: 0}, + {Credential: "gB", Position: 1}, + } + var prPtr atomic.Pointer[vault.PoolResolver] + prPtr.Store(vault.NewPoolResolver([]store.Pool{pool}, nil)) + addon.SetPoolResolver(&prPtr) + + client := setupAddonConn(addon, "grpc.example.com:443") + pr := prPtr.Load() + if got, _ := pr.ResolveActive(poolName); got != "gA" { + t.Fatalf("pre-failover active = %q, want gA", got) + } + + // Build a gRPC response flow. detectRequestProtocol refines to gRPC when + // the request carries the gRPC content type over TLS (https scheme). + f := newPoolRespFlow(client, 429, []byte(`{"error":"rate_limited"}`)) + f.Request.URL.Scheme = "https" + f.Request.URL.Host = "grpc.example.com" + f.Request.Header.Set("Content-Type", "application/grpc") + f.Response.Header.Set("Content-Type", "application/grpc") + + // Sanity: detectRequestProtocol must classify this as gRPC, and the + // hardcoded-"https" lookup would have missed the grpc-scoped binding. + if got := addon.detectRequestProtocol(f, 443); got != ProtoGRPC { + t.Fatalf("precondition: detectRequestProtocol = %v, want ProtoGRPC", got) + } + if res := resolverPtr.Load(); len(res.CredentialsForDestination("grpc.example.com", 443, "https")) != 0 { + t.Fatal("precondition: a 'https' lookup must NOT match the grpc-scoped binding (this is the Finding 3 bug)") + } + + pool2, member, _, ok := addon.poolForResponse(f) + if !ok { + t.Fatal("Finding 3: protocol-scoped (grpc) pooled binding must be recognized on the failover path; got ok=false") + } + if pool2 != poolName || member != "gA" { + t.Fatalf("Finding 3: got pool=%q member=%q, want %s/gA", pool2, member, poolName) + } + + var got FailoverEvent + gotCalled := make(chan struct{}, 1) + addon.SetOnFailover(func(ev FailoverEvent) { + got = ev + gotCalled <- struct{}{} + }) + addon.Response(f) + + if active, _ := pr.ResolveActive(poolName); active != "gB" { + t.Fatalf("Finding 3: post-429 active = %q, want gB (grpc-scoped binding must fail over)", active) + } + select { + case <-gotCalled: + case <-time.After(2 * time.Second): + t.Fatal("onFailover callback not invoked for grpc-scoped failover") + } + if got.From != "gA" || got.Pool != poolName || got.Reason != "429" { + t.Fatalf("FailoverEvent = %+v, want from=gA pool=%s reason=429", got, poolName) + } +} From 914bc0917db45c4b9e2dcbb80d8889dba6b1b83b Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 11:58:24 +0800 Subject: [PATCH 28/49] fix(proxy): classify token-endpoint 403 invalid_grant as auth-failure; audit detected protocol --- internal/proxy/pool_failover.go | 55 +++++++++++++++++---------- internal/proxy/pool_failover_test.go | 21 +++++++--- internal/proxy/pool_splithost_test.go | 50 +++++++++++++++++++++++- 3 files changed, 98 insertions(+), 28 deletions(-) diff --git a/internal/proxy/pool_failover.go b/internal/proxy/pool_failover.go index fededd4..2772632 100644 --- a/internal/proxy/pool_failover.go +++ b/internal/proxy/pool_failover.go @@ -86,9 +86,14 @@ func classifyFailover(statusCode int, body []byte, isTokenEndpoint bool) (class if bodyContainsAny(body, "insufficient_quota", "quota_exceeded", "quota exhausted", "rate_limit_exceeded") { return failoverRateLimited, "" } - return failoverNone, "" + // NOT a quota signal: do not early-return. A 403 is still a non-2xx + // status, so a real token-endpoint body of invalid_grant/invalid_token + // must classify as auth-failure (consistent with the 400/401 path). + // The shared non-2xx token-endpoint check below handles it; a 403 from + // a non-token-endpoint with an unrelated body still resolves to + // failoverNone there (the body is only trusted on a real token URL). } - // Non-4xx-status path. Only a real token-endpoint body may be classified + // Non-2xx-status path. Only a real token-endpoint body may be classified // (invalid_grant/invalid_token), and only when the status is not a 2xx // success. A 2xx token response is a healthy refresh, never a failover. if isTokenEndpoint && (statusCode < 200 || statusCode > 299) { @@ -133,21 +138,26 @@ type FailoverEvent struct { // poolForResponse maps a response's CONNECT destination back to a pooled // binding and returns the pool name + the member that was active for this // request. Returns ok=false when the destination is not bound to a pool. -func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember string, pr *vault.PoolResolver, ok bool) { +// +// proto is the protocol detected for THIS request (the same value used for +// the protocol-scoped binding lookup). The caller threads it into the +// cred_failover audit event so the audit records the real protocol of the +// pooled binding (grpc / http2 / etc.) instead of a hardcoded "https". +func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember, proto string, pr *vault.PoolResolver, ok bool) { if a.poolResolver == nil || a.resolver == nil { - return "", "", nil, false + return "", "", "", nil, false } pr = a.poolResolver.Load() if pr == nil { - return "", "", nil, false + return "", "", "", nil, false } res := a.resolver.Load() if res == nil { - return "", "", nil, false + return "", "", "", nil, false } host, port := connectTargetForFlow(a, f) if host == "" { - return "", "", nil, false + return "", "", "", nil, false } // Finding 3: the failover binding lookup MUST use the same protocol the // request-side injection (injectHeaders / buildPhantomPairs) used, not a @@ -157,7 +167,7 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str // detectRequestProtocol mirrors the injection path exactly (URL scheme // then header refinement); for the common unscoped-binding case the // result is still https-equivalent so behavior is unchanged. - proto := a.detectRequestProtocol(f, port).String() + proto = a.detectRequestProtocol(f, port).String() for _, boundName := range res.CredentialsForDestination(host, port, proto) { if !pr.IsPool(boundName) { continue @@ -176,7 +186,7 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str // member of this pool (a membership change could have // raced); otherwise fall through to ResolveActive. if pr.PoolForMember(injected) == boundName { - return boundName, injected, pr, true + return boundName, injected, proto, pr, true } } } @@ -184,7 +194,7 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str if !mok || member == "" { continue } - return boundName, member, pr, true + return boundName, member, proto, pr, true } // Token-endpoint path. An OAuth refresh hits the credential's token-URL @@ -246,7 +256,7 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str realRefresh := extractRequestRefreshToken(f.Request.Body, reqCT) if owner, ok := a.refreshAttr.Peek(realRefresh); ok && owner != "" { if ownerPool := pr.PoolForMember(owner); ownerPool != "" { - return ownerPool, owner, pr, true + return ownerPool, owner, proto, pr, true } // owner is no longer in any pool (membership change // raced the failure); fall through to the active-member @@ -261,20 +271,20 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str log.Printf("[POOL-FAILOVER] pool %q: could not attribute "+ "token-endpoint failure via injected refresh token; "+ "falling back to active member %q", pool, active) - return pool, active, pr, true + return pool, active, proto, pr, true } // Last resort: a pooled index match if any (preserves prior // behavior when even ResolveActive cannot decide; better than // no attribution at all). for _, c := range matches { if pr.PoolForMember(c) != "" { - return pool, c, pr, true + return pool, c, proto, pr, true } } - return pool, matched, pr, true + return pool, matched, proto, pr, true } } - return "", "", nil, false + return "", "", "", nil, false } // handlePoolFailover is the Phase 2 entry point invoked from Response for @@ -302,7 +312,7 @@ func (a *SluiceAddon) handlePoolFailover(f *mitmproxy.Flow) { if f == nil || f.Response == nil || f.Request == nil { return } - pool, from, pr, ok := a.poolForResponse(f) + pool, from, proto, pr, ok := a.poolForResponse(f) if !ok { return } @@ -352,11 +362,14 @@ func (a *SluiceAddon) handlePoolFailover(f *mitmproxy.Flow) { evt := audit.Event{ Destination: host, Port: port, - Protocol: "https", - Verdict: "failover", - Action: "cred_failover", - Reason: fmt.Sprintf("%s:%s->%s:%s", pool, from, to, tag), - Credential: from, + // Same protocol used for the protocol-scoped binding lookup in + // poolForResponse, NOT a hardcoded "https". For a grpc/http2 + // scoped pooled binding the audit must record the real protocol. + Protocol: proto, + Verdict: "failover", + Action: "cred_failover", + Reason: fmt.Sprintf("%s:%s->%s:%s", pool, from, to, tag), + Credential: from, } if err := a.auditLog.Log(evt); err != nil { log.Printf("[POOL-FAILOVER] audit log error: %v", err) diff --git a/internal/proxy/pool_failover_test.go b/internal/proxy/pool_failover_test.go index 699b539..1822b41 100644 --- a/internal/proxy/pool_failover_test.go +++ b/internal/proxy/pool_failover_test.go @@ -57,6 +57,17 @@ func TestClassifyFailover(t *testing.T) { {"403 insufficient_quota", 403, `{"error":"insufficient_quota"}`, false, failoverRateLimited, "403"}, {"403 quota_exceeded", 403, `{"error":{"code":"quota_exceeded"}}`, false, failoverRateLimited, "403"}, {"403 unrelated -> noop", 403, `{"error":"forbidden: bad scope"}`, false, failoverNone, ""}, + // Finding 1: a token-endpoint 403 carrying invalid_grant/invalid_token + // is an auth failure (consistent with the 400/401 token-endpoint path). + // The old code early-returned failoverNone in the 403 branch before the + // token-endpoint body check ever ran. + {"403 token-endpoint invalid_grant -> auth", 403, `{"error":"invalid_grant"}`, true, failoverAuthFailure, "invalid_grant"}, + {"403 token-endpoint invalid_token -> auth", 403, `{"error":"invalid_token"}`, true, failoverAuthFailure, "invalid_token"}, + // 403 + quota signal stays rate-limited (unchanged). + {"403 insufficient_quota (tokenEP) stays rate-limited", 403, `{"error":"insufficient_quota"}`, true, failoverRateLimited, "403"}, + // 403 + invalid_grant but NOT a real token endpoint -> still noop + // (the body is only trusted on a real token URL). + {"403 invalid_grant but NOT token endpoint -> noop", 403, `{"error":"invalid_grant"}`, false, failoverNone, ""}, {"401 auth failure", 401, "", false, failoverAuthFailure, "401"}, {"token-endpoint invalid_grant", 400, `{"error":"invalid_grant"}`, true, failoverAuthFailure, "invalid_grant"}, {"token-endpoint invalid_token", 400, `{"error":"invalid_token"}`, true, failoverAuthFailure, "invalid_token"}, @@ -307,7 +318,7 @@ func TestPoolForResponseResolvesActiveMember(t *testing.T) { client := setupAddonConn(addon, "auth.example.com:443") f := newPoolRespFlow(client, 429, nil) - pool, member, pr, ok := addon.poolForResponse(f) + pool, member, _, pr, ok := addon.poolForResponse(f) if !ok { t.Fatal("poolForResponse: expected a pooled destination match") } @@ -392,7 +403,7 @@ func TestTokenEndpointHostFailoverOnPooledMember(t *testing.T) { // (this is exactly the gap CRITICAL-2 describes). poolForResponse must // still succeed via the token-URL index path. f := newPoolRespFlow(client, 400, []byte(`{"error":"invalid_grant"}`)) - pool, member, _, ok := addon.poolForResponse(f) + pool, member, _, _, ok := addon.poolForResponse(f) if !ok { t.Fatal("poolForResponse: token-endpoint response on a pooled member must be attributed (CRITICAL-2 fix); got ok=false") } @@ -497,7 +508,7 @@ func TestTokenEndpointFailoverAttributesInjectedMemberNotFirstIndex(t *testing.T // poolForResponse must now attribute the failure to memB (the injected // member), NOT memA (the first index entry). f := newPoolRespFlowBody(client, 400, "B-refresh-old", []byte(`{"error":"invalid_grant"}`)) - pool, member, _, ok := addon.poolForResponse(f) + pool, member, _, _, ok := addon.poolForResponse(f) if !ok { t.Fatal("poolForResponse: token-endpoint failure on a pooled member must be attributed") } @@ -614,7 +625,7 @@ func TestTokenEndpointFailover3MemberAttributesMiddleMember(t *testing.T) { addon.refreshAttr.Tag("memB-refresh", "memB") f := newPoolRespFlowBody(client, 401, "memB-refresh", []byte(`{"error":"invalid_token"}`)) - pool, member, _, ok := addon.poolForResponse(f) + pool, member, _, _, ok := addon.poolForResponse(f) if !ok || pool != "codex_pool" || member != "memB" { t.Fatalf("poolForResponse got ok=%v pool=%q member=%q, want codex_pool/memB", ok, pool, member) } @@ -650,7 +661,7 @@ func TestTokenEndpointFailoverFallsBackToActiveMember(t *testing.T) { } f := newPoolRespFlowBody(client, 400, "untagged-refresh", []byte(`{"error":"invalid_grant"}`)) - pool, member, _, ok := addon.poolForResponse(f) + pool, member, _, _, ok := addon.poolForResponse(f) if !ok { t.Fatal("poolForResponse: expected attribution via active-member fallback") } diff --git a/internal/proxy/pool_splithost_test.go b/internal/proxy/pool_splithost_test.go index eb22d71..aaa3afc 100644 --- a/internal/proxy/pool_splithost_test.go +++ b/internal/proxy/pool_splithost_test.go @@ -1,11 +1,15 @@ package proxy import ( + "encoding/json" + "os" + "path/filepath" "strings" "sync/atomic" "testing" "time" + "github.com/nemirovsky/sluice/internal/audit" "github.com/nemirovsky/sluice/internal/store" "github.com/nemirovsky/sluice/internal/vault" ) @@ -241,7 +245,7 @@ func TestSplitHost_TokenEndpointFailoverWithPlainCredSortingFirst(t *testing.T) // poolForResponse MUST attribute the failure to memB (the injected // member), not return ok=false because the plain credential sorted first. f := newPoolRespFlowBody(client, 400, "B-refresh-old", []byte(`{"error":"invalid_grant"}`)) - pool, member, _, ok := addon.poolForResponse(f) + pool, member, _, _, ok := addon.poolForResponse(f) if !ok { t.Fatal("Finding 2: token-endpoint failure on a pooled member must be attributed even when a plain cred sorts first; got ok=false") } @@ -355,13 +359,27 @@ func TestFinding3_ProtocolScopedPooledBindingFailoverLookup(t *testing.T) { t.Fatal("precondition: a 'https' lookup must NOT match the grpc-scoped binding (this is the Finding 3 bug)") } - pool2, member, _, ok := addon.poolForResponse(f) + pool2, member, detProto, _, ok := addon.poolForResponse(f) if !ok { t.Fatal("Finding 3: protocol-scoped (grpc) pooled binding must be recognized on the failover path; got ok=false") } if pool2 != poolName || member != "gA" { t.Fatalf("Finding 3: got pool=%q member=%q, want %s/gA", pool2, member, poolName) } + if detProto != ProtoGRPC.String() { + t.Fatalf("Finding 2: poolForResponse detected protocol = %q, want %q", detProto, ProtoGRPC.String()) + } + + // Finding 2: the cred_failover audit event must record the SAME + // protocol that drove the binding lookup (grpc here), not a hardcoded + // "https". Wire a real audit logger and assert the persisted Protocol. + dir := t.TempDir() + logPath := filepath.Join(dir, "audit.log") + logger, lerr := audit.NewFileLogger(logPath) + if lerr != nil { + t.Fatalf("NewFileLogger: %v", lerr) + } + addon.auditLog = logger var got FailoverEvent gotCalled := make(chan struct{}, 1) @@ -382,4 +400,32 @@ func TestFinding3_ProtocolScopedPooledBindingFailoverLookup(t *testing.T) { if got.From != "gA" || got.Pool != poolName || got.Reason != "429" { t.Fatalf("FailoverEvent = %+v, want from=gA pool=%s reason=429", got, poolName) } + + if cerr := logger.Close(); cerr != nil { + t.Fatalf("logger close: %v", cerr) + } + data, rerr := os.ReadFile(logPath) + if rerr != nil { + t.Fatalf("read audit log: %v", rerr) + } + var foundFailover bool + for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") { + if line == "" { + continue + } + var evt audit.Event + if uerr := json.Unmarshal([]byte(line), &evt); uerr != nil { + t.Fatalf("unmarshal audit line %q: %v", line, uerr) + } + if evt.Action != "cred_failover" { + continue + } + foundFailover = true + if evt.Protocol != ProtoGRPC.String() { + t.Fatalf("Finding 2: cred_failover audit Protocol = %q, want %q (must match the detected request protocol, not hardcoded https)", evt.Protocol, ProtoGRPC.String()) + } + } + if !foundFailover { + t.Fatalf("no cred_failover audit event found in:\n%s", data) + } } From a6b24e8c35a370b58f6fd6ffb3c87824f219668f Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 12:30:03 +0800 Subject: [PATCH 29/49] fix(vault): monotonic cooldown (in-memory + durable); doc pool phantom shapes --- CLAUDE.md | 6 ++-- README.md | 2 +- internal/store/pools.go | 40 +++++++++++++++++++-- internal/store/pools_test.go | 68 ++++++++++++++++++++++++++++++++++++ internal/vault/pool.go | 16 +++++++++ internal/vault/pool_test.go | 49 ++++++++++++++++++++++++++ 6 files changed, 174 insertions(+), 7 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index f7afedc..d948f15 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -214,7 +214,7 @@ Extends phantom swap to handle OAuth credentials bidirectionally. Static credent ### Credential pools and auto-failover -A **credential pool** lets one phantom identity the agent sees be backed by **N real OAuth credentials**. The agent always holds a single pool-scoped phantom pair (`SLUICE_PHANTOM:.access` / `SLUICE_PHANTOM:.refresh`); sluice maps it to the *currently active member's* real tokens at injection time and persists refreshed tokens back to the member that issued them. Primary use case: two OpenAI Codex OAuth accounts behind one agent so quota exhaustion on one account transparently rolls onto the other. Pool members must be `oauth` credentials — `static` members are rejected. `cred remove` errors on a credential that is a live pool member. **One credential belongs to at most one pool**: proxy attribution (`PoolResolver.PoolForMember`) maps a member back to a single pool, so a credential shared across pools would persist/audit a token response against the wrong pool's phantom and leave the agent with an unreplaceable phantom. `pool create` rejects a member that is already in another pool (enforced inside the same transaction as the member insert). +A **credential pool** lets one phantom identity the agent sees be backed by **N real OAuth credentials**. The agent always holds a single pool-scoped phantom pair, byte-stable across member switches: the **access** phantom is a synthetic pool-stable JWT (HS256, `sub: sluice-pool:`, `iss: sluice-phantom`, fixed far-future `exp`, built by `poolStablePhantomAccess`) — byte-identical for a given pool regardless of which member is active; the **refresh** phantom is the static string `SLUICE_PHANTOM:.refresh` (from `oauthPhantomRefresh`'s request-side strip path). Sluice maps the pair to the *currently active member's* real tokens at injection time and persists refreshed tokens back to the member that issued them. Primary use case: two OpenAI Codex OAuth accounts behind one agent so quota exhaustion on one account transparently rolls onto the other. Pool members must be `oauth` credentials — `static` members are rejected. `cred remove` errors on a credential that is a live pool member. **One credential belongs to at most one pool**: proxy attribution (`PoolResolver.PoolForMember`) maps a member back to a single pool, so a credential shared across pools would persist/audit a token response against the wrong pool's phantom and leave the agent with an unreplaceable phantom. `pool create` rejects a member that is already in another pool (enforced inside the same transaction as the member insert). **CLI:** @@ -235,13 +235,13 @@ Auto-failover on 429/401 is the primary mechanism; `pool rotate` is an operator - **Single chokepoint (I2):** every `binding.Credential` / `OAuthIndex.Has` / `extractInjectableSecret` / persist consumer on the HTTP/HTTPS OAuth path routes through `PoolResolver.ResolveActive` (`resolveInjectionTarget` for pass-1 header + pass-2 phantom swap; `resolveOAuthResponseAttribution` for the response/persist path). `idx.Has` is always called with the resolved member name, never the pool. Plain (non-pool) credentials pass through `ResolveActive` unchanged. SSH/mail/QUIC are non-OAuth and out of scope. - **Active-member selection:** healthy or expired-cooldown members first, by configured position; if all members are in cooldown, the soonest-recovering member is returned with a WARNING (degrade, never hard-fail). Recovery is lazy — evaluated in `ResolveActive`, no scheduler. - **R1 refresh-token attribution / fail-closed:** when pass-2 swaps `SLUICE_PHANTOM:.refresh`, sluice records `realRefreshToken → member` in a short-TTL map. On the token-endpoint response it recovers the member by that real refresh token and persists to that member (`persistAddonOAuthTokens(member, ...)`, singleflight key `"persist:"+member`). The join key is the real **refresh** token sluice injected — never the access token, the client connection, or `OAuthIndex.Match` (two pooled members share `auth.openai.com`'s token URL and collide there). If the member is unrecoverable: WARNING + skip the vault write, never guess. Rotating refresh tokens are single-use, so a mis-attributed write would brick both accounts — fail-closed is mandatory. -- **R3 pool-stable phantom JWT:** Codex access tokens are JWTs and the per-real-token `resignJWT` would emit a *different* phantom after every cross-member refresh, breaking the "agent never notices" guarantee. Pooled OAuth `oauthPhantomAccess`/`resignJWT` instead build the phantom JWT from a deterministic synthetic payload keyed on the **pool name** (stable `sub`/`iss`, far-future `exp`), HMAC'd with the existing fixed key — byte-identical across member switches while still a structurally valid JWT. Static-form fallback (`SLUICE_PHANTOM:.access`) is documented for the case where the agent is verified to treat the access token as opaque. +- **R3 pool-stable phantom JWT:** Codex access tokens are JWTs and the per-real-token `resignJWT` would emit a *different* phantom after every cross-member refresh, breaking the "agent never notices" guarantee. The dedicated `poolStablePhantomAccess` (in `internal/proxy/oauth_response.go`) instead builds the phantom JWT from a deterministic synthetic payload keyed on the **pool name** (`sub: sluice-pool:`, `iss: sluice-phantom`, fixed far-future `exp`, no `iat`), HMAC-SHA256'd with the existing fixed key — byte-identical across member switches while still a structurally valid JWT. The pool name is JSON-marshaled (never concatenated) so a name with quotes/control chars cannot inject claims. Static-form fallback (`SLUICE_PHANTOM:.access`) is emitted only on the unreachable `json.Marshal` failure of the fixed struct (and is documented as the equivalent for an agent verified to treat the access token as opaque). The **refresh** phantom is unaffected — it stays the static `SLUICE_PHANTOM:.refresh`. **Phase 2 — auto-failover on 429 / 401:** - **Classification** (`classifyFailover` in `internal/proxy/pool_failover.go`, called from `SluiceAddon.Response` for pooled destinations): `429` or `403 + insufficient_quota` → rate-limited; `401` or token-body `invalid_grant` / `invalid_token` → auth-failure; `5xx` / other → no-op. The token-endpoint body is only trusted when the request URL matched the OAuth index. - **Pool attribution for the response** (`poolForResponse`): a response is attributed to a pool either (a) when the flow's CONNECT host has a pooled binding (the API-host 429/403 path), **or** (b) when the request URL matches the OAuth token-URL index for a credential that is a pool member (the token-endpoint 401 / `invalid_grant` path). Case (b) is essential: an OAuth refresh hits the credential's token-URL host (e.g. `auth.openai.com`), which has no pool binding — only the API host (e.g. `api.openai.com`) does — so without the token-URL index match the token-endpoint classification would be dead code for the Codex deployment. `idx.Match` is strict 1:1 token_url→credential, so case (b) cools the exact member whose refresh token was injected. -- **Synchronous in-memory failover (I1):** health is updated in-process *before* the response returns — `MarkCooldown` takes the resolver write lock, `ResolveActive` the read lock — so the active-member switch never waits on the 2s data-version watcher (which only reconciles). A detached `onFailover` callback also writes `SetCredentialHealth(member, 'cooldown', now+ttl, reason)` for durability. Cooldown TTLs: `vault.RateLimitCooldown` = 60s, `vault.AuthFailCooldown` = 300s. No in-flight retry — the next request uses the new member. +- **Synchronous in-memory failover (I1):** health is updated in-process *before* the response returns — `MarkCooldown` takes the resolver write lock, `ResolveActive` the read lock — so the active-member switch never waits on the 2s data-version watcher (which only reconciles). A detached `onFailover` callback also writes `SetCredentialHealth(member, 'cooldown', now+ttl, reason)` for durability. Cooldown TTLs: `vault.RateLimitCooldown` = 60s, `vault.AuthFailCooldown` = 300s. **Cooldown extension is monotonic on both layers:** a member parked for an auth failure (300s) that subsequently trips a rate-limit (60s) keeps the LATER expiry — `MarkCooldown` (in-memory) and `SetCredentialHealth`'s `cooldown` upsert (durable, via a `CASE`/comparison against the stored future `cooldown_until`) both keep `max(existing-future, new)` so a known-bad credential is never made eligible early. Only the extend path is monotonic: an explicit clear (zero/past `until` in `MarkCooldown`) and any transition to `status='healthy'` still shorten/clear (recovery intact), and lazy expiry still wins over an already-expired stored cooldown. No in-flight retry — the next request uses the new member. - **Reload does not resurrect a cooled member:** because the durable `SetCredentialHealth` write is detached and best-effort, any reload (SIGHUP or the 2s data-version watcher firing on *any* unrelated DB write) rebuilds the resolver from store rows alone via `NewPoolResolver`. `Server.StorePool` therefore calls `PoolResolver.MergeLiveCooldowns(prev)` to carry forward still-active in-memory cooldowns from the resolver being replaced before the atomic swap. The merge is monotonic (a live cooldown is never shortened/erased by an unrelated reload) and drops cooldowns for credentials no longer in any pool. - **Audit:** a `cred_failover` event (Verdict `failover`, Credential = the cooled-down member) with `Reason = ":->:<429|403|401|invalid_grant>"`, emitted synchronously in `handlePoolFailover`. - **Telegram:** a best-effort non-blocking notice "pool failed over -> ()" (plain text — `TelegramChannel.Notify` sends with no parse mode); the store write and every broker channel `Notify` are detached into their own goroutine so the response path never blocks. diff --git a/README.md b/README.md index c411e02..df20df8 100644 --- a/README.md +++ b/README.md @@ -288,7 +288,7 @@ github_pat static api.github.com ## Credential Pools -A credential pool lets a single phantom identity the agent sees be backed by **N real OAuth credentials**, with sluice auto-failing-over to the next member when the upstream rejects the active one. Primary use case: two OpenAI Codex OAuth accounts driven by one agent, so quota exhaustion on one account transparently rolls onto the other. The agent always holds one pool-scoped phantom pair (`SLUICE_PHANTOM:.access` / `.refresh`); sluice maps it to the currently active member's real token at injection time and persists refreshed tokens back to the member that issued them. +A credential pool lets a single phantom identity the agent sees be backed by **N real OAuth credentials**, with sluice auto-failing-over to the next member when the upstream rejects the active one. Primary use case: two OpenAI Codex OAuth accounts driven by one agent, so quota exhaustion on one account transparently rolls onto the other. The agent always holds one pool-scoped phantom pair, byte-stable across member switches: the **access** phantom is a synthetic pool-stable JWT (HS256, `sub: sluice-pool:`, `iss: sluice-phantom`, far-future `exp`) that is byte-identical for a given pool regardless of which member is active, so a cross-member failover never changes the access token the agent holds; the **refresh** phantom is the static string `SLUICE_PHANTOM:.refresh`. Sluice maps the pair to the currently active member's real token at injection time and persists refreshed tokens back to the member that issued them. ```bash sluice pool create --members credA,credB[,credC] [--strategy failover] diff --git a/internal/store/pools.go b/internal/store/pools.go index dc6e834..fe949c3 100644 --- a/internal/store/pools.go +++ b/internal/store/pools.go @@ -327,13 +327,47 @@ func (s *Store) SetCredentialHealth(credential, status string, cooldownUntil tim } else { cu = nil } + // Monotonic extend for the durable row, mirroring MarkCooldown's + // in-memory invariant. When the incoming write is a cooldown AND the + // stored row already has a cooldown_until strictly in the future that + // is LATER than the incoming one, keep the stored (longer) value: a + // short rate-limit cooldown must never shorten a longer auth-failure + // cooldown, even on the durable side, so restart durability matches + // the resolver. Any transition to "healthy" (excluded.status = + // 'healthy', whose cooldown_until is NULL) always overwrites, so the + // recovery/heal path is intact. cooldown_until is always written as + // UTC RFC3339 by this function, so the string comparison is a valid + // chronological ordering; the datetime('now') guard makes an already + // expired stored cooldown lose to the fresh future one (lazy expiry + // preserved). _, err := s.db.Exec( `INSERT INTO credential_health (credential, status, cooldown_until, last_failure_reason, updated_at) VALUES (?, ?, ?, ?, datetime('now')) ON CONFLICT(credential) DO UPDATE SET - status = excluded.status, - cooldown_until = excluded.cooldown_until, - last_failure_reason = excluded.last_failure_reason, + cooldown_until = CASE + WHEN excluded.status = 'cooldown' + AND credential_health.cooldown_until IS NOT NULL + AND credential_health.cooldown_until > strftime('%Y-%m-%dT%H:%M:%SZ', 'now') + AND credential_health.cooldown_until > excluded.cooldown_until + THEN credential_health.cooldown_until + ELSE excluded.cooldown_until + END, + status = CASE + WHEN excluded.status = 'cooldown' + AND credential_health.cooldown_until IS NOT NULL + AND credential_health.cooldown_until > strftime('%Y-%m-%dT%H:%M:%SZ', 'now') + AND credential_health.cooldown_until > excluded.cooldown_until + THEN credential_health.status + ELSE excluded.status + END, + last_failure_reason = CASE + WHEN excluded.status = 'cooldown' + AND credential_health.cooldown_until IS NOT NULL + AND credential_health.cooldown_until > strftime('%Y-%m-%dT%H:%M:%SZ', 'now') + AND credential_health.cooldown_until > excluded.cooldown_until + THEN credential_health.last_failure_reason + ELSE excluded.last_failure_reason + END, updated_at = excluded.updated_at`, credential, status, cu, nilIfEmpty(reason), ) diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go index 6601910..cc7af0a 100644 --- a/internal/store/pools_test.go +++ b/internal/store/pools_test.go @@ -278,6 +278,74 @@ func TestCredentialHealthCRUD(t *testing.T) { } } +func TestSetCredentialHealthMonotonicCooldown(t *testing.T) { + s := newTestStore(t) + + // Seed a long auth-failure cooldown (now+300s). + authUntil := time.Now().Add(300 * time.Second).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("a", "cooldown", authUntil, "401 auth fail"); err != nil { + t.Fatalf("seed cooldown: %v", err) + } + + // A subsequent shorter rate-limit cooldown (now+60s) must NOT shorten + // the durable row — restart durability must match the resolver. + rlUntil := time.Now().Add(60 * time.Second).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("a", "cooldown", rlUntil, "429 rate limited"); err != nil { + t.Fatalf("shorter cooldown write: %v", err) + } + h, _ := s.GetCredentialHealth("a") + if h == nil || !h.CooldownUntil.Equal(authUntil) { + t.Fatalf("after shorter write CooldownUntil = %v, want %v (NOT shortened)", + cooldownOf(h), authUntil) + } + if h.LastFailureReason != "401 auth fail" { + t.Errorf("reason = %q, want %q (longer cooldown's metadata kept)", h.LastFailureReason, "401 auth fail") + } + + // A strictly LATER cooldown does extend. + laterUntil := authUntil.Add(120 * time.Second) + if err := s.SetCredentialHealth("a", "cooldown", laterUntil, "429 again"); err != nil { + t.Fatalf("later cooldown write: %v", err) + } + h, _ = s.GetCredentialHealth("a") + if h == nil || !h.CooldownUntil.Equal(laterUntil) { + t.Fatalf("after later write CooldownUntil = %v, want %v (extended)", + cooldownOf(h), laterUntil) + } + + // Transition to healthy clears, even though a longer cooldown is active + // (recovery/heal path must remain intact). + if err := s.SetCredentialHealth("a", "healthy", time.Time{}, ""); err != nil { + t.Fatalf("heal write: %v", err) + } + h, _ = s.GetCredentialHealth("a") + if h == nil || h.Status != "healthy" || !h.CooldownUntil.IsZero() { + t.Errorf("after heal = %+v, want healthy/zero (recovery must not be blocked by monotonicity)", h) + } + + // An already-expired stored cooldown loses to a fresh future one + // (lazy expiry preserved at the durable layer too). + pastUntil := time.Now().Add(-time.Hour).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("b", "cooldown", pastUntil, "stale"); err != nil { + t.Fatalf("seed stale cooldown: %v", err) + } + freshUntil := time.Now().Add(60 * time.Second).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("b", "cooldown", freshUntil, "429"); err != nil { + t.Fatalf("fresh cooldown write: %v", err) + } + h, _ = s.GetCredentialHealth("b") + if h == nil || !h.CooldownUntil.Equal(freshUntil) { + t.Errorf("fresh cooldown after stale = %v, want %v", cooldownOf(h), freshUntil) + } +} + +func cooldownOf(h *CredentialHealth) interface{} { + if h == nil { + return nil + } + return h.CooldownUntil +} + // TestMigration000006DownUp verifies the pool migration is reversible. func TestMigration000006DownUp(t *testing.T) { dir := t.TempDir() diff --git a/internal/vault/pool.go b/internal/vault/pool.go index c667109..51fcbbc 100644 --- a/internal/vault/pool.go +++ b/internal/vault/pool.go @@ -261,6 +261,22 @@ func (pr *PoolResolver) MarkCooldown(credential string, until time.Time, reason delete(pr.health.health, credential) return } + // Monotonic extend: a member parked for an auth failure (300s) that + // subsequently trips a rate-limit (60s) must NOT have its cooldown + // shortened — a known-bad credential would become eligible far too + // early. Keep the LATER of the existing future cooldown and the new + // one. This is ONLY the extend path: an explicit clear/recover (the + // zero/past `until` branch above, and SetCredentialHealth "healthy" + // on the durable side) still shortens/clears, and a strictly later + // `until` still extends. Lazy expiry in ResolveActive/CooldownUntil + // is unaffected because an expired existing cooldown is in the past + // and `until.After(existing.cooldownUntil)` is true, so the fresh + // future cooldown wins. + if existing, ok := pr.health.health[credential]; ok && + existing.cooldownUntil.After(time.Now()) && + !until.After(existing.cooldownUntil) { + return + } pr.health.health[credential] = memberHealth{cooldownUntil: until, reason: reason} } diff --git a/internal/vault/pool_test.go b/internal/vault/pool_test.go index 80fcce3..e5dc0f8 100644 --- a/internal/vault/pool_test.go +++ b/internal/vault/pool_test.go @@ -103,6 +103,55 @@ func TestMarkCooldownSynchronousFlip(t *testing.T) { } } +func TestMarkCooldownMonotonicExtend(t *testing.T) { + pr := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, nil) + + // Park "a" for an auth failure (300s). + authUntil := time.Now().Add(AuthFailCooldown) + pr.MarkCooldown("a", authUntil, "401") + got, ok := pr.CooldownUntil("a") + if !ok || !got.Equal(authUntil) { + t.Fatalf("after auth-fail cooldown = %v,%v; want %v,true", got, ok, authUntil) + } + + // A subsequent shorter rate-limit cooldown (60s) must NOT shorten it: + // the credential is known-bad for 300s and must not be eligible early. + rlUntil := time.Now().Add(RateLimitCooldown) + pr.MarkCooldown("a", rlUntil, "429") + got, ok = pr.CooldownUntil("a") + if !ok || !got.Equal(authUntil) { + t.Errorf("after shorter rate-limit cooldown = %v,%v; want %v,true (NOT shortened to %v)", + got, ok, authUntil, rlUntil) + } + + // A strictly LATER cooldown does extend. + laterUntil := authUntil.Add(120 * time.Second) + pr.MarkCooldown("a", laterUntil, "429-again") + got, ok = pr.CooldownUntil("a") + if !ok || !got.Equal(laterUntil) { + t.Errorf("after later cooldown = %v,%v; want %v,true (extended)", got, ok, laterUntil) + } + + // Explicit clear (zero) still recovers despite an active longer cooldown. + pr.MarkCooldown("a", time.Time{}, "") + if _, cooling := pr.CooldownUntil("a"); cooling { + t.Error("after explicit clear CooldownUntil(a) cooling=true, want false (recovery path must not be blocked by monotonicity)") + } + if active, _ := pr.ResolveActive("pool"); active != "a" { + t.Errorf("after clear active = %q, want a", active) + } + + // Expired existing cooldown must lose to a fresh future one (lazy + // expiry preserved): set a past cooldown, then a normal future one. + pr.MarkCooldown("b", time.Now().Add(-time.Hour), "stale") // zero/past => clear, b stays healthy + freshUntil := time.Now().Add(RateLimitCooldown) + pr.MarkCooldown("b", freshUntil, "429") + got, ok = pr.CooldownUntil("b") + if !ok || !got.Equal(freshUntil) { + t.Errorf("fresh cooldown after stale = %v,%v; want %v,true", got, ok, freshUntil) + } +} + func TestPoolForMemberAndMembers(t *testing.T) { pr := NewPoolResolver([]store.Pool{mkPool("pool", "a", "b")}, nil) if p := pr.PoolForMember("b"); p != "pool" { From b4218a69f9b6d75117b5f8cdddcb8c84a3ddcd90 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 15:15:28 +0800 Subject: [PATCH 30/49] fix(proxy): fail-closed unattributed pooled refresh; broker resolve/detach race; scope pool failover fallback to pooled requests --- internal/channel/broker.go | 50 ++++++- internal/channel/channel_test.go | 109 +++++++++++++++ internal/proxy/addon.go | 34 ++--- internal/proxy/pool_failover.go | 71 ++++++---- internal/proxy/pool_failover_test.go | 97 +++++++++++-- internal/proxy/pool_phantom_test.go | 199 +++++++++++++++++++++++++++ 6 files changed, 500 insertions(+), 60 deletions(-) diff --git a/internal/channel/broker.go b/internal/channel/broker.go index 03c972e..9cb4f26 100644 --- a/internal/channel/broker.go +++ b/internal/channel/broker.go @@ -70,6 +70,22 @@ type Broker struct { // nowFunc is used for testing to control time. If nil, time.Now is used. nowFunc func() time.Time + + // resolveAfterDeleteHook is a test-only seam invoked inside Resolve + // immediately after the waiter has been deleted from b.waiters and the + // coalesced count recorded, but before the primary/subscriber response + // sends. It runs while b.mu is held (post-fix the sends are also under + // the lock), so a test can drive a coalesced subscriber's deadline path + // concurrently and assert it cannot observe a lost wakeup. nil in + // production. + resolveAfterDeleteHook func() + + // subDeadlineGate is a test-only seam invoked at the very top of + // waitSub's deadline branch, before detachSub. A test uses it to park a + // coalesced subscriber exactly at the start of its timeout-handling path + // so the resolve/detach interleave can be forced deterministically + // without sleeps. nil in production. + subDeadlineGate func() } // waiter tracks a pending approval request and its response channel. @@ -391,6 +407,9 @@ func (b *Broker) waitSub(primaryID string, subCh chan Response, deadlineC <-chan b.detachSub(primaryID, subCh) return ResponseDeny, fmt.Errorf("approval broker shutting down") case <-deadlineC: + if b.subDeadlineGate != nil { + b.subDeadlineGate() + } b.detachSub(primaryID, subCh) // The primary may have resolved between the deadline firing and // the detach completing. The sub chan is buffered (cap 1), so a @@ -544,18 +563,35 @@ func (b *Broker) Resolve(id string, resp Response) bool { delete(b.dedupIndex, w.dedupKey) } b.recordCoalescedLocked(id, w.count) - } - b.mu.Unlock() - if ok { + if b.resolveAfterDeleteHook != nil { + b.resolveAfterDeleteHook() + } + + // Deliver the primary response and fan it to every coalesced + // subscriber WHILE STILL HOLDING b.mu. The primary ch and every + // sub chan are buffered cap-1 and receive exactly one value, so + // these sends cannot block — holding the lock here is safe and + // closes the resolve/detach lost-wakeup window: a subscriber + // whose deadline fires takes b.mu in detachSub, so it serializes + // against this section. It therefore either detaches BEFORE this + // runs (removed from w.subs, gets no send, returns its own + // timeout — correct, it never coalesced under this decision) or + // AFTER (the response is already buffered on its cap-1 chan, and + // waitSub's post-detach non-blocking read picks it up instead of + // denying an approved request). There is no instant where a sub + // can observe "waiter gone AND response not yet sent". w.ch <- resp - // Fan the same response to every coalesced subscriber. All sub - // chans are buffered (cap 1), so a send to a subscriber that - // already timed out and detached never blocks. for _, sub := range w.subs { sub <- resp } - // Cancel on all channels so they can clean up (e.g. edit message). + } + b.mu.Unlock() + + if ok { + // Cancel on all channels so they can clean up (e.g. edit + // message). This calls into channel implementations that may do + // blocking network I/O, so it must stay OUTSIDE b.mu. b.cancelOnChannels(id) } return ok diff --git a/internal/channel/channel_test.go b/internal/channel/channel_test.go index e1b7fc0..9457647 100644 --- a/internal/channel/channel_test.go +++ b/internal/channel/channel_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "runtime" "sync" "sync/atomic" "testing" @@ -1404,3 +1405,111 @@ func TestBrokerCoalesceCrossChannelFirstWins(t *testing.T) { t.Fatalf("unexpected response %v", first.resp) } } + +// TestBrokerResolveDetachLostWakeup deterministically forces the exact +// resolve/detach interleave from Copilot Finding 2: the primary's Resolve +// has deleted the waiter but the coalesced subscriber's deadline fires in +// the window before the fan-out send. Pre-fix, Resolve sent AFTER releasing +// b.mu, so the subscriber's detachSub (no waiter found) + non-blocking read +// (subCh still empty) returned a spurious ResponseDeny and the buffered +// allow was dropped — a coalesced caller wrongly denied. Post-fix the +// fan-out send happens INSIDE the locked resolution section, so the +// subscriber's detachSub serializes against it via b.mu: by the time the +// subscriber's non-blocking read runs, the allow is already buffered on its +// cap-1 chan and it returns ResponseAllowOnce. +// +// The interleave is forced with two test-only seams (no sleeps for the +// critical ordering): +// - subDeadlineGate parks the subscriber at the very top of waitSub's +// deadline branch (before detachSub) until the test releases it. +// - resolveAfterDeleteHook fires inside Resolve right after the waiter is +// deleted (lock held), and is where the test releases the subscriber so +// its detach/read races the fan-out exactly in the lost-wakeup window. +func TestBrokerResolveDetachLostWakeup(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const dest = "lostwakeup.example.com" + const port = 443 + + // Long-lived primary so it stays pending while the sub attaches. + primaryOut := make(chan result, 1) + go func() { + resp, err := broker.Request(dest, port, "https", 5*time.Second) + primaryOut <- result{resp, err} + }() + var primaryID string + for { + reqs := ch.getRequests() + if len(reqs) == 1 { + primaryID = reqs[0].ID + break + } + time.Sleep(time.Millisecond) + } + + // Park the subscriber at the start of its deadline branch. + subAtGate := make(chan struct{}) + releaseSub := make(chan struct{}) + var gateOnce sync.Once + broker.subDeadlineGate = func() { + gateOnce.Do(func() { close(subAtGate) }) + <-releaseSub + } + + // Subscriber with a very short timeout: it attaches, coalesces onto the + // primary, then its deadline fires and it blocks in subDeadlineGate. + subOut := make(chan result, 1) + go func() { + resp, err := broker.Request(dest, port, "https", 20*time.Millisecond) + subOut <- result{resp, err} + }() + // Wait for it to coalesce (count == 2) so it is in w.subs when Resolve + // deletes the waiter — the precondition for the lost-wakeup window. + for broker.CoalescedCount(primaryID) < 2 { + time.Sleep(time.Millisecond) + } + // Wait until the subscriber's deadline has fired and it is parked at the + // gate (still attached, detachSub not yet called). + <-subAtGate + + // Inside Resolve, after the waiter is deleted (lock held), release the + // parked subscriber so its detachSub + non-blocking read races the + // fan-out send exactly in the lost-wakeup window. Yield generously so + // the subscriber goroutine is scheduled and (post-fix) blocks on b.mu + // inside detachSub before this hook returns and Resolve performs the + // under-lock fan-out send. + broker.resolveAfterDeleteHook = func() { + close(releaseSub) + for i := 0; i < 1000; i++ { + runtime.Gosched() + } + } + + if !broker.Resolve(primaryID, ResponseAllowOnce) { + t.Fatal("Resolve returned false for primary") + } + + pr := <-primaryOut + if pr.resp != ResponseAllowOnce { + t.Fatalf("primary: expected ResponseAllowOnce, got %v (err %v)", pr.resp, pr.err) + } + + select { + case sr := <-subOut: + // The whole point of Finding 2: the coalesced subscriber whose + // deadline fired in the resolve window must observe the operator's + // ALLOW, not a spurious timeout deny. + if sr.resp != ResponseAllowOnce { + t.Fatalf("coalesced subscriber wrongly got %v (err %v); want ResponseAllowOnce "+ + "— lost-wakeup: Resolve deleted the waiter then the sub's deadline "+ + "fired before the fan-out send (Finding 2)", sr.resp, sr.err) + } + if sr.err != nil { + t.Fatalf("coalesced subscriber got ALLOW but with error %v; the resolved "+ + "decision must come back clean", sr.err) + } + case <-time.After(3 * time.Second): + t.Fatal("coalesced subscriber never returned (deadlock?)") + } +} diff --git a/internal/proxy/addon.go b/internal/proxy/addon.go index a52de4a..4290e47 100644 --- a/internal/proxy/addon.go +++ b/internal/proxy/addon.go @@ -952,22 +952,24 @@ func (a *SluiceAddon) resolveOAuthResponseAttribution(f *mitmproxy.Flow, matched realRefresh := extractRequestRefreshToken(reqBody, reqCT) member, ok := a.refreshAttr.Recover(realRefresh) if !ok { - // Recovery failed. The refresh may legitimately belong to a plain - // OAuth credential that shares this token URL (not a pool member); - // only attribute to that plain credential when the - // deterministic-first match itself is NOT pooled AND no per-member - // refresh tag was recorded for it (a pooled refresh always records - // a tag in buildPooledMemberPairs, so a missing tag means this was - // not a pooled refresh). Otherwise fail closed: never misfile a - // pooled member's rotated tokens under the wrong entry (R1). - if matchedCred != "" && pr.PoolForMember(matchedCred) == "" { - log.Printf("[ADDON-OAUTH] token URL shared with pool %q but no pooled "+ - "refresh tag; attributing to plain credential %q", poolName, matchedCred) - return oauthRespAttribution{phantomName: matchedCred, persistMember: matchedCred} - } - log.Printf("[ADDON-OAUTH] R1 fail-closed: pooled token URL for pool %q but owning member "+ - "could not be recovered from the injected refresh token; skipping vault write "+ - "(next refresh will retry)", poolName) + // Recovery failed AND a pool shares this token URL (the + // poolName == "" plain-only case already returned above with a + // normal 1:1 persist to matchedCred). We cannot prove this + // response is not a pooled refresh: a pooled refresh always + // records a tag in buildPooledMemberPairs, but that tag can + // expire (refreshAttrTTL) or be consumed before a slow response + // comes back, so a missing tag is NOT evidence of "plain, not + // pooled". Attributing to the deterministic-first plain + // credential here would misfile a pooled member's rotated tokens + // under the plain entry — exactly the R1 ("never guess") + // violation. Strict fail-closed: skip the vault write. The agent + // still receives phantoms (the swap ran), and the next refresh + // cycle re-tags and persists correctly. A genuinely plain-only + // token URL is unaffected because it never reaches here. + log.Printf("[ADDON-OAUTH] R1 fail-closed: token URL shared with pool %q but the "+ + "owning member could not be recovered from the injected refresh token "+ + "(no live tag); skipping vault write to avoid misfiling a pooled refresh "+ + "under the wrong credential (next refresh will retry)", poolName) return oauthRespAttribution{phantomName: poolName, pooled: true, skipPersist: true} } // The recovered member's own pool is authoritative (a membership change diff --git a/internal/proxy/pool_failover.go b/internal/proxy/pool_failover.go index 2772632..4e8d55a 100644 --- a/internal/proxy/pool_failover.go +++ b/internal/proxy/pool_failover.go @@ -236,11 +236,7 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember, pr if idx := a.oauthIndex.Load(); idx != nil && f.Request != nil { matches := idx.MatchAll(f.Request.URL) pool := "" - matched := "" for _, c := range matches { - if matched == "" { - matched = c // preserve the deterministic-first as last resort - } if p := pr.PoolForMember(c); p != "" { pool = p break @@ -259,29 +255,56 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember, pr return ownerPool, owner, proto, pr, true } // owner is no longer in any pool (membership change - // raced the failure); fall through to the active-member - // fallback below for a still-meaningful attribution. - } - // Fallback ONLY when the real refresh token cannot be - // extracted / attributed: cool the ACTIVE member rather - // than blindly the first index entry. The active member is - // the one whose token was most likely just injected, so it - // is strictly better than idx.Match's deterministic-first. - if active, aok := pr.ResolveActive(pool); aok && active != "" { - log.Printf("[POOL-FAILOVER] pool %q: could not attribute "+ - "token-endpoint failure via injected refresh token; "+ - "falling back to active member %q", pool, active) - return pool, active, proto, pr, true + // raced the failure); the refresh-attr tag still proves + // THIS request used the pool, so fall through to the + // active-member fallback below for a still-meaningful + // attribution. + if active, aok := pr.ResolveActive(pool); aok && active != "" { + log.Printf("[POOL-FAILOVER] pool %q: token-endpoint failure "+ + "owner %q left the pool (membership raced); falling back "+ + "to active member %q", pool, owner, active) + return pool, active, proto, pr, true + } } - // Last resort: a pooled index match if any (preserves prior - // behavior when even ResolveActive cannot decide; better than - // no attribution at all). - for _, c := range matches { - if pr.PoolForMember(c) != "" { - return pool, c, proto, pr, true + // Finding 3: the refresh-attr tag could not attribute this + // failure. A blind ResolveActive / first-index fallback here + // over-applies the cooldown: a PLAIN (non-pool) OAuth + // credential that merely SHARES this token URL with a pool + // would, on its own 401 / invalid_grant, cool an unrelated + // active pool member even though the failing request never + // used the pool. The active-member fallback is only sound + // when there is independent evidence THIS request actually + // went through the pooled injection path. The injection-time + // flow tag (set by buildPooledMemberPairs' sibling + // flowInjected.Tag) is exactly that evidence and is keyed by + // flow ID, so it survives a missing/expired refresh-attr tag + // for a genuinely pooled refresh. + if f.Id != uuid.Nil { + if injected, iok := a.flowInjected.Recover(f.Id); iok && injected != "" { + if injPool := pr.PoolForMember(injected); injPool != "" { + return injPool, injected, proto, pr, true + } + // The injected member left the pool but the flow tag + // still proves this request used the pool: cool the + // pool's current active member. + if active, aok := pr.ResolveActive(pool); aok && active != "" { + log.Printf("[POOL-FAILOVER] pool %q: token-endpoint failure "+ + "injected member %q left the pool; falling back to "+ + "active member %q", pool, injected, active) + return pool, active, proto, pr, true + } } } - return pool, matched, proto, pr, true + // No refresh-attr tag AND no flow-injection tag: there is no + // evidence this request used the pool. It is most likely a + // plain OAuth credential that only happens to share the token + // URL. Return ok=false so NO pool member is cooled (a blind + // fallback here would park an innocent active member). + log.Printf("[POOL-FAILOVER] pool %q: token-endpoint failure on a "+ + "shared token URL with no pooled-usage evidence (no refresh-attr "+ + "or flow-injection tag); not cooling any member (likely a plain "+ + "OAuth credential sharing this token URL)", pool) + return "", "", "", nil, false } } return "", "", "", nil, false diff --git a/internal/proxy/pool_failover_test.go b/internal/proxy/pool_failover_test.go index 1822b41..4ab632a 100644 --- a/internal/proxy/pool_failover_test.go +++ b/internal/proxy/pool_failover_test.go @@ -336,6 +336,12 @@ func TestPoolForResponseResolvesActiveMember(t *testing.T) { // binding is on api.openai.com, the OAuth refresh hits auth.openai.com. The // CONNECT-host reverse mapping in poolForResponse therefore CANNOT match a // token-endpoint response — only the token-URL-index path can. +// +// poolName/memberA/memberB are parameterized on purpose: this is a general +// split-host pool fixture and a multi-pool test may legitimately pass other +// names. unparam only sees the current callers all using codex_pool/memA/memB. +// +//nolint:unparam func setupPoolAddonSplitHost(t *testing.T, poolName, memberA, memberB string) (*SluiceAddon, *atomic.Pointer[vault.PoolResolver]) { t.Helper() @@ -399,10 +405,19 @@ func TestTokenEndpointHostFailoverOnPooledMember(t *testing.T) { t.Fatalf("pre-failover active = %q, want memA", got) } + // A genuine pooled refresh ALWAYS goes through pass-2, which records + // the real-refresh -> member attribution tag (buildPooledMemberPairs). + // Model that so the response carries recoverable pool-usage evidence — + // post-Finding-3 a token-endpoint failure with NO pool-usage evidence + // is intentionally NOT failed over (it could be a plain credential + // merely sharing the token URL), so a realistic pooled-refresh + // regression must tag like production does. + addon.refreshAttr.Tag("A-refresh-old", "memA") + // Sanity: the CONNECT-host reverse mapping alone must NOT match here // (this is exactly the gap CRITICAL-2 describes). poolForResponse must // still succeed via the token-URL index path. - f := newPoolRespFlow(client, 400, []byte(`{"error":"invalid_grant"}`)) + f := newPoolRespFlowBody(client, 400, "A-refresh-old", []byte(`{"error":"invalid_grant"}`)) pool, member, _, _, ok := addon.poolForResponse(f) if !ok { t.Fatal("poolForResponse: token-endpoint response on a pooled member must be attributed (CRITICAL-2 fix); got ok=false") @@ -419,7 +434,8 @@ func TestTokenEndpointHostFailoverOnPooledMember(t *testing.T) { }) // A token-endpoint invalid_grant must cool memA and switch to memB. - addon.Response(newPoolRespFlow(client, 400, []byte(`{"error":"invalid_grant"}`))) + // Peek (failover path) does not consume the tag, so it is still live. + addon.Response(newPoolRespFlowBody(client, 400, "A-refresh-old", []byte(`{"error":"invalid_grant"}`))) if active, _ := pr.ResolveActive("codex_pool"); active != "memB" { t.Fatalf("post-failover active = %q, want memB (token-endpoint auth failure must fail over)", active) @@ -643,30 +659,85 @@ func TestTokenEndpointFailover3MemberAttributesMiddleMember(t *testing.T) { } } -// TestTokenEndpointFailoverFallsBackToActiveMember asserts the documented -// fallback: when the real refresh token cannot be recovered from the body -// (no attribution tag — e.g. the request was not driven through pass-2), -// poolForResponse cools the ACTIVE member, never blindly idx.Match's first -// index entry. -func TestTokenEndpointFailoverFallsBackToActiveMember(t *testing.T) { +// TestTokenEndpointFailoverFailClosedWithoutPoolUsageEvidence is the +// Finding 3 regression. A token-endpoint 401 / invalid_grant on a token URL +// that a pool shares must NOT cool a pool member when there is no evidence +// the failing request actually used the pool. The pre-Finding-3 code +// blindly fell back to the pool's ACTIVE member on a missing tag, so a +// PLAIN (non-pool) OAuth credential that merely shares the token URL would, +// on its own invalid_grant, cool an unrelated active pool member and park +// it. The fix returns ok=false unless a refreshAttr OR flowInjected tag +// proves the request went through the pooled injection path. +// +// This test MUST fail before the fix: the old active-member fallback cools +// memB even though nothing tied the failing request to the pool. +func TestTokenEndpointFailoverFailClosedWithoutPoolUsageEvidence(t *testing.T) { + addon, prPtr := setupPoolAddonSplitHost(t, "codex_pool", "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + pr := prPtr.Load() + + // memA cooled -> memB active. NO refreshAttr tag, NO flowInjected tag: + // the failing request carries zero evidence it used the pool (this is + // exactly the shape of a plain non-pool OAuth credential that merely + // shares the token URL hitting its own invalid_grant). + memBPre, _ := pr.CooldownUntil("memB") + pr.MarkCooldown("memA", time.Now().Add(90*time.Second), "429") + if got, _ := pr.ResolveActive("codex_pool"); got != "memB" { + t.Fatalf("active = %q, want memB", got) + } + + f := newPoolRespFlowBody(client, 400, "unrelated-plain-refresh", []byte(`{"error":"invalid_grant"}`)) + pool, member, _, _, ok := addon.poolForResponse(f) + if ok { + t.Fatalf("poolForResponse must fail closed (ok=false) with no pool-usage "+ + "evidence; got ok=true pool=%q member=%q — Finding 3: a plain credential "+ + "sharing the token URL would cool an unrelated active pool member", pool, member) + } + + // Drive the full Response path and assert NO pool member was cooled. + addon.Response(f) + if _, cooling := pr.CooldownUntil("memB"); cooling { + t.Fatal("memB (active member) was cooled by an unattributed shared-token-URL " + + "failure — Finding 3 over-application of the fallback") + } + if bU, _ := pr.CooldownUntil("memB"); !bU.Equal(memBPre) { + t.Fatalf("memB cooldown changed (%v -> %v) despite no pool-usage evidence", memBPre, bU) + } + // memA's original 429 window must be untouched too. + if aU, c := pr.CooldownUntil("memA"); !c { + t.Fatal("memA should still be on its original 429 window") + } else if time.Until(aU) < 60*time.Second { + t.Fatalf("memA 429 window was shortened/cleared: %s left", time.Until(aU)) + } +} + +// TestTokenEndpointFailoverFlowInjectedTagFailsOver is the companion to the +// fail-closed test: when the refreshAttr tag is absent but the +// injection-time flowInjected tag IS present (genuine pooled usage proven +// by the flow ID), the failover MUST still cool the injected member. This +// guards against over-restricting the Finding 3 fix. +func TestTokenEndpointFailoverFlowInjectedTagFailsOver(t *testing.T) { addon, prPtr := setupPoolAddonSplitHost(t, "codex_pool", "memA", "memB") client := setupAddonConn(addon, "auth.example.com:443") pr := prPtr.Load() - // memA cooled -> memB active. NO refreshAttr tag is recorded, and the - // body's refresh token is not in the attribution map, so Peek misses. pr.MarkCooldown("memA", time.Now().Add(90*time.Second), "429") if got, _ := pr.ResolveActive("codex_pool"); got != "memB" { t.Fatalf("active = %q, want memB", got) } - f := newPoolRespFlowBody(client, 400, "untagged-refresh", []byte(`{"error":"invalid_grant"}`)) + // No refreshAttr tag (e.g. it expired), but the request DID go through + // the pooled injection path, so flowInjected carries memB for this flow. + f := newPoolRespFlowBody(client, 400, "expired-refresh", []byte(`{"error":"invalid_grant"}`)) + addon.flowInjected.Tag(f.Id, "memB") + pool, member, _, _, ok := addon.poolForResponse(f) if !ok { - t.Fatal("poolForResponse: expected attribution via active-member fallback") + t.Fatal("poolForResponse: a flow-injection-tagged pooled refresh must still fail over") } if pool != "codex_pool" || member != "memB" { - t.Fatalf("fallback got pool=%q member=%q, want codex_pool/memB (active member, NOT idx.Match's memA)", pool, member) + t.Fatalf("got pool=%q member=%q, want codex_pool/memB (flowInjected tag is "+ + "valid pool-usage evidence — Finding 3 must not over-restrict)", pool, member) } } diff --git a/internal/proxy/pool_phantom_test.go b/internal/proxy/pool_phantom_test.go index 7591515..9a0874a 100644 --- a/internal/proxy/pool_phantom_test.go +++ b/internal/proxy/pool_phantom_test.go @@ -418,3 +418,202 @@ func TestChokepointPlainCredentialUnchanged(t *testing.T) { t.Errorf("plain cred phantom changed; body=%q", body) } } + +// plainCredWithTokenURL builds a plain (non-pool) OAuth credential envelope +// with an explicit token URL. +func plainCredWithTokenURL(t *testing.T, access, refresh, tokenURL string) string { + t.Helper() + c := &vault.OAuthCredential{ + AccessToken: access, + RefreshToken: refresh, + TokenURL: tokenURL, + } + data, err := c.Marshal() + if err != nil { + t.Fatalf("marshal oauth cred: %v", err) + } + return string(data) +} + +// TestR1FailClosedPlainCredFirstMatchSharesPoolTokenURL is the Copilot +// Finding 1 regression. A PLAIN (non-pool) OAuth credential sorts FIRST in +// credential_meta and shares the SAME token URL as a pool. A pooled refresh +// response arrives whose owning member cannot be recovered (no live +// refresh-attr tag — it expired, or the response is slow). idx.Match +// returns the plain credential (first index entry). Before the fix, +// resolveOAuthResponseAttribution took the "matchedCred not pooled" branch +// and persisted the rotated POOLED tokens under the PLAIN credential's +// vault entry — an R1 ("never guess") violation that misfiles one pool +// member's rotated tokens under an unrelated plain credential. +// +// The fix: once ANY pool shares the token URL and the owning member cannot +// be recovered, skip persistence entirely (fail closed). The swap still +// runs so the agent never sees real tokens. A genuinely plain-only token +// URL (no pool sharing) must still persist normally — covered by the +// sub-test below so the fix is not over-restrictive. +// +// MUST fail before the fix: the plain credential's vault entry would be +// overwritten with the pooled refresh's rotated tokens. +func TestR1FailClosedPlainCredFirstMatchSharesPoolTokenURL(t *testing.T) { + const poolName = "codex_pool" + // "aaa_plain" sorts/indexes before the pool members so idx.Match (first + // entry) returns it — exactly the Finding 1 collision shape. + provider := &addonWritableProvider{ + creds: map[string]string{ + "aaa_plain": poolMemberCred(t, "plain-access-old", "plain-refresh-old"), + "memA": poolMemberCred(t, "A-access-old", "A-refresh-old"), + "memB": poolMemberCred(t, "B-access-old", "B-refresh-old"), + }, + } + bindings := []vault.Binding{{ + Destination: "auth.example.com", + Ports: []int{443}, + Credential: poolName, + }} + resolver, err := vault.NewBindingResolver(bindings) + if err != nil { + t.Fatalf("NewBindingResolver: %v", err) + } + var resolverPtr atomic.Pointer[vault.BindingResolver] + resolverPtr.Store(resolver) + + addon := NewSluiceAddon(WithResolver(&resolverPtr), WithProvider(provider)) + addon.persistDone = make(chan struct{}, 10) + + // aaa_plain is FIRST in the metas slice -> first index entry -> what + // idx.Match returns for testOAuthTokenURL. + addon.UpdateOAuthIndex([]store.CredentialMeta{ + {Name: "aaa_plain", CredType: "oauth", TokenURL: testOAuthTokenURL}, + {Name: "memA", CredType: "oauth", TokenURL: testOAuthTokenURL}, + {Name: "memB", CredType: "oauth", TokenURL: testOAuthTokenURL}, + }) + pool := store.Pool{Name: poolName, Strategy: store.PoolStrategyFailover} + pool.Members = []store.PoolMember{ + {Credential: "memA", Position: 0}, + {Credential: "memB", Position: 1}, + } + var prPtr atomic.Pointer[vault.PoolResolver] + prPtr.Store(vault.NewPoolResolver([]store.Pool{pool}, nil)) + addon.SetPoolResolver(&prPtr) + + client := setupAddonConn(addon, "auth.example.com:443") + + // Precondition: idx.Match returns the plain credential (first entry), + // while MatchAll reveals a pool also shares the token URL. + if idx := addon.oauthIndex.Load(); idx != nil { + u, _ := url.Parse(testOAuthTokenURL) + if m, _ := idx.Match(u); m != "aaa_plain" { + t.Fatalf("precondition: idx.Match must return the plain first entry, got %q", m) + } + } + + beforePlain := provider.creds["aaa_plain"] + beforeA := provider.creds["memA"] + beforeB := provider.creds["memB"] + + // A pooled refresh response. NO refresh-attr tag is recorded (it + // expired / the response is slow), so the owning member cannot be + // recovered. The body's refresh token is untracked. + resp := newPoolReqRespFlow(client, + []byte("grant_type=refresh_token&refresh_token=untracked-pooled-refresh"), + mustJSON(t, map[string]interface{}{ + "access_token": "rotated-pooled-access", + "refresh_token": "rotated-pooled-refresh", + "expires_in": 3600, + })) + addon.Response(resp) + + // No vault persist must have been scheduled to ANYONE. + select { + case <-addon.persistDone: + t.Fatal("R1 fail-closed violated (Finding 1): a vault persist was scheduled " + + "for a pooled refresh whose owner could not be recovered, while a plain " + + "credential sorted first and shared the token URL") + default: + } + if provider.creds["aaa_plain"] != beforePlain { + t.Fatal("Finding 1: pooled refresh tokens were misfiled under the PLAIN " + + "credential 'aaa_plain' (R1 'never guess' violation)") + } + if provider.creds["memA"] != beforeA || provider.creds["memB"] != beforeB { + t.Fatal("Finding 1: pooled refresh tokens were written to a pool member " + + "without a recovered owner (must fail closed)") + } + + // Agent must still be protected: the real rotated tokens are swapped to + // the pool-stable phantoms even though nothing was persisted. + body := string(resp.Response.Body) + if strings.Contains(body, "rotated-pooled-access") || strings.Contains(body, "rotated-pooled-refresh") { + t.Errorf("fail-closed must still strip real tokens; body=%q", body) + } + if !strings.Contains(body, poolStablePhantomAccess(poolName)) { + t.Errorf("fail-closed response missing pool-stable phantom; body=%q", body) + } +} + +// TestR1PlainOnlyTokenURLStillPersists is the no-regression companion to +// Finding 1: a plain OAuth credential whose token URL is NOT shared by any +// pool must still persist its rotated tokens normally. The fix only skips +// persistence when a pool shares the token URL, so this 1:1 plain path is +// unchanged. +func TestR1PlainOnlyTokenURLStillPersists(t *testing.T) { + const plainTokenURL = "https://plain-only.example.com/oauth/token" + provider := &addonWritableProvider{ + creds: map[string]string{ + // A pool exists but on a DIFFERENT token URL, so it does not + // share plainTokenURL. + "plainCred": plainCredWithTokenURL(t, "p-access-old", "p-refresh-old", plainTokenURL), + "memA": poolMemberCred(t, "A-access-old", "A-refresh-old"), + "memB": poolMemberCred(t, "B-access-old", "B-refresh-old"), + }, + } + bindings := []vault.Binding{{ + Destination: "plain-only.example.com", + Ports: []int{443}, + Credential: "plainCred", + }} + resolver, err := vault.NewBindingResolver(bindings) + if err != nil { + t.Fatalf("NewBindingResolver: %v", err) + } + var resolverPtr atomic.Pointer[vault.BindingResolver] + resolverPtr.Store(resolver) + + addon := NewSluiceAddon(WithResolver(&resolverPtr), WithProvider(provider)) + addon.persistDone = make(chan struct{}, 10) + addon.UpdateOAuthIndex([]store.CredentialMeta{ + {Name: "plainCred", CredType: "oauth", TokenURL: plainTokenURL}, + // Pool members on the OTHER token URL (testOAuthTokenURL). + {Name: "memA", CredType: "oauth", TokenURL: testOAuthTokenURL}, + {Name: "memB", CredType: "oauth", TokenURL: testOAuthTokenURL}, + }) + pool := store.Pool{Name: "codex_pool", Strategy: store.PoolStrategyFailover} + pool.Members = []store.PoolMember{ + {Credential: "memA", Position: 0}, + {Credential: "memB", Position: 1}, + } + var prPtr atomic.Pointer[vault.PoolResolver] + prPtr.Store(vault.NewPoolResolver([]store.Pool{pool}, nil)) + addon.SetPoolResolver(&prPtr) + + client := setupAddonConn(addon, "plain-only.example.com:443") + resp := newTestResponseFlow(client, plainTokenURL, 200, mustJSON(t, map[string]interface{}{ + "access_token": "p-real-access-NEW", + "refresh_token": "p-real-refresh-NEW", + "expires_in": 3600, + }), "application/json") + addon.Response(resp) + waitAddonPersist(t, addon) + + // The plain credential's vault entry must now hold the rotated tokens. + updated := provider.creds["plainCred"] + if !strings.Contains(updated, "p-real-access-NEW") || !strings.Contains(updated, "p-real-refresh-NEW") { + t.Fatalf("plain-only token URL must still persist rotated tokens to the "+ + "plain credential (no Finding 1 over-restriction); vault=%q", updated) + } + // Agent still gets phantoms, not the real rotated tokens. + body := string(resp.Response.Body) + if strings.Contains(body, "p-real-access-NEW") || strings.Contains(body, "p-real-refresh-NEW") { + t.Errorf("plain-only: real token leaked to agent; body=%q", body) + } +} From a560bfdb5ee27e92faef1b5141ea3da6072264f6 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 15:28:15 +0800 Subject: [PATCH 31/49] fix(cli): reject pool removal while bindings still reference the pool --- cmd/sluice/pool.go | 25 +++++++++- cmd/sluice/pool_test.go | 106 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 1 deletion(-) diff --git a/cmd/sluice/pool.go b/cmd/sluice/pool.go index 56b8f48..a66bac2 100644 --- a/cmd/sluice/pool.go +++ b/cmd/sluice/pool.go @@ -242,6 +242,29 @@ func handlePoolRemove(args []string) error { } defer func() { _ = db.Close() }() + // Reject the removal while any binding still references this pool by + // name. A pool shares the credential namespace, so a binding's + // "credential" column may hold the pool name (e.g. created via + // "sluice binding add --destination ..."). Deleting the pool + // out from under such bindings would leave them pointing at a + // non-existent credential (injection silently fails for those + // destinations) and, worse, a later credential created with the same + // name would silently inherit the stale bindings. This mirrors the + // fail-closed pool-membership guard in "sluice cred remove": refuse + // with an actionable error instead of cascading or orphaning. + refs, err := db.ListBindingsByCredential(name) + if err != nil { + return fmt.Errorf("check bindings referencing pool %q: %w", name, err) + } + if len(refs) > 0 { + details := make([]string, len(refs)) + for i, b := range refs { + details[i] = fmt.Sprintf("[%d] %s", b.ID, b.Destination) + } + return fmt.Errorf("pool %q is still referenced by %d binding(s): %s; rebind or remove these bindings first (sluice binding remove , which also clears the auto-created allow rule), then retry pool remove", + name, len(refs), strings.Join(details, ", ")) + } + removed, err := db.RemovePool(name) if err != nil { return err @@ -249,6 +272,6 @@ func handlePoolRemove(args []string) error { if !removed { return fmt.Errorf("pool %q not found", name) } - fmt.Printf("pool %q removed (members and bindings referencing it are unaffected; remove stale bindings with 'sluice binding remove')\n", name) + fmt.Printf("pool %q removed\n", name) return nil } diff --git a/cmd/sluice/pool_test.go b/cmd/sluice/pool_test.go index 165fdad..ac23405 100644 --- a/cmd/sluice/pool_test.go +++ b/cmd/sluice/pool_test.go @@ -3,6 +3,7 @@ package main import ( "os" "path/filepath" + "strconv" "strings" "testing" @@ -166,6 +167,111 @@ func TestCredRemoveBlockedForLivePoolMember(t *testing.T) { sb.Release() } +// TestPoolRemoveBlockedWhileBindingReferencesIt asserts that "sluice pool +// remove" refuses while a binding still references the pool by name, then +// succeeds once the binding is gone. Without the guard the pool would be +// deleted out from under the binding, leaving it pointing at a non-existent +// credential (and silently re-applying to any later credential created with +// the same name). Regression for Copilot round-7 finding. +func TestPoolRemoveBlockedWhileBindingReferencesIt(t *testing.T) { + dir := t.TempDir() + dbPath := setupVaultDB(t, dir) + seedPoolCred(t, dbPath, dir, "acct_a") + seedPoolCred(t, dbPath, dir, "acct_b") + if err := handlePoolCommand([]string{"create", "--db", dbPath, "--members", "acct_a,acct_b", "codex"}); err != nil { + t.Fatalf("pool create: %v", err) + } + + // binding add creates a binding whose credential == "codex" plus + // a paired auto-created allow rule tagged binding-add:codex. + if err := handleBindingCommand([]string{"add", "--db", dbPath, "--destination", "api.example.com", "codex"}); err != nil { + t.Fatalf("binding add: %v", err) + } + + // pool remove must be refused while the binding references the pool. + err := handlePoolCommand([]string{"remove", "--db", dbPath, "codex"}) + if err == nil { + t.Fatalf("pool remove with referencing binding: err = nil, want block error") + } + if !strings.Contains(err.Error(), "still referenced by") || !strings.Contains(err.Error(), "api.example.com") { + t.Fatalf("pool remove error = %v, want message naming the blocking binding", err) + } + + // The pool must still exist (removal was refused before RemovePool). + db, derr := store.New(dbPath) + if derr != nil { + t.Fatalf("open db: %v", derr) + } + p, perr := db.GetPool("codex") + if perr != nil { + t.Fatalf("get pool after blocked remove: %v", perr) + } + if p == nil { + t.Fatalf("pool %q was deleted despite blocked removal", "codex") + } + _ = db.Close() + + // Resolve the blocking binding id and remove it (clears the paired rule + // too via RemoveBindingWithRuleCleanup). + db2, derr2 := store.New(dbPath) + if derr2 != nil { + t.Fatalf("reopen db: %v", derr2) + } + refs, lerr := db2.ListBindingsByCredential("codex") + if lerr != nil || len(refs) != 1 { + t.Fatalf("list bindings: refs=%v err=%v", refs, lerr) + } + bindingID := refs[0].ID + _ = db2.Close() + + if err := handleBindingCommand([]string{"remove", "--db", dbPath, strconv.FormatInt(bindingID, 10)}); err != nil { + t.Fatalf("binding remove: %v", err) + } + + // With no binding referencing it, pool remove now succeeds. + out := captureStdout(t, func() { + if err := handlePoolCommand([]string{"remove", "--db", dbPath, "codex"}); err != nil { + t.Fatalf("pool remove after binding removed: %v", err) + } + }) + if !strings.Contains(out, `pool "codex" removed`) { + t.Errorf("remove output = %q", out) + } + + db3, derr3 := store.New(dbPath) + if derr3 != nil { + t.Fatalf("reopen db: %v", derr3) + } + gone, gerr := db3.GetPool("codex") + if gerr != nil { + t.Fatalf("get pool after successful remove: %v", gerr) + } + if gone != nil { + t.Fatalf("pool %q still exists after successful remove", "codex") + } + _ = db3.Close() +} + +// TestPoolRemoveCleanWithoutReferencingBindings is the no-regression case: +// a pool with no binding referencing it removes cleanly as before. +func TestPoolRemoveCleanWithoutReferencingBindings(t *testing.T) { + dir := t.TempDir() + dbPath := setupVaultDB(t, dir) + seedPoolCred(t, dbPath, dir, "acct_a") + seedPoolCred(t, dbPath, dir, "acct_b") + if err := handlePoolCommand([]string{"create", "--db", dbPath, "--members", "acct_a,acct_b", "codex"}); err != nil { + t.Fatalf("pool create: %v", err) + } + out := captureStdout(t, func() { + if err := handlePoolCommand([]string{"remove", "--db", dbPath, "codex"}); err != nil { + t.Fatalf("pool remove (no bindings): %v", err) + } + }) + if !strings.Contains(out, `pool "codex" removed`) { + t.Errorf("remove output = %q", out) + } +} + // TestCredRemoveFailsClosedWhenDBUnopenable asserts that when the policy DB // path exists but cannot be opened, cred remove refuses (fails closed) // instead of logging a warning and deleting the vault secret anyway. A From 661e51eab318b4da902699a3276af1980275a688 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 15:47:29 +0800 Subject: [PATCH 32/49] fix(store): enforce namespace + pool-member + health-row invariants at store layer; engine-aware persist fast-path --- internal/proxy/server.go | 35 ++++++++++- internal/proxy/server_test.go | 90 ++++++++++++++++++++++++++++ internal/store/pools_test.go | 110 ++++++++++++++++++++++++++++++++++ internal/store/store.go | 81 +++++++++++++++++++++++-- 4 files changed, 311 insertions(+), 5 deletions(-) diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 346feb2..7173da1 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -527,7 +527,40 @@ func (r *policyRuleSet) persistApprovalRule(verdict, dest string, port int) bool if exists, existsErr := r.store.HasApprovalRule(verdict, dest, port); existsErr != nil { log.Printf("[WARN] failed to check existing %s rule for %s:%d: %v", verdict, dest, port, existsErr) } else if exists { - log.Printf("[approval] %s rule for %s:%d already present; skipping duplicate persist", verdict, dest, port) + // The row is present, but a prior persist may have written the row + // (AddRule) and then failed at LoadFromStore/Validate before the + // engine pointer swapped. In that window the durable store has the + // rule while the live engine still evaluates dest:port as ask. Only + // fast-path when the CURRENT live engine already reflects the rule; + // otherwise recompile/swap so the engine is guaranteed current + // before we report success. Without this, a coalesced caller that + // trusts the row would skip the safety-net per-request checker even + // though the live engine has not yet learned the rule. + if eng := r.engine.Load(); eng != nil { + v, src := eng.EvaluateDetailed(dest, port) + want := policy.Allow + if verdict == "deny" { + want = policy.Deny + } + if src == policy.RuleMatch && v == want { + log.Printf("[approval] %s rule for %s:%d already present and live engine current; skipping duplicate persist", verdict, dest, port) + return true + } + } + // Row present but engine stale: recompile from the store (the row + // is already there, so no duplicate AddRule) and swap so subsequent + // callers see a current engine. + log.Printf("[approval] %s rule for %s:%d present but live engine stale; recompiling", verdict, dest, port) + newEng, recompErr := policy.LoadFromStore(r.store) + if recompErr != nil { + log.Printf("[WARN] failed to recompile engine for stale %s rule %s:%d: %v", verdict, dest, port, recompErr) + return false + } + if valErr := newEng.Validate(); valErr != nil { + log.Printf("[WARN] engine validation failed for stale %s rule %s:%d: %v", verdict, dest, port, valErr) + return false + } + r.engine.Store(newEng) return true } if _, storeErr := r.store.AddRule(verdict, store.RuleOpts{Destination: dest, Ports: []int{port}, Source: "approval"}); storeErr != nil { diff --git a/internal/proxy/server_test.go b/internal/proxy/server_test.go index 3a9974e..cb98c8e 100644 --- a/internal/proxy/server_test.go +++ b/internal/proxy/server_test.go @@ -5043,3 +5043,93 @@ func TestPersistApprovalRuleSinglePersistUnchanged(t *testing.T) { t.Error("expected engine pointer to be swapped after persist") } } + +// TestPersistApprovalRuleRowPresentEngineStale is the Finding 1 regression. +// A prior persist may have written the rule row (AddRule) and then failed at +// LoadFromStore/Validate before the engine pointer swapped: the durable store +// has the rule but the live engine still evaluates dest:port as the default +// verdict. A later coalesced caller hitting the HasApprovalRule fast path must +// NOT just return true off the row's presence — it must make the live engine +// current first, otherwise callers skip the safety-net per-request checker +// while the live engine has not learned the rule. +func TestPersistApprovalRuleRowPresentEngineStale(t *testing.T) { + st, err := store.New(":memory:") + if err != nil { + t.Fatalf("store.New: %v", err) + } + defer func() { _ = st.Close() }() + + // Compile a STALE engine snapshot (empty store: no rules). + staleEng, err := policy.LoadFromStore(st) + if err != nil { + t.Fatalf("LoadFromStore (stale): %v", err) + } + // Sanity: the stale engine does not yet rule-match the destination. + if v, src := staleEng.EvaluateDetailed("stale.example.com", 443); src == policy.RuleMatch && v == policy.Allow { + t.Fatal("precondition: stale engine unexpectedly already allows the destination via a rule") + } + + // Simulate the partial-persist window: the row IS in the store, but the + // live engine pointer still points at the stale snapshot. + if _, err := st.AddRule("allow", store.RuleOpts{ + Destination: "stale.example.com", Ports: []int{443}, Source: "approval", + }); err != nil { + t.Fatalf("AddRule: %v", err) + } + if has, herr := st.HasApprovalRule("allow", "stale.example.com", 443); herr != nil || !has { + t.Fatalf("precondition: expected row present, got has=%v err=%v", has, herr) + } + + engPtr := new(atomic.Pointer[policy.Engine]) + engPtr.Store(staleEng) + var reloadMu sync.Mutex + r := &policyRuleSet{engine: engPtr, reloadMu: &reloadMu, store: st} + + if !r.persistApprovalRule("allow", "stale.example.com", 443) { + t.Fatal("persistApprovalRule returned false") + } + + // The live engine must have been swapped to one that reflects the rule; + // the stale pointer must no longer be installed. + if engPtr.Load() == staleEng { + t.Fatal("engine pointer not swapped: stale engine still live (callers would skip the safety-net checker)") + } + if v, srcM := engPtr.Load().EvaluateDetailed("stale.example.com", 443); srcM != policy.RuleMatch || v != policy.Allow { + t.Fatalf("live engine still stale after persist: verdict=%v source=%v, want Allow/RuleMatch", v, srcM) + } + + // No duplicate insert may have happened (the row was already there). + rules, err := st.ListRules(store.RuleFilter{}) + if err != nil { + t.Fatalf("ListRules: %v", err) + } + matches := 0 + for _, ru := range rules { + if ru.Destination == "stale.example.com" && ru.Source == "approval" { + matches++ + } + } + if matches != 1 { + t.Fatalf("expected exactly 1 persisted approval rule, got %d", matches) + } + + // Idempotent path: row present AND engine now current -> fast-path, still + // returns true, still no duplicate insert, engine unchanged. + curEng := engPtr.Load() + if !r.persistApprovalRule("allow", "stale.example.com", 443) { + t.Fatal("second persistApprovalRule returned false") + } + if engPtr.Load() != curEng { + t.Error("engine pointer swapped on the idempotent fast path (row present AND engine current)") + } + rules, _ = st.ListRules(store.RuleFilter{}) + matches = 0 + for _, ru := range rules { + if ru.Destination == "stale.example.com" && ru.Source == "approval" { + matches++ + } + } + if matches != 1 { + t.Fatalf("idempotent fast path inserted a duplicate: %d rules", matches) + } +} diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go index cc7af0a..c943a42 100644 --- a/internal/store/pools_test.go +++ b/internal/store/pools_test.go @@ -403,3 +403,113 @@ func TestMigration000006DownUp(t *testing.T) { } } } + +// TestRemoveCredentialMetaBlocksLivePoolMember is the Finding 3 regression. +// The pool-member integrity guard must live in the store layer so the REST +// API and Telegram removal paths (which call RemoveCredentialMeta directly, +// bypassing the CLI guard) cannot orphan a credential_pool_members row. +func TestRemoveCredentialMetaBlocksLivePoolMember(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "member") + seedOAuthCred(t, s, "other") + if err := s.CreatePoolWithMembers("p", "failover", []string{"member", "other"}); err != nil { + t.Fatalf("CreatePoolWithMembers: %v", err) + } + + // Store-level removal of a live pool member must be refused. + removed, err := s.RemoveCredentialMeta("member") + if err == nil { + t.Fatal("expected RemoveCredentialMeta to refuse a live pool member (Finding 3)") + } + if removed { + t.Fatal("RemoveCredentialMeta reported removed=true for a refused removal") + } + // The meta row must still be present (refusal must not delete anything). + m, gerr := s.GetCredentialMeta("member") + if gerr != nil || m == nil { + t.Fatalf("credential meta deleted despite refusal: %+v, %v", m, gerr) + } + // And the member row must still point at a real credential. + pools, perr := s.PoolsForMember("member") + if perr != nil || len(pools) != 1 || pools[0] != "p" { + t.Fatalf("PoolsForMember(member) = %v, %v; want [p] (no dangling change)", pools, perr) + } + + // Removing a NON-member credential still works. + seedOAuthCred(t, s, "free") + removed, err = s.RemoveCredentialMeta("free") + if err != nil || !removed { + t.Fatalf("RemoveCredentialMeta(free) = %v, %v; want true, nil", removed, err) + } +} + +// TestRemoveCredentialMetaCleansHealthRow is the Finding 2 regression. +// credential_health is keyed by name and not FK-tied to credential_meta, so a +// bare meta delete would leave a stale cooldown that a recreated same-named +// credential inherits on the next resolver seed. +func TestRemoveCredentialMetaCleansHealthRow(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "x") + + // Seed a cooldown for x. + until := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("x", "cooldown", until, "429 rate limited"); err != nil { + t.Fatalf("SetCredentialHealth: %v", err) + } + if h, _ := s.GetCredentialHealth("x"); h == nil || h.Status != "cooldown" { + t.Fatalf("precondition: expected x in cooldown, got %+v", h) + } + + // Remove the credential. The health row must go with it. + removed, err := s.RemoveCredentialMeta("x") + if err != nil || !removed { + t.Fatalf("RemoveCredentialMeta(x) = %v, %v; want true, nil", removed, err) + } + if h, herr := s.GetCredentialHealth("x"); herr != nil || h != nil { + t.Fatalf("stale health row survived removal: %+v, %v", h, herr) + } + + // Recreate the same-named credential and add it to a fresh pool. It must + // NOT inherit the old cooldown — GetCredentialHealth is nil (= healthy). + seedOAuthCred(t, s, "x") + seedOAuthCred(t, s, "y") + if err := s.CreatePoolWithMembers("fresh", "failover", []string{"x", "y"}); err != nil { + t.Fatalf("CreatePoolWithMembers(fresh): %v", err) + } + if h, herr := s.GetCredentialHealth("x"); herr != nil || h != nil { + t.Fatalf("recreated credential inherited a stale cooldown: %+v, %v", h, herr) + } +} + +// TestAddCredentialMetaRejectsPoolNameCollision is the Finding 4 regression. +// The pool-vs-credential namespace mutual-exclusion must be enforced in the +// store so the REST API and any other AddCredentialMeta caller cannot create +// a credential whose name collides with an existing pool. +func TestAddCredentialMetaRejectsPoolNameCollision(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "acct_a") + seedOAuthCred(t, s, "acct_b") + if err := s.CreatePoolWithMembers("codex", "failover", []string{"acct_a", "acct_b"}); err != nil { + t.Fatalf("CreatePoolWithMembers: %v", err) + } + + // A credential named "codex" collides with the existing pool. + if err := s.AddCredentialMeta("codex", "oauth", "https://auth.example.com/token"); err == nil { + t.Fatal("expected AddCredentialMeta to reject a name that collides with an existing pool (Finding 4)") + } + // No credential_meta row may have been written. + if m, _ := s.GetCredentialMeta("codex"); m != nil { + t.Fatalf("credential_meta row leaked for colliding name: %+v", m) + } + + // A non-colliding name still succeeds. + if err := s.AddCredentialMeta("not_a_pool", "static", ""); err != nil { + t.Fatalf("AddCredentialMeta(not_a_pool) = %v, want nil", err) + } + + // The reverse direction still holds: CreatePoolWithMembers rejects a + // name that already exists as a credential. + if err := s.CreatePoolWithMembers("not_a_pool", "failover", []string{"acct_a"}); err == nil { + t.Fatal("expected CreatePoolWithMembers to reject a name that is already a credential") + } +} diff --git a/internal/store/store.go b/internal/store/store.go index 74f92de..d78a762 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1738,15 +1738,42 @@ func (s *Store) AddCredentialMeta(name, credType, tokenURL string) error { if credType == "oauth" && tokenURL == "" { return fmt.Errorf("token_url is required for oauth credentials") } - _, err := s.db.Exec( + tx, err := s.db.Begin() + if err != nil { + return fmt.Errorf("begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + // Namespace mutual-exclusion: a credential must not shadow a pool. Pool + // and credential names share one namespace so a bound destination + // resolves unambiguously to either a pool or a plain credential. This is + // the store-layer counterpart to the pool-side check in + // CreatePoolWithMembers; enforcing it here protects every credential + // creation path (CLI, REST API, any future caller), not just the CLI. + // The check and the insert run in one transaction so a concurrent + // CreatePoolWithMembers cannot interleave between them. + var poolName string + collErr := tx.QueryRow("SELECT name FROM credential_pools WHERE name = ?", name).Scan(&poolName) + switch collErr { + case nil: + return fmt.Errorf("name %q is already a credential pool; pool and credential names share one namespace", name) + case sql.ErrNoRows: + // ok + default: + return fmt.Errorf("check pool name collision for %q: %w", name, collErr) + } + + if _, err := tx.Exec( `INSERT INTO credential_meta (name, cred_type, token_url) VALUES (?, ?, ?) ON CONFLICT(name) DO UPDATE SET cred_type = excluded.cred_type, token_url = excluded.token_url`, name, credType, nilIfEmpty(tokenURL), - ) - if err != nil { + ); err != nil { return fmt.Errorf("insert credential meta: %w", err) } + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit: %w", err) + } return nil } @@ -1792,11 +1819,57 @@ func (s *Store) ListCredentialMeta() ([]CredentialMeta, error) { // RemoveCredentialMeta deletes a credential metadata row by name. Returns true // if a row was deleted. +// +// Store-layer invariant enforcement (protects every caller, not just the CLI): +// +// - Pool-member integrity: removal is REFUSED (fail-closed) when the +// credential is still a live member of one or more pools. Deleting it +// would leave a credential_pool_members row pointing at a missing +// credential, so the pool would resolve to an uninjectable credential. +// The operator must remove it from the pool first. The CLI grew a guard +// for this earlier, but the REST API and Telegram paths call this method +// directly and bypassed it; pushing the guard here closes all paths. +// - Health-row cleanup: the per-credential credential_health row is keyed +// by credential name and is not tied to credential_meta by a foreign +// key, so a bare meta delete would leave a stale cooldown behind. +// Recreating a same-named credential would then inherit that stale +// cooldown on the next resolver seed. The health row is deleted in the +// same transaction as the meta row so the two never diverge. func (s *Store) RemoveCredentialMeta(name string) (bool, error) { - res, err := s.db.Exec("DELETE FROM credential_meta WHERE name = ?", name) + tx, err := s.db.Begin() + if err != nil { + return false, fmt.Errorf("begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + // Fail-closed pool-member guard. Same semantics as the CLI guard in + // `sluice cred remove`: a credential that is a live pool member cannot + // be removed until it is taken out of the pool. + var pool string + memErr := tx.QueryRow( + "SELECT pool FROM credential_pool_members WHERE credential = ? ORDER BY pool LIMIT 1", name, + ).Scan(&pool) + switch memErr { + case nil: + return false, fmt.Errorf("credential %q is a member of pool %q; remove it from the pool first (sluice pool remove

, or recreate the pool without it)", name, pool) + case sql.ErrNoRows: + // not a pool member; safe to remove + default: + return false, fmt.Errorf("check pool membership for %q: %w", name, memErr) + } + + res, err := tx.Exec("DELETE FROM credential_meta WHERE name = ?", name) if err != nil { return false, fmt.Errorf("delete credential meta: %w", err) } + // Drop the health row in the same transaction so a removed credential + // never leaves a stale cooldown that a same-named recreation inherits. + if _, err := tx.Exec("DELETE FROM credential_health WHERE credential = ?", name); err != nil { + return false, fmt.Errorf("delete credential health: %w", err) + } + if err := tx.Commit(); err != nil { + return false, fmt.Errorf("commit: %w", err) + } n, _ := res.RowsAffected() return n > 0, nil } From dcaa6f3de3903a9b26e8e7fdc4601e47eea67e44 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 16:11:45 +0800 Subject: [PATCH 33/49] fix(proxy): pool-namespace covered-set; shared-health prune for non-members; store-gated vault delete on cred remove --- cmd/sluice/cred.go | 86 ++++++------ cmd/sluice/cred_test.go | 195 ++++++++++++++++++++++++++ internal/api/server.go | 25 ++-- internal/proxy/addon.go | 25 +++- internal/proxy/pool_splithost_test.go | 118 ++++++++++++++++ internal/telegram/commands.go | 48 +++++-- internal/vault/pool.go | 23 ++- internal/vault/pool_test.go | 78 +++++++++++ 8 files changed, 532 insertions(+), 66 deletions(-) diff --git a/cmd/sluice/cred.go b/cmd/sluice/cred.go index 0195791..42d1527 100644 --- a/cmd/sluice/cred.go +++ b/cmd/sluice/cred.go @@ -562,28 +562,50 @@ func handleCredRemove(args []string) error { } name := fs.Arg(0) - // Block removing a credential that is still a live pool member so no - // dangling member rows are left behind. The operator must remove it - // from the pool first. This check runs before the vault delete so a - // blocked removal does not destroy the secret. Only consult the DB if - // it already exists (do not create it as a side effect of a removal). + // Store-first removal order (Finding 3, round-9). The authoritative + // pool-membership gate is the atomic, fail-closed RemoveCredentialMeta + // in the store layer: it refuses inside its own transaction if the + // credential is still a live pool member, closing the TOCTOU window + // where a separate pre-check passes and a concurrent caller then + // creates a pool with this credential before the vault secret is + // deleted. The vault secret is therefore only deleted AFTER the store + // removal has succeeded; if the store removal refuses, the vault + // secret is left untouched and no window exists where the secret is + // gone but credential_pool_members still references it. + // + // Only consult/mutate the DB if it already exists (do not create it as + // a side effect of a removal). + dbExists := false if _, statErr := os.Stat(*dbPath); statErr == nil { - guardDB, gerr := store.New(*dbPath) - if gerr != nil { + dbExists = true + } else if !os.IsNotExist(statErr) { + return fmt.Errorf("access database %q for credential removal of %q (refusing to remove; a pool member may otherwise be orphaned): %w", *dbPath, name, statErr) + } + + var db *store.Store + if dbExists { + var derr error + db, derr = store.New(*dbPath) + if derr != nil { // Fail closed: the DB exists but cannot be opened, so the - // pool-membership guard cannot run. Proceeding to delete the + // pool-membership gate cannot run. Proceeding to delete the // vault secret would orphan any credential_pool_members row // pointing at this now-missing credential -- exactly what the - // guard prevents. Refuse the removal instead. - return fmt.Errorf("open database %q to check pool membership for %q (refusing to remove; a pool member may otherwise be orphaned): %w", *dbPath, name, gerr) + // gate prevents. Refuse the removal instead. + return fmt.Errorf("open database %q to check pool membership for %q (refusing to remove; a pool member may otherwise be orphaned): %w", *dbPath, name, derr) } - pools, perr := guardDB.PoolsForMember(name) - _ = guardDB.Close() - if perr != nil { - return fmt.Errorf("check pool membership for %q: %w", name, perr) + defer func() { _ = db.Close() }() + + // GATE: atomic, fail-closed pool-member guard. This MUST run before + // the vault delete. If the credential is still a live pool member, + // RemoveCredentialMeta returns an error inside its transaction and + // the vault secret below is never touched. + metaDeleted, rmMetaErr := db.RemoveCredentialMeta(name) + if rmMetaErr != nil { + return fmt.Errorf("remove credential metadata for %q (refusing to delete the vault secret so a pool member is not orphaned): %w", name, rmMetaErr) } - if len(pools) > 0 { - return fmt.Errorf("credential %q is a member of pool(s) %s; remove it from the pool first (sluice pool remove

, or recreate the pool without it)", name, strings.Join(pools, ", ")) + if metaDeleted { + fmt.Printf("removed credential metadata for %q\n", name) } } @@ -592,8 +614,10 @@ func handleCredRemove(args []string) error { return err } - // Remove from vault. If already gone (previous partial cleanup), - // continue to DB cleanup so stale rules/bindings can be removed. + // Store removal already succeeded (or the DB does not exist). Now it is + // safe to delete the vault secret. If already gone (previous partial + // cleanup), continue to DB cleanup so stale rules/bindings can be + // removed. if err := vs.Remove(name); err != nil { if !os.IsNotExist(err) { return fmt.Errorf("remove: %w", err) @@ -603,22 +627,12 @@ func handleCredRemove(args []string) error { fmt.Printf("credential %q removed\n", name) } - // Clean up associated bindings and auto-created rules. Only open the - // store if the DB file exists to avoid creating it as a side effect of - // a credential removal. - if _, statErr := os.Stat(*dbPath); statErr != nil { - if !os.IsNotExist(statErr) { - log.Printf("warning: cannot access database %q for cleanup: %v (stale rules/bindings may remain)", *dbPath, statErr) - } - return nil - } - - db, err := store.New(*dbPath) - if err != nil { - log.Printf("warning: could not open database %q for cleanup: %v (stale rules/bindings may remain)", *dbPath, err) + // Clean up associated bindings and auto-created rules. The DB handle + // was opened above for the membership gate; if the DB did not exist + // there is nothing to clean up. + if db == nil { return nil } - defer func() { _ = db.Close() }() // Remove rules tagged either by "sluice cred add --destination" // (cred-add:) or by "sluice binding add" (binding-add:). @@ -646,14 +660,6 @@ func handleCredRemove(args []string) error { } else if removed > 0 { fmt.Printf("removed %d binding(s) for %q\n", removed, name) } - - // Remove credential metadata (type, token_url). - metaDeleted, rmMetaErr := db.RemoveCredentialMeta(name) - if rmMetaErr != nil { - log.Printf("warning: failed to remove credential meta for %q: %v", name, rmMetaErr) - } else if metaDeleted { - fmt.Printf("removed credential metadata for %q\n", name) - } return nil } diff --git a/cmd/sluice/cred_test.go b/cmd/sluice/cred_test.go index 5f9f3a6..bde1bd0 100644 --- a/cmd/sluice/cred_test.go +++ b/cmd/sluice/cred_test.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "sync" "testing" "github.com/nemirovsky/sluice/internal/store" @@ -2522,3 +2523,197 @@ func TestHandleCredUpdateNoName(t *testing.T) { t.Errorf("expected usage error, got: %v", err) } } + +// TestFinding3Round9_StoreGatedVaultDeleteOnLivePoolMember is the Copilot +// round-9 Finding 3 regression. A credential that is a live pool member must +// NOT have its vault secret deleted by `cred remove`: the store-layer +// RemoveCredentialMeta fail-closes on a live pool member, and the vault +// delete must be GATED on that store removal succeeding (store-first order). +// +// Before the fix, vs.Remove(name) ran BEFORE the guarded RemoveCredentialMeta, +// so even though the command ultimately reported the credential could not be +// removed, the vault secret was already destroyed — leaving the pool pointing +// at a deleted secret (the TOCTOU window). The fix performs the guarded store +// removal first and only deletes the vault secret if it succeeds. +func TestFinding3Round9_StoreGatedVaultDeleteOnLivePoolMember(t *testing.T) { + dir := t.TempDir() + dbPath := setupVaultDB(t, dir) + + // Create an OAuth credential and make it a live pool member. + vs, err := vault.NewStore(dir) + if err != nil { + t.Fatal(err) + } + if _, err := vs.Add("pool_mem", `{"access_token":"at","refresh_token":"rt"}`); err != nil { + t.Fatal(err) + } + + db, err := store.New(dbPath) + if err != nil { + t.Fatal(err) + } + if err := db.AddCredentialMeta("pool_mem", "oauth", "https://auth.example.com/token"); err != nil { + t.Fatalf("AddCredentialMeta: %v", err) + } + if err := db.CreatePoolWithMembers("codex_pool", "failover", []string{"pool_mem"}); err != nil { + t.Fatalf("CreatePoolWithMembers: %v", err) + } + _ = db.Close() + + // `cred remove` must refuse because pool_mem is a live pool member. + err = handleCredCommand([]string{"remove", "pool_mem", "--db", dbPath}) + if err == nil { + t.Fatal("Finding 3 r9: cred remove of a live pool member must fail") + } + if !strings.Contains(err.Error(), "pool") { + t.Errorf("Finding 3 r9: expected a pool-membership error, got: %v", err) + } + + // The vault secret MUST still be present: the store gate refused, so the + // vault delete must never have run (store-first ordering). + names, err := vs.List() + if err != nil { + t.Fatal(err) + } + found := false + for _, n := range names { + if n == "pool_mem" { + found = true + break + } + } + if !found { + t.Fatal("Finding 3 r9: vault secret was deleted despite the store removal being refused — store-first ordering violated, pool now points at a deleted secret") + } + + // The credential must still be loadable (not just listed). + sec, gerr := vs.Get("pool_mem") + if gerr != nil { + t.Fatalf("Finding 3 r9: vault secret unreadable after a refused removal: %v", gerr) + } + sec.Release() + + // Sanity: removing it from the pool first then `cred remove` succeeds + // and deletes BOTH (store-first, then vault). + db2, err := store.New(dbPath) + if err != nil { + t.Fatal(err) + } + if _, err := db2.RemovePool("codex_pool"); err != nil { + t.Fatalf("RemovePool: %v", err) + } + _ = db2.Close() + + if err := handleCredCommand([]string{"remove", "pool_mem", "--db", dbPath}); err != nil { + t.Fatalf("Finding 3 r9: cred remove after pool removal should succeed: %v", err) + } + names, err = vs.List() + if err != nil { + t.Fatal(err) + } + for _, n := range names { + if n == "pool_mem" { + t.Fatal("Finding 3 r9: normal removal (not a pool member) should delete the vault secret") + } + } + db3, err := store.New(dbPath) + if err != nil { + t.Fatal(err) + } + defer func() { _ = db3.Close() }() + if meta, gerr := db3.GetCredentialMeta("pool_mem"); gerr == nil && meta != nil { + t.Fatal("Finding 3 r9: credential meta should have been removed (store-first) on normal removal") + } +} + +// TestFinding3Round9_TOCTOUInterleaveStoreGatesVaultDelete deterministically +// exercises the round-9 Finding 3 TOCTOU: a pool is created concurrently +// with `cred remove`. The invariant that MUST always hold is: +// +// if the credential is still referenced by credential_pool_members, +// its vault secret MUST still exist. +// +// With the OLD vault-first order, an interleave where the pool is created +// after the (separate) pre-check read but before vs.Remove deletes the vault +// secret while the membership row survives — the pool is left pointing at a +// deleted secret. With the store-first fix the atomic, fail-closed +// RemoveCredentialMeta runs BEFORE vs.Remove, so either it refuses (pool +// already exists -> vault untouched) or it succeeds (no membership -> +// consistent). There is never a state where the secret is gone but +// membership still references it. +func TestFinding3Round9_TOCTOUInterleaveStoreGatesVaultDelete(t *testing.T) { + for iter := 0; iter < 40; iter++ { + dir := t.TempDir() + dbPath := setupVaultDB(t, dir) + + vs, err := vault.NewStore(dir) + if err != nil { + t.Fatal(err) + } + if _, err := vs.Add("racer", `{"access_token":"at","refresh_token":"rt"}`); err != nil { + t.Fatal(err) + } + seed, err := store.New(dbPath) + if err != nil { + t.Fatal(err) + } + if err := seed.AddCredentialMeta("racer", "oauth", "https://auth.example.com/token"); err != nil { + t.Fatalf("AddCredentialMeta: %v", err) + } + _ = seed.Close() + + start := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(2) + + // Racer A: create a pool with "racer" as a member. + go func() { + defer wg.Done() + <-start + pdb, e := store.New(dbPath) + if e != nil { + return + } + _ = pdb.CreatePoolWithMembers("codex_pool", "failover", []string{"racer"}) + _ = pdb.Close() + }() + + // Racer B: `cred remove racer`. + go func() { + defer wg.Done() + <-start + _ = handleCredCommand([]string{"remove", "racer", "--db", dbPath}) + }() + + close(start) + wg.Wait() + + // Invariant check: if membership still references "racer", the vault + // secret MUST still be present. + chk, err := store.New(dbPath) + if err != nil { + t.Fatal(err) + } + pools, perr := chk.PoolsForMember("racer") + _ = chk.Close() + if perr != nil { + t.Fatalf("PoolsForMember: %v", err) + } + + names, lerr := vs.List() + if lerr != nil { + t.Fatal(lerr) + } + vaultHas := false + for _, n := range names { + if n == "racer" { + vaultHas = true + break + } + } + + if len(pools) > 0 && !vaultHas { + t.Fatalf("iter %d: TOCTOU violated — credential is a member of pool(s) %v but its vault secret was deleted (store-first ordering must gate the vault delete)", iter, pools) + } + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 28e7d33..ccaef29 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -1263,7 +1263,20 @@ func (s *Server) DeleteApiCredentialsName(w http.ResponseWriter, r *http.Request } } - // Remove associated bindings and auto-created rules first. If vault.Remove + // Store-first removal order (Finding 3, round-9). RemoveCredentialMeta + // is the authoritative, fail-closed pool-membership gate: it refuses + // inside its own transaction if the credential is still a live pool + // member. Run it BEFORE the vault delete (and before bindings/rules + // cleanup) so the vault secret is only destroyed once the store has + // accepted the removal. If it refuses, the vault secret, bindings, and + // rules are all left intact and no window exists where the secret is + // gone but credential_pool_members still references it. + if _, err := s.store.RemoveCredentialMeta(name); err != nil { + writeError(w, http.StatusConflict, "failed to remove credential metadata (vault secret left intact so a pool member is not orphaned): "+err.Error(), "") + return + } + + // Remove associated bindings and auto-created rules next. If vault.Remove // below fails, bindings/rules are already gone. This is a pre-existing // ordering tradeoff: reversing it would orphan bindings when vault succeeds // but SQLite fails. A transactional approach would require the vault to @@ -1287,18 +1300,14 @@ func (s *Server) DeleteApiCredentialsName(w http.ResponseWriter, r *http.Request } } - // Remove the credential from the vault first. If this fails, metadata - // stays intact so the credential type is not lost. + // Store removal already succeeded above (the pool-membership gate + // passed and credential_meta is gone). Only now is it safe to delete + // the vault secret. if err := s.vault.Remove(name); err != nil { writeError(w, http.StatusInternalServerError, "failed to remove credential: "+err.Error(), "") return } - // Remove credential metadata after vault deletion succeeded. - if _, err := s.store.RemoveCredentialMeta(name); err != nil { - log.Printf("[WARN] failed to remove credential meta %q: %v", name, err) - } - if err := s.recompileEngine(); err != nil { writeError(w, http.StatusInternalServerError, "credential removed but engine recompile failed: "+err.Error(), "") return diff --git a/internal/proxy/addon.go b/internal/proxy/addon.go index 4290e47..a9096e3 100644 --- a/internal/proxy/addon.go +++ b/internal/proxy/addon.go @@ -1450,6 +1450,16 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flo // pass is a no-op). covered := make(map[string]bool, len(boundCreds)) + // poolEmitted tracks pool namespaces whose pool-keyed phantom pairs + // (SLUICE_PHANTOM:.access / .refresh) were already produced — by + // EITHER pass. The token-host expansion below must be gated on the POOL + // namespace, not on covered[member]: if the active member ALSO has a + // plain direct binding on the token host, the CONNECT-host loop emits + // only member-scoped phantoms and sets covered[member], which would + // otherwise wrongly suppress the pool namespace and leave + // SLUICE_PHANTOM:.refresh unswapped (Finding 1, round-9). + poolEmitted := make(map[string]bool) + var pairs []phantomPair for _, boundName := range boundCreds { // Chokepoint: expand a bound pool name to its active member @@ -1478,6 +1488,7 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flo continue } covered[member] = true + poolEmitted[poolName] = true pairs = append(pairs, oauthPairs...) continue } @@ -1527,15 +1538,21 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flo pr = a.poolResolver.Load() } if pr != nil { - seenPool := make(map[string]bool) for _, credName := range idx.MatchAll(reqURL) { poolName := pr.PoolForMember(credName) - if poolName == "" || seenPool[poolName] { + // Gate on the POOL namespace only. covered[member] is + // deliberately NOT consulted here: a plain direct + // binding for the active member on this same token + // host sets covered[member] but does NOT emit the + // pool-keyed phantoms the agent actually holds, so + // suppressing on it would leak SLUICE_PHANTOM:.* + // upstream unswapped (Finding 1, round-9). + if poolName == "" || poolEmitted[poolName] { continue } - seenPool[poolName] = true + poolEmitted[poolName] = true member, ok := pr.ResolveActive(poolName) - if !ok || member == "" || covered[member] { + if !ok || member == "" { continue } secret, err := a.provider.Get(member) diff --git a/internal/proxy/pool_splithost_test.go b/internal/proxy/pool_splithost_test.go index aaa3afc..ff86b70 100644 --- a/internal/proxy/pool_splithost_test.go +++ b/internal/proxy/pool_splithost_test.go @@ -429,3 +429,121 @@ func TestFinding3_ProtocolScopedPooledBindingFailoverLookup(t *testing.T) { t.Fatalf("no cred_failover audit event found in:\n%s", data) } } + +// TestFinding1Round9_PoolNamespaceNotSuppressedByMemberPlainBinding is the +// Copilot round-9 Finding 1 regression. Topology: a pool bound to the API +// host (api.example.com), AND the active pool member (memA) ALSO has its OWN +// plain direct binding on the TOKEN host (auth.example.com, == the pool's +// token URL host). The agent POSTs the pool-keyed refresh grant +// (SLUICE_PHANTOM:codex_pool.refresh) to the token host. +// +// Before the fix, the CONNECT-host binding loop processed memA's plain +// binding on the token host, emitted only memA-scoped phantoms, and set +// covered[memA]=true. The token-host expansion pass then saw +// covered[member]==true (member==memA) and skipped the pool entirely, so +// SLUICE_PHANTOM:codex_pool.refresh AND .access were NEVER swapped: the +// pool-keyed phantoms the agent actually holds would travel upstream +// verbatim and the refresh would fail. The fix gates the token-host pass on +// the POOL namespace (poolEmitted[poolName]) rather than covered[member], so +// a plain member binding no longer suppresses the pool expansion. The pool +// namespace must be emitted exactly once (not double-emitted, not skipped). +func TestFinding1Round9_PoolNamespaceNotSuppressedByMemberPlainBinding(t *testing.T) { + const ( + poolName = "codex_pool" + memA = "memA" + memB = "memB" + ) + + provider := &addonWritableProvider{ + creds: map[string]string{ + memA: poolMemberCred(t, "A-access-old", "A-refresh-old"), + memB: poolMemberCred(t, "B-access-old", "B-refresh-old"), + }, + } + + // Pool bound to the API host. AND memA (the active member) ALSO has a + // plain direct binding on the TOKEN host -- this is the configuration + // that, pre-fix, set covered[memA] in the CONNECT-host loop and + // suppressed the pool expansion on the token host. + bindings := []vault.Binding{ + {Destination: "api.example.com", Ports: []int{443}, Credential: poolName}, + {Destination: "auth.example.com", Ports: []int{443}, Credential: memA}, + } + resolver, err := vault.NewBindingResolver(bindings) + if err != nil { + t.Fatalf("NewBindingResolver: %v", err) + } + var resolverPtr atomic.Pointer[vault.BindingResolver] + resolverPtr.Store(resolver) + + addon := NewSluiceAddon(WithResolver(&resolverPtr), WithProvider(provider)) + addon.persistDone = make(chan struct{}, 10) + + metas := []store.CredentialMeta{ + {Name: memA, CredType: "oauth", TokenURL: testOAuthTokenURL}, + {Name: memB, CredType: "oauth", TokenURL: testOAuthTokenURL}, + } + addon.UpdateOAuthIndex(metas) + + pool := store.Pool{Name: poolName, Strategy: store.PoolStrategyFailover} + pool.Members = []store.PoolMember{ + {Credential: memA, Position: 0}, + {Credential: memB, Position: 1}, + } + var prPtr atomic.Pointer[vault.PoolResolver] + prPtr.Store(vault.NewPoolResolver([]store.Pool{pool}, nil)) + addon.SetPoolResolver(&prPtr) + + if got, _ := prPtr.Load().ResolveActive(poolName); got != memA { + t.Fatalf("pre-condition active = %q, want %s", got, memA) + } + + // CONNECT target is the token host (where memA ALSO has a plain + // binding). The agent body carries the pool-keyed refresh phantom. + client := setupAddonConn(addon, "auth.example.com:443") + reqFlow := newTestFlow(client, "POST", testOAuthTokenURL) + reqFlow.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + reqFlow.Request.Body = refreshGrantBody(poolName) + + addon.Requestheaders(reqFlow) + addon.Request(reqFlow) + + body := string(reqFlow.Request.Body) + + // The pool refresh phantom must be swapped to the active member's REAL + // refresh token, NOT left verbatim. + if strings.Contains(body, "SLUICE_PHANTOM:"+poolName+".refresh") { + t.Fatalf("Finding 1 r9: pool refresh phantom NOT swapped (suppressed by member plain binding); body=%q", body) + } + if !strings.Contains(body, "A-refresh-old") { + t.Fatalf("Finding 1 r9: active member memA real refresh token not injected; body=%q", body) + } + // The pool ACCESS phantom must also be swappable: build the pairs + // directly and assert the pool access phantom maps to memA's real + // access token exactly once (no double-emit, not suppressed). + pairs := addon.buildPhantomPairs("auth.example.com", 443, "https", reqFlow.Id, reqFlow.Request.URL) + defer releasePhantomPairs(pairs) + accessPhantom := poolStablePhantomAccess(poolName) + refreshPhantom := "SLUICE_PHANTOM:" + poolName + ".refresh" + var accessCount, refreshCount int + for _, p := range pairs { + switch string(p.phantom) { + case accessPhantom: + accessCount++ + if got := string(p.secret.Bytes()); got != "A-access-old" { + t.Fatalf("pool access phantom -> %q, want A-access-old", got) + } + case refreshPhantom: + refreshCount++ + if got := string(p.secret.Bytes()); got != "A-refresh-old" { + t.Fatalf("pool refresh phantom -> %q, want A-refresh-old", got) + } + } + } + if accessCount != 1 { + t.Fatalf("Finding 1 r9: pool access phantom emitted %d times, want exactly 1 (not suppressed, not double-emitted)", accessCount) + } + if refreshCount != 1 { + t.Fatalf("Finding 1 r9: pool refresh phantom emitted %d times, want exactly 1 (not suppressed, not double-emitted)", refreshCount) + } +} diff --git a/internal/telegram/commands.go b/internal/telegram/commands.go index c2e64cd..5c605eb 100644 --- a/internal/telegram/commands.go +++ b/internal/telegram/commands.go @@ -639,21 +639,27 @@ func (h *CommandHandler) credRotate(name, value string) string { } func (h *CommandHandler) credRemove(name string) string { - // Remove from vault. If already gone (previous partial cleanup), - // continue to DB cleanup so stale rules/bindings can be removed. - if err := h.vault.Remove(name); err != nil { - if !os.IsNotExist(err) { - return fmt.Sprintf("Failed to remove credential: %v", err) - } - // Vault entry already gone. Continue to clean up stale DB state. - } - // Clean up associated bindings and auto-created rules. var warnings []string var removedEnvVars []string if h.store != nil { h.reloadMu.Lock() defer h.reloadMu.Unlock() + + // Store-first removal order (Finding 3, round-9). + // RemoveCredentialMeta is the authoritative, fail-closed + // pool-membership gate: it refuses inside its own transaction if + // the credential is still a live pool member. Run it BEFORE the + // vault delete (and before bindings/rules cleanup) so the vault + // secret is only destroyed once the store has accepted the + // removal. If it refuses, the vault secret, bindings, and rules + // are all left intact and no window exists where the secret is + // gone but credential_pool_members still references it. + if _, err := h.store.RemoveCredentialMeta(name); err != nil { + log.Printf("[WARN] remove credential meta for %q: %v", name, err) + return fmt.Sprintf("Failed to remove credential %q (vault secret left intact so a pool member is not orphaned): %v", name, err) + } + // Read env_var values from bindings before removal so we can clear // them from the agent container after the bindings are deleted. if credBindings, err := h.store.ListBindingsByCredential(name); err == nil { @@ -663,6 +669,18 @@ func (h *CommandHandler) credRemove(name string) string { } } } + + // Store removal already succeeded (the pool-membership gate + // passed). Only now is it safe to delete the vault secret. If + // already gone (previous partial cleanup), continue to clean up + // stale DB state. + if err := h.vault.Remove(name); err != nil { + if !os.IsNotExist(err) { + return fmt.Sprintf("Failed to remove credential: %v", err) + } + // Vault entry already gone. Continue to clean up stale DB state. + } + if _, err := h.store.RemoveRulesBySource("cred-add:" + name); err != nil { log.Printf("[WARN] remove rules for credential %q: %v", name, err) warnings = append(warnings, fmt.Sprintf("failed to remove rules: %v", err)) @@ -671,10 +689,6 @@ func (h *CommandHandler) credRemove(name string) string { log.Printf("[WARN] remove bindings for credential %q: %v", name, err) warnings = append(warnings, fmt.Sprintf("failed to remove bindings: %v", err)) } - if _, err := h.store.RemoveCredentialMeta(name); err != nil { - log.Printf("[WARN] remove credential meta for %q: %v", name, err) - warnings = append(warnings, fmt.Sprintf("failed to remove credential meta: %v", err)) - } // Recompile engine so removed allow rules take effect immediately. if err := h.recompileAndSwap(); err != nil { log.Printf("[WARN] recompile after cred remove failed: %v", err) @@ -690,6 +704,14 @@ func (h *CommandHandler) credRemove(name string) string { if h.onOAuthIndexRebuild != nil { h.onOAuthIndexRebuild() } + } else { + // No store configured, so there is no pool-membership gate to + // run; delete the vault secret directly. If already gone + // (previous partial cleanup), report success — there is no DB + // state to clean up. + if err := h.vault.Remove(name); err != nil && !os.IsNotExist(err) { + return fmt.Sprintf("Failed to remove credential: %v", err) + } } msg := fmt.Sprintf("Removed credential: %s", name) diff --git a/internal/vault/pool.go b/internal/vault/pool.go index 51fcbbc..c252120 100644 --- a/internal/vault/pool.go +++ b/internal/vault/pool.go @@ -302,7 +302,28 @@ func (pr *PoolResolver) MergeLiveCooldowns(prev *PoolResolver) { } if pr.health == prev.health { // Shared health map: both generations already see the same - // cooldowns. Nothing to do — this is the CRITICAL-1 fix. + // cooldowns, so there is nothing to carry forward — this is the + // CRITICAL-1 fix. But the shared map can still hold stale entries + // for credentials this new generation no longer tracks as a pool + // member (a member removed from a pool, or removed and recreated + // under the same name). Without pruning, a re-add before the old + // TTL expires would inherit the stale cooldown and ResolveActive + // would skip the member even though the store snapshot no longer + // records it (Finding 2, round-9). + // + // Pruning only NON-members preserves both invariants: a current + // member's (possibly synchronously-recorded) cooldown is never + // touched, so the monotonic-cooldown invariant and CRITICAL-1 + // shared-map durability for live members are intact; only entries + // for credentials absent from the new resolver's member set are + // dropped. + pr.health.mu.Lock() + for cred := range pr.health.health { + if _, stillMember := pr.memberOf[cred]; !stillMember { + delete(pr.health.health, cred) + } + } + pr.health.mu.Unlock() return } now := time.Now() diff --git a/internal/vault/pool_test.go b/internal/vault/pool_test.go index e5dc0f8..bd2263a 100644 --- a/internal/vault/pool_test.go +++ b/internal/vault/pool_test.go @@ -275,6 +275,84 @@ func TestSharedHealthSurvivesResolverRebuild(t *testing.T) { } } +// TestFinding2Round9_SharedHealthPrunesNonMembers is the Copilot round-9 +// Finding 2 regression. On the shared-PoolHealth path (the normal server +// path, prev.health == pr.health) MergeLiveCooldowns early-returned BEFORE +// pruning cooldowns for credentials no longer a member of ANY pool. A cooled +// member removed from a pool stayed in the process-wide shared health map +// and, if re-added before its old TTL expired, was skipped again by +// ResolveActive even though the store snapshot no longer recorded the +// cooldown. The fix prunes non-member entries on the shared path too, while +// never shortening a still-valid cooldown for a current member. +func TestFinding2Round9_SharedHealthPrunesNonMembers(t *testing.T) { + shared := NewPoolHealth() + + // gen1: a single resolver generation holds ALL pools (this is how the + // server builds it — one PoolResolver per process, every pool inside + // it). "pool" has members "a" and "b"; "other" has "c". "a" (still a + // member next gen) and "b" (about to be removed) both get cooled. + gen1 := NewPoolResolverShared([]store.Pool{ + mkPool("pool", "a", "b"), + mkPool("other", "c"), + }, nil, shared) + aUntil := time.Now().Add(300 * time.Second) + bUntil := time.Now().Add(300 * time.Second) + gen1.MarkCooldown("a", aUntil, "429") + gen1.MarkCooldown("b", bUntil, "401") + + // gen2: "b" removed from "pool" (membership change). gen2's memberOf is + // the COMPLETE member set across all pools for the new generation, so a + // credential absent from it is no longer in ANY pool. Same shared + // health instance — this is the normal server path + // (prev.health == pr.health). + gen2 := NewPoolResolverShared([]store.Pool{ + mkPool("pool", "a"), + mkPool("other", "c"), + }, nil, shared) + + // Without the fix MergeLiveCooldowns early-returns on the shared path + // and "b"'s stale cooldown lingers in the process-wide shared map. + gen2.MergeLiveCooldowns(gen1) + + // "b" is no longer a member of any pool: its stale cooldown MUST be + // pruned so a re-add before the old TTL does not inherit it. + if until, cooling := gen2.CooldownUntil("b"); cooling { + t.Errorf("Finding 2 r9: stale cooldown for removed non-member b must be pruned; got until=%v cooling=%v", until, cooling) + } + + // "a" is still a member of "pool": its cooldown must survive the merge + // and must NOT be shortened (monotonic-cooldown / CRITICAL-1 durability + // for live members). + if until, cooling := gen2.CooldownUntil("a"); !cooling { + t.Fatalf("Finding 2 r9: still-member a lost its cooldown across the shared-path merge") + } else if until.Before(aUntil.Add(-time.Second)) { + t.Errorf("Finding 2 r9: still-member a's cooldown was shortened: got %v want ~%v", until, aUntil) + } + + // Re-add "b" to a pool (next generation) BEFORE its old TTL would have + // expired. Because the stale cooldown was pruned, "b" must now be + // healthy and ResolveActive must pick it, not skip it as still-cooling. + gen3 := NewPoolResolverShared([]store.Pool{ + mkPool("pool", "a"), + mkPool("other", "c"), + mkPool("p2", "b", "d"), + }, nil, shared) + gen3.MergeLiveCooldowns(gen2) + if _, cooling := gen3.CooldownUntil("b"); cooling { + t.Errorf("Finding 2 r9: re-added b inherited a stale cooldown that should have been pruned") + } + if got, ok := gen3.ResolveActive("p2"); !ok || got != "b" { + t.Errorf("Finding 2 r9: re-added b should be the active member of p2; got %q,%v want b,true", got, ok) + } + // "a" is still in gen3's "pool"; its (unshortened) cooldown must still + // be intact after the second merge as well. + if until, cooling := gen3.CooldownUntil("a"); !cooling { + t.Errorf("Finding 2 r9: still-member a lost its cooldown across the second shared-path merge") + } else if until.Before(aUntil.Add(-time.Second)) { + t.Errorf("Finding 2 r9: still-member a's cooldown was shortened by the second merge: got %v want ~%v", until, aUntil) + } +} + // TestSharedHealthConcurrentMarkCooldownVsRebuild stresses the CRITICAL-1 // race: MarkCooldown on rotating "old" generations racing continuous // resolver rebuilds (the StorePool/reload swap) against one shared health. From 509ca44da0305a9cdb80fea5145f1467b9d4349a Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 16:28:06 +0800 Subject: [PATCH 34/49] fix(store): guard CAS credential-rollback for pool members + health cleanup; fix e2e coalesce deadlock-on-failure --- e2e/approval_coalesce_test.go | 8 ++++ internal/store/pools_test.go | 83 +++++++++++++++++++++++++++++++++ internal/store/store.go | 86 +++++++++++++++++++++++------------ 3 files changed, 148 insertions(+), 29 deletions(-) diff --git a/e2e/approval_coalesce_test.go b/e2e/approval_coalesce_test.go index b205fc0..22829e2 100644 --- a/e2e/approval_coalesce_test.go +++ b/e2e/approval_coalesce_test.go @@ -136,6 +136,14 @@ func startGatedVerdictServer(t *testing.T, verdict string) (*httptest.Server, *g g := newGatedVerdictServer(verdict) srv := newIPv4Server(t, g) t.Cleanup(srv.Close) + // Registered AFTER srv.Close so it runs BEFORE it (cleanups are LIFO): + // a failing assertion anywhere in the test (e.g. the coalescing-window + // checks) would otherwise leave the first approval handler parked on + // <-g.release, and httptest's srv.Close() blocks on that in-flight + // request until the global test timeout. Release() is sync.Once-guarded + // so the happy-path explicit release plus this cleanup release is safe + // (the second close is a no-op). + t.Cleanup(g.Release) return srv, g } diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go index c943a42..5a248e1 100644 --- a/internal/store/pools_test.go +++ b/internal/store/pools_test.go @@ -2,6 +2,7 @@ package store import ( "path/filepath" + "strings" "testing" "time" @@ -513,3 +514,85 @@ func TestAddCredentialMetaRejectsPoolNameCollision(t *testing.T) { t.Fatal("expected CreatePoolWithMembers to reject a name that is already a credential") } } + +// TestRemoveCredentialMetaCASGuardsLivePoolMember is the round-10 Finding 1 +// regression. The cred-add rollback path (RemoveCredentialMetaCAS) must apply +// the SAME fail-closed pool-member guard and the SAME credential_health +// cleanup as RemoveCredentialMeta. Interleave being defended: +// +// cred add inserts credential_meta("c") -> a concurrent caller creates a +// pool that claims "c" -> a later step in the original add flow fails -> +// the CAS rollback runs. A blind CAS delete here would orphan the +// credential_pool_members row (pool -> missing credential). The shared +// guarded helper must refuse the delete and surface an informative error. +func TestRemoveCredentialMetaCASGuardsLivePoolMember(t *testing.T) { + s := newTestStore(t) + + // cred add inserted the meta row (oauth, with the seed token URL). + seedOAuthCred(t, s, "c") + seedOAuthCred(t, s, "sibling") + + // Concurrent pool-create claims "c" between the insert and the rollback. + if err := s.CreatePoolWithMembers("p", "failover", []string{"c", "sibling"}); err != nil { + t.Fatalf("CreatePoolWithMembers: %v", err) + } + + // The original add flow failed; its rollback fires RemoveCredentialMetaCAS + // with the values it inserted. The pool-member guard must REFUSE. + removed, noConcurrent, err := s.RemoveCredentialMetaCAS("c", "oauth", "https://auth.example.com/token") + if err == nil { + t.Fatal("expected CAS rollback to refuse a live pool member (Finding 1)") + } + if removed || noConcurrent { + t.Fatalf("CAS rollback reported removed=%v noConcurrent=%v on a refused delete; want false,false", removed, noConcurrent) + } + if !strings.Contains(err.Error(), "member of pool") { + t.Fatalf("CAS refusal error is not informative about pool membership: %v", err) + } + // The meta row must survive (it IS a live pool member — correct state), + // so the pool does not resolve to a missing credential. + if m, gerr := s.GetCredentialMeta("c"); gerr != nil || m == nil { + t.Fatalf("CAS rollback deleted a live pool member's meta row: %+v, %v", m, gerr) + } + pools, perr := s.PoolsForMember("c") + if perr != nil || len(pools) != 1 || pools[0] != "p" { + t.Fatalf("PoolsForMember(c) = %v, %v; want [p] (no orphan)", pools, perr) + } + + // A normal rollback (no pool claim) must still delete meta + health row + // and leave no stale cooldown for a same-named recreation. + seedOAuthCred(t, s, "lone") + until := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("lone", "cooldown", until, "429 from a prior add attempt"); err != nil { + t.Fatalf("SetCredentialHealth(lone): %v", err) + } + removed, noConcurrent, err = s.RemoveCredentialMetaCAS("lone", "oauth", "https://auth.example.com/token") + if err != nil || !removed || !noConcurrent { + t.Fatalf("RemoveCredentialMetaCAS(lone) = %v,%v,%v; want true,true,nil", removed, noConcurrent, err) + } + if m, _ := s.GetCredentialMeta("lone"); m != nil { + t.Fatalf("CAS rollback left a meta row for a non-member: %+v", m) + } + if h, herr := s.GetCredentialHealth("lone"); herr != nil || h != nil { + t.Fatalf("CAS rollback left a stale cooldown row: %+v, %v", h, herr) + } + + // CAS predicate is still honoured: a concurrent overwrite (different + // cred_type) must not be deleted by a stale-expectation rollback, even + // when the credential is free of any pool. + seedOAuthCred(t, s, "raced") + // A concurrent writer overwrote "raced" as a static credential. + if _, err := s.db.Exec("UPDATE credential_meta SET cred_type = 'static', token_url = NULL WHERE name = ?", "raced"); err != nil { + t.Fatalf("simulate concurrent overwrite: %v", err) + } + removed, noConcurrent, err = s.RemoveCredentialMetaCAS("raced", "oauth", "https://auth.example.com/token") + if err != nil { + t.Fatalf("CAS with stale expectation errored: %v", err) + } + if removed || noConcurrent { + t.Fatalf("CAS deleted a concurrently-overwritten row: removed=%v noConcurrent=%v; want false,false", removed, noConcurrent) + } + if m, _ := s.GetCredentialMeta("raced"); m == nil { + t.Fatal("CAS wiped a concurrent writer's row despite the predicate mismatch") + } +} diff --git a/internal/store/store.go b/internal/store/store.go index d78a762..12e5b1a 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1820,21 +1820,10 @@ func (s *Store) ListCredentialMeta() ([]CredentialMeta, error) { // RemoveCredentialMeta deletes a credential metadata row by name. Returns true // if a row was deleted. // -// Store-layer invariant enforcement (protects every caller, not just the CLI): -// -// - Pool-member integrity: removal is REFUSED (fail-closed) when the -// credential is still a live member of one or more pools. Deleting it -// would leave a credential_pool_members row pointing at a missing -// credential, so the pool would resolve to an uninjectable credential. -// The operator must remove it from the pool first. The CLI grew a guard -// for this earlier, but the REST API and Telegram paths call this method -// directly and bypassed it; pushing the guard here closes all paths. -// - Health-row cleanup: the per-credential credential_health row is keyed -// by credential name and is not tied to credential_meta by a foreign -// key, so a bare meta delete would leave a stale cooldown behind. -// Recreating a same-named credential would then inherit that stale -// cooldown on the next resolver seed. The health row is deleted in the -// same transaction as the meta row so the two never diverge. +// The fail-closed pool-member guard and the credential_health cleanup are +// enforced for every caller (CLI, REST API, Telegram) via the shared +// deleteCredentialMetaGuardedTx helper, which RemoveCredentialMetaCAS also +// routes through so the bare-delete and CAS-rollback paths cannot diverge. func (s *Store) RemoveCredentialMeta(name string) (bool, error) { tx, err := s.db.Begin() if err != nil { @@ -1842,6 +1831,41 @@ func (s *Store) RemoveCredentialMeta(name string) (bool, error) { } defer func() { _ = tx.Rollback() }() + n, err := deleteCredentialMetaGuardedTx(tx, name, "DELETE FROM credential_meta WHERE name = ?", []any{name}) + if err != nil { + return false, err + } + if err := tx.Commit(); err != nil { + return false, fmt.Errorf("commit: %w", err) + } + return n > 0, nil +} + +// deleteCredentialMetaGuardedTx is the single guarded credential_meta delete +// path shared by RemoveCredentialMeta and RemoveCredentialMetaCAS so the two +// can never diverge on the pool-member integrity guard or the health-row +// cleanup. It must be called inside an open transaction; the caller commits. +// +// Invariants enforced here for EVERY removal path (bare delete and CAS +// rollback alike): +// +// - Pool-member integrity (fail-closed): removal is REFUSED when the +// credential is still a live member of one or more pools. Deleting it +// would leave a credential_pool_members row pointing at a missing +// credential. This also closes the add-rollback TOCTOU: a concurrent +// pool-create can claim the just-added credential between insert and +// rollback, so the CAS rollback must honour the same guard. When the +// guard refuses, the meta row stays (it IS a live pool member, which is +// correct) and an informative error is returned so the caller surfaces +// it instead of silently leaving inconsistent state. +// - Health-row cleanup: the per-credential credential_health row is keyed +// by credential name and is not FK-tied to credential_meta, so a bare +// meta delete would leave a stale cooldown a same-named recreation would +// inherit. The health row is deleted in the same transaction. +// +// deleteSQL/deleteArgs let the CAS caller add its compare-and-swap predicate +// to the meta delete while sharing the guard and health cleanup. +func deleteCredentialMetaGuardedTx(tx *sql.Tx, name, deleteSQL string, deleteArgs []any) (int64, error) { // Fail-closed pool-member guard. Same semantics as the CLI guard in // `sluice cred remove`: a credential that is a live pool member cannot // be removed until it is taken out of the pool. @@ -1851,27 +1875,24 @@ func (s *Store) RemoveCredentialMeta(name string) (bool, error) { ).Scan(&pool) switch memErr { case nil: - return false, fmt.Errorf("credential %q is a member of pool %q; remove it from the pool first (sluice pool remove

, or recreate the pool without it)", name, pool) + return 0, fmt.Errorf("credential %q is a member of pool %q; remove it from the pool first (sluice pool remove

, or recreate the pool without it)", name, pool) case sql.ErrNoRows: // not a pool member; safe to remove default: - return false, fmt.Errorf("check pool membership for %q: %w", name, memErr) + return 0, fmt.Errorf("check pool membership for %q: %w", name, memErr) } - res, err := tx.Exec("DELETE FROM credential_meta WHERE name = ?", name) + res, err := tx.Exec(deleteSQL, deleteArgs...) if err != nil { - return false, fmt.Errorf("delete credential meta: %w", err) + return 0, fmt.Errorf("delete credential meta: %w", err) } // Drop the health row in the same transaction so a removed credential // never leaves a stale cooldown that a same-named recreation inherits. if _, err := tx.Exec("DELETE FROM credential_health WHERE credential = ?", name); err != nil { - return false, fmt.Errorf("delete credential health: %w", err) - } - if err := tx.Commit(); err != nil { - return false, fmt.Errorf("commit: %w", err) + return 0, fmt.Errorf("delete credential health: %w", err) } n, _ := res.RowsAffected() - return n > 0, nil + return n, nil } // RemoveCredentialMetaCAS deletes a credential metadata row only when its @@ -1915,15 +1936,22 @@ func (s *Store) RemoveCredentialMetaCAS(name, expectedType, expectedTokenURL str return false, false, nil } - res, err := tx.Exec("DELETE FROM credential_meta WHERE name = ? AND cred_type = ? AND COALESCE(token_url, '') = ?", - name, expectedType, expectedTokenURL) - if err != nil { - return false, false, fmt.Errorf("delete credential meta: %w", err) + // Route the actual delete through the shared guarded helper so the + // CAS rollback path enforces the SAME fail-closed pool-member guard and + // the SAME credential_health cleanup as RemoveCredentialMeta. A + // concurrent pool-create can claim the just-added credential between + // our insert and this rollback (TOCTOU); the guard refuses the delete + // in that case, surfacing an informative error and leaving the meta row + // in place (it IS a live pool member, which is the correct state). + n, gErr := deleteCredentialMetaGuardedTx(tx, name, + "DELETE FROM credential_meta WHERE name = ? AND cred_type = ? AND COALESCE(token_url, '') = ?", + []any{name, expectedType, expectedTokenURL}) + if gErr != nil { + return false, false, gErr } if commitErr := tx.Commit(); commitErr != nil { return false, false, fmt.Errorf("commit: %w", commitErr) } - n, _ := res.RowsAffected() return n > 0, true, nil } From fcb2fbf0e744ad5a25a1fd71496513d50c05a89c Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 16:42:55 +0800 Subject: [PATCH 35/49] fix(store): RemovePool health cleanup; reject live-pool-member meta downgrade; CAS-aware health delete --- internal/store/pools.go | 71 +++++++++++++++++- internal/store/pools_test.go | 139 +++++++++++++++++++++++++++++++++++ internal/store/store.go | 53 +++++++++++-- internal/store/store_test.go | 105 ++++++++++++++++++++++++++ 4 files changed, 362 insertions(+), 6 deletions(-) diff --git a/internal/store/pools.go b/internal/store/pools.go index fe949c3..482c345 100644 --- a/internal/store/pools.go +++ b/internal/store/pools.go @@ -281,12 +281,81 @@ func (s *Store) ListPools() ([]Pool, error) { // RemovePool deletes a pool and (via ON DELETE CASCADE) its members. Returns // true if a pool row was deleted. +// +// The members' credential_health rows are deleted in the SAME transaction so +// a cooled member taken out with its pool does not leave a stale durable +// cooldown. loadPoolResolver seeds the shared PoolHealth from ALL +// credential_health rows, so an orphaned cooldown would otherwise be +// inherited by the same credential when it is re-added to a new pool before +// the old TTL expires. A member that is still a live member of ANOTHER pool +// keeps its health row (its cooldown is still meaningful for that pool); only +// members no longer in any pool after this delete have their health row +// removed. func (s *Store) RemovePool(name string) (bool, error) { - res, err := s.db.Exec("DELETE FROM credential_pools WHERE name = ?", name) + tx, err := s.db.Begin() + if err != nil { + return false, fmt.Errorf("begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + // Snapshot the pool's members before the cascade wipes the membership + // rows so we know whose health rows to consider for cleanup. + mrows, err := tx.Query( + "SELECT credential FROM credential_pool_members WHERE pool = ?", name, + ) + if err != nil { + return false, fmt.Errorf("list members of pool %q: %w", name, err) + } + var members []string + for mrows.Next() { + var c string + if scanErr := mrows.Scan(&c); scanErr != nil { + _ = mrows.Close() + return false, fmt.Errorf("scan pool member: %w", scanErr) + } + members = append(members, c) + } + if mrowsErr := mrows.Err(); mrowsErr != nil { + _ = mrows.Close() + return false, fmt.Errorf("iterate pool members: %w", mrowsErr) + } + _ = mrows.Close() + + res, err := tx.Exec("DELETE FROM credential_pools WHERE name = ?", name) if err != nil { return false, fmt.Errorf("delete pool %q: %w", name, err) } n, _ := res.RowsAffected() + + if n > 0 { + // The CASCADE has now removed this pool's credential_pool_members + // rows. For each former member, drop its health row UNLESS it is + // still a member of some OTHER pool (the membership query runs + // post-cascade, so any remaining row means another pool still owns + // the credential and its cooldown stays meaningful). + for _, c := range members { + var stillPooled int + err := tx.QueryRow( + "SELECT 1 FROM credential_pool_members WHERE credential = ? LIMIT 1", c, + ).Scan(&stillPooled) + switch { + case errors.Is(err, sql.ErrNoRows): + if _, delErr := tx.Exec( + "DELETE FROM credential_health WHERE credential = ?", c, + ); delErr != nil { + return false, fmt.Errorf("delete health for former pool member %q: %w", c, delErr) + } + case err != nil: + return false, fmt.Errorf("check residual pool membership for %q: %w", c, err) + default: + // Still a member of another pool; leave its health row. + } + } + } + + if err := tx.Commit(); err != nil { + return false, fmt.Errorf("commit: %w", err) + } return n > 0, nil } diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go index 5a248e1..0da7f0c 100644 --- a/internal/store/pools_test.go +++ b/internal/store/pools_test.go @@ -596,3 +596,142 @@ func TestRemoveCredentialMetaCASGuardsLivePoolMember(t *testing.T) { t.Fatal("CAS wiped a concurrent writer's row despite the predicate mismatch") } } + +// TestRemovePoolDeletesMemberHealth pins Finding 1: RemovePool must delete the +// credential_health rows of the pool's members in the same transaction so a +// cooled member taken out with its pool does not leave a stale durable +// cooldown that loadPoolResolver (which seeds the shared PoolHealth from ALL +// credential_health rows) would inherit when the credential is re-added to a +// NEW pool before the old TTL expires. +// +// Fail-before: RemovePool only DELETEd credential_pools (members cascaded), +// leaving the health row -> GetCredentialHealth still returns the cooldown. +// Pass-after: the member's health row is gone. +func TestRemovePoolDeletesMemberHealth(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "m") + if err := s.CreatePoolWithMembers("p", "failover", []string{"m"}); err != nil { + t.Fatalf("create pool: %v", err) + } + until := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("m", "cooldown", until, "429 rate limited"); err != nil { + t.Fatalf("cool member: %v", err) + } + + removed, err := s.RemovePool("p") + if err != nil || !removed { + t.Fatalf("RemovePool = %v, %v; want true, nil", removed, err) + } + + // The former member's durable cooldown must be gone so re-adding it to a + // new pool before the old TTL expires yields a healthy member. + h, err := s.GetCredentialHealth("m") + if err != nil { + t.Fatalf("GetCredentialHealth: %v", err) + } + if h != nil { + t.Fatalf("member health row survived RemovePool (stale cooldown inherited): %+v", h) + } +} + +// TestRemovePoolSparesStillPooledMemberHealth is the negative case for +// Finding 1: a member that is STILL a live member of another pool after the +// removal must keep its health row (its cooldown is still meaningful for that +// pool). The one-credential-one-pool invariant is enforced at the +// application layer, so a second membership is injected via raw SQL to +// exercise the residual-membership defensive branch (the same reason +// PoolsForMember returns a slice). +func TestRemovePoolSparesStillPooledMemberHealth(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "m") + if err := s.CreatePoolWithMembers("p", "failover", []string{"m"}); err != nil { + t.Fatalf("create pool p: %v", err) + } + // "m" also belongs to pool q (legacy/pre-invariant row injected directly). + if _, err := s.db.Exec("INSERT INTO credential_pools (name, strategy) VALUES ('q', 'failover')"); err != nil { + t.Fatalf("insert pool q: %v", err) + } + if _, err := s.db.Exec("INSERT INTO credential_pool_members (pool, credential, position) VALUES ('q', 'm', 0)"); err != nil { + t.Fatalf("insert q membership: %v", err) + } + until := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("m", "cooldown", until, "401 auth fail"); err != nil { + t.Fatalf("cool member: %v", err) + } + + removed, err := s.RemovePool("p") + if err != nil || !removed { + t.Fatalf("RemovePool(p) = %v, %v; want true, nil", removed, err) + } + + // "m" is still in pool q, so its cooldown must be preserved. + h, err := s.GetCredentialHealth("m") + if err != nil { + t.Fatalf("GetCredentialHealth: %v", err) + } + if h == nil { + t.Fatal("RemovePool(p) wiped the health row of a member still in pool q") + } + if h.Status != "cooldown" { + t.Errorf("health status = %q, want cooldown (cooldown for still-pooled member destroyed)", h.Status) + } +} + +// TestAddCredentialMetaRejectsLivePoolMemberDowngrade pins Finding 2: +// AddCredentialMeta is an upsert, and a re-add/update path could flip an +// existing credential that is a LIVE pool member to static / non-oauth / +// missing token_url, leaving the pool pointing at a member the pooled OAuth +// injection+failover code cannot use. The downgrade must be rejected; benign +// updates (still oauth with a token_url) and non-member upserts must still +// work. +func TestAddCredentialMetaRejectsLivePoolMemberDowngrade(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "poolcred") + if err := s.CreatePoolWithMembers("p", "failover", []string{"poolcred"}); err != nil { + t.Fatalf("create pool: %v", err) + } + + // Downgrade a live pool member to static -> rejected, row unchanged. + if err := s.AddCredentialMeta("poolcred", "static", ""); err == nil { + t.Fatal("expected AddCredentialMeta to reject downgrading a live pool member to static") + } + meta, err := s.GetCredentialMeta("poolcred") + if err != nil { + t.Fatalf("get meta: %v", err) + } + if meta == nil || meta.CredType != "oauth" || meta.TokenURL == "" { + t.Fatalf("live pool member meta was mutated by a rejected downgrade: %+v", meta) + } + + // Dropping the token_url while still "oauth" is also a downgrade + // (pooled failover needs a token endpoint). AddCredentialMeta's own + // oauth-needs-token_url validation rejects this before the guard, which + // still leaves the row unchanged — the property under test. + if err := s.AddCredentialMeta("poolcred", "oauth", ""); err == nil { + t.Fatal("expected AddCredentialMeta to reject a live pool member losing its token_url") + } + if m, _ := s.GetCredentialMeta("poolcred"); m == nil || m.TokenURL == "" { + t.Fatalf("token_url drop mutated the live pool member row: %+v", m) + } + + // Benign update: still oauth, new token_url -> allowed. + if err := s.AddCredentialMeta("poolcred", "oauth", "https://new.example.com/token"); err != nil { + t.Fatalf("benign oauth token_url change on a pool member rejected: %v", err) + } + m2, _ := s.GetCredentialMeta("poolcred") + if m2 == nil || m2.TokenURL != "https://new.example.com/token" { + t.Fatalf("benign token_url change not applied: %+v", m2) + } + + // Non-member credential: static upsert still allowed (no regression). + if err := s.AddCredentialMeta("freecred", "oauth", "https://auth.example.com/token"); err != nil { + t.Fatalf("seed freecred: %v", err) + } + if err := s.AddCredentialMeta("freecred", "static", ""); err != nil { + t.Fatalf("static upsert of a non-pool-member credential was wrongly rejected: %v", err) + } + fm, _ := s.GetCredentialMeta("freecred") + if fm == nil || fm.CredType != "static" { + t.Fatalf("non-member static upsert did not apply: %+v", fm) + } +} diff --git a/internal/store/store.go b/internal/store/store.go index 12e5b1a..ba9c7c4 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1763,6 +1763,39 @@ func (s *Store) AddCredentialMeta(name, credType, tokenURL string) error { return fmt.Errorf("check pool name collision for %q: %w", name, collErr) } + // Live-pool-member downgrade guard. AddCredentialMeta is an upsert, so a + // re-add/update path could flip an EXISTING credential that is currently + // a live pool member to static / non-oauth / missing token_url. Pool + // creation validates members are oauth (validatePoolMemberTx), but this + // upsert bypasses that post-hoc and would leave the pool pointing at a + // member the pooled OAuth injection+failover code cannot use. Reject the + // downgrade only when the row already exists AND is a live pool member + // AND the new metadata is not a usable oauth credential. Benign updates + // (still oauth with a token_url, e.g. a token_url change) are allowed. + var existingType string + exErr := tx.QueryRow("SELECT cred_type FROM credential_meta WHERE name = ?", name).Scan(&existingType) + switch exErr { + case nil: + var memberPool string + memErr := tx.QueryRow( + "SELECT pool FROM credential_pool_members WHERE credential = ? ORDER BY pool LIMIT 1", name, + ).Scan(&memberPool) + switch memErr { + case nil: + if credType != "oauth" || tokenURL == "" { + return fmt.Errorf("credential %q is a live member of pool %q; it must stay an oauth credential with a token_url (pooled failover cannot use a %s credential)", name, memberPool, credType) + } + case sql.ErrNoRows: + // not a pool member; any upsert is fine + default: + return fmt.Errorf("check pool membership for %q: %w", name, memErr) + } + case sql.ErrNoRows: + // brand-new credential; nothing to downgrade + default: + return fmt.Errorf("check existing credential meta for %q: %w", name, exErr) + } + if _, err := tx.Exec( `INSERT INTO credential_meta (name, cred_type, token_url) VALUES (?, ?, ?) @@ -1886,12 +1919,22 @@ func deleteCredentialMetaGuardedTx(tx *sql.Tx, name, deleteSQL string, deleteArg if err != nil { return 0, fmt.Errorf("delete credential meta: %w", err) } - // Drop the health row in the same transaction so a removed credential - // never leaves a stale cooldown that a same-named recreation inherits. - if _, err := tx.Exec("DELETE FROM credential_health WHERE credential = ?", name); err != nil { - return 0, fmt.Errorf("delete credential health: %w", err) - } n, _ := res.RowsAffected() + // Drop the health row in the same transaction ONLY when a meta row was + // actually deleted. The CAS caller appends a compare-and-swap predicate + // to deleteSQL, so a concurrent writer that changed cred_type/token_url + // makes the meta DELETE a no-op (0 rows). In that case the concurrent + // writer's metadata is correctly left intact; wiping its health row too + // would silently destroy a live cooldown it still owns. When 0 rows are + // deleted we leave both untouched and signal the no-op to the caller + // (RemoveCredentialMetaCAS turns n==0 into removed=false). A plain + // delete-by-name still removes the health row whenever it removes the + // meta row, so non-CAS semantics are unchanged. + if n > 0 { + if _, err := tx.Exec("DELETE FROM credential_health WHERE credential = ?", name); err != nil { + return 0, fmt.Errorf("delete credential health: %w", err) + } + } return n, nil } diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 3998042..1a995a1 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -9,6 +9,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/golang-migrate/migrate/v4" migsqlite "github.com/golang-migrate/migrate/v4/database/sqlite" @@ -3318,6 +3319,110 @@ func TestRemoveCredentialMetaCASMissingRow(t *testing.T) { } } +// TestDeleteCredentialMetaGuardedTxNoOpKeepsHealth pins the round-10 +// regression: deleteCredentialMetaGuardedTx must NOT delete the +// credential_health row when the meta DELETE affected zero rows. The CAS +// caller appends a compare-and-swap predicate to the meta DELETE, so a +// concurrent writer that changed cred_type/token_url makes the delete a no-op +// (0 rows). The concurrent writer's metadata is correctly left intact; +// wiping its still-live health row would silently destroy a cooldown it owns. +// +// Fail-before: the helper unconditionally deleted credential_health, so a CAS +// no-op left the meta row but wiped the health row. Pass-after: a 0-row meta +// delete leaves BOTH untouched; a matching delete still removes both. +func TestDeleteCredentialMetaGuardedTxNoOpKeepsHealth(t *testing.T) { + s := newTestStore(t) + if err := s.AddCredentialMeta("concurrent", "static", ""); err != nil { + t.Fatalf("seed meta: %v", err) + } + until := time.Now().Add(60 * time.Second).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("concurrent", "cooldown", until, "429 rate limited"); err != nil { + t.Fatalf("seed health: %v", err) + } + + // Simulate the CAS rollback racing a concurrent writer: the CAS DELETE + // predicate no longer matches the current row, so the meta DELETE + // affects 0 rows. + tx, err := s.db.Begin() + if err != nil { + t.Fatalf("begin: %v", err) + } + n, err := deleteCredentialMetaGuardedTx(tx, "concurrent", + "DELETE FROM credential_meta WHERE name = ? AND cred_type = ?", + []any{"concurrent", "oauth"}) // predicate fails: row is static + if err != nil { + _ = tx.Rollback() + t.Fatalf("deleteCredentialMetaGuardedTx: %v", err) + } + if n != 0 { + _ = tx.Rollback() + t.Fatalf("expected 0 rows deleted on CAS no-op, got %d", n) + } + if err := tx.Commit(); err != nil { + t.Fatalf("commit: %v", err) + } + + // The concurrent writer's metadata must still be there... + meta, err := s.GetCredentialMeta("concurrent") + if err != nil { + t.Fatalf("get meta: %v", err) + } + if meta == nil { + t.Fatal("CAS no-op wrongly deleted the concurrent writer's meta row") + } + // ...and so must its live cooldown (this is the regression assertion). + h, err := s.GetCredentialHealth("concurrent") + if err != nil { + t.Fatalf("get health: %v", err) + } + if h == nil { + t.Fatal("CAS no-op wrongly deleted the concurrent writer's credential_health row") + } + if h.Status != "cooldown" { + t.Errorf("health status = %q, want cooldown (cooldown destroyed)", h.Status) + } + + // A matching CAS delete still removes BOTH meta and health. + removed, noConcurrent, err := s.RemoveCredentialMetaCAS("concurrent", "static", "") + if err != nil { + t.Fatalf("matching CAS delete: %v", err) + } + if !removed || !noConcurrent { + t.Fatalf("matching CAS delete: removed=%v noConcurrent=%v, want true,true", removed, noConcurrent) + } + if m, _ := s.GetCredentialMeta("concurrent"); m != nil { + t.Error("matching CAS delete left meta row") + } + if hh, _ := s.GetCredentialHealth("concurrent"); hh != nil { + t.Error("matching CAS delete left health row") + } +} + +// TestRemoveCredentialMetaStillDeletesHealth confirms the non-CAS +// delete-by-name path is unchanged: deleting the meta row also deletes the +// health row (so a same-named recreation does not inherit a stale cooldown). +func TestRemoveCredentialMetaStillDeletesHealth(t *testing.T) { + s := newTestStore(t) + if err := s.AddCredentialMeta("plain", "static", ""); err != nil { + t.Fatalf("seed meta: %v", err) + } + until := time.Now().Add(60 * time.Second).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("plain", "cooldown", until, "401 auth fail"); err != nil { + t.Fatalf("seed health: %v", err) + } + + deleted, err := s.RemoveCredentialMeta("plain") + if err != nil || !deleted { + t.Fatalf("RemoveCredentialMeta = %v, %v; want true, nil", deleted, err) + } + if m, _ := s.GetCredentialMeta("plain"); m != nil { + t.Error("meta row survived RemoveCredentialMeta") + } + if h, _ := s.GetCredentialHealth("plain"); h != nil { + t.Error("health row survived RemoveCredentialMeta (stale cooldown would be inherited)") + } +} + func TestCredentialMetaCRUDRoundTrip(t *testing.T) { s := newTestStore(t) From 9ac52f14c7b873fb847a0dd59df73a64c59c3b3c Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 17:21:52 +0800 Subject: [PATCH 36/49] fix(proxy): API-host failover requires per-flow pool-usage evidence (non-consuming peek, no blind fallback) --- internal/proxy/pool_attribution.go | 34 +++ internal/proxy/pool_failover.go | 60 +++-- internal/proxy/pool_failover_apihost_test.go | 253 +++++++++++++++++++ internal/proxy/pool_failover_test.go | 22 +- internal/proxy/pool_splithost_test.go | 5 + 5 files changed, 350 insertions(+), 24 deletions(-) create mode 100644 internal/proxy/pool_failover_apihost_test.go diff --git a/internal/proxy/pool_attribution.go b/internal/proxy/pool_attribution.go index 40c110f..c740bbc 100644 --- a/internal/proxy/pool_attribution.go +++ b/internal/proxy/pool_attribution.go @@ -86,6 +86,40 @@ func (m *flowInjectedMember) Recover(flowID uuid.UUID) (string, bool) { return e.member, true } +// Peek returns the member tagged for the given flow ID WITHOUT removing the +// entry. Returns ("", false) when no live tag exists. +// +// Peek exists for poolForResponse's API-host failover path. That path +// iterates CredentialsForDestination(dest:port), which can return MULTIPLE +// matching pools for one destination. A consuming Recover inside that loop +// would let the FIRST matching pool consume the tag even when the tag +// actually belongs to a LATER pool, starving the true owner and forcing a +// blind ResolveActive on an unrelated pool (the round-12 bug). A single +// non-consuming Peek before/independent of the loop serves the whole +// iteration so attribution is decided once, by membership, against the one +// pool the injected member actually belongs to. +// +// poolForResponse is invoked exactly once per response (one flow -> +// one Response callback -> one poolForResponse call), so not deleting the +// entry here does not re-attribute across responses; the entry is bounded +// by flowAttrTTL and the opportunistic sweep in Tag. The consuming Recover +// is retained for any caller that requires exactly-once semantics. +func (m *flowInjectedMember) Peek(flowID uuid.UUID) (string, bool) { + if flowID == uuid.Nil { + return "", false + } + m.mu.Lock() + defer m.mu.Unlock() + e, ok := m.entries[flowID] + if !ok { + return "", false + } + if time.Now().After(e.expires) { + return "", false + } + return e.member, true +} + // refreshAttrTTL is how long a real-refresh-token -> member tag is retained. // An OAuth refresh round-trip (agent POSTs refresh_token, upstream answers // with rotated tokens) completes in well under a second in practice; a diff --git a/internal/proxy/pool_failover.go b/internal/proxy/pool_failover.go index 4e8d55a..e54ddde 100644 --- a/internal/proxy/pool_failover.go +++ b/internal/proxy/pool_failover.go @@ -168,33 +168,53 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember, pr // then header refinement); for the common unscoped-binding case the // result is still https-equivalent so behavior is unchanged. proto = a.detectRequestProtocol(f, port).String() + + // Round-12: recover the per-flow injected member ONCE, with a + // NON-consuming Peek, before iterating the matching credentials. + // CredentialsForDestination(dest:port) can return MULTIPLE matching + // pools for one destination, but the request-side header injection used + // exactly ONE binding (the first match). Two concrete bugs the old + // per-pool consuming Recover had: + // + // 1. flowInjected.Recover is single-use. Calling it inside the loop let + // the FIRST matching pool consume the tag even when the tag belonged + // to a LATER pool; the earlier pool then hit the blind ResolveActive + // fallback (cooling an unrelated pool) and the later — correct — + // pool could no longer see its own tag. + // 2. With no per-flow tag at all (a plain binding was used, or the + // request never went through pooled injection), the old code blindly + // cooled ResolveActive(boundName) for ANY matching pool even though + // this request never used that pool. + // + // Mirror how the token-endpoint path was hardened: a single + // non-consuming Peek, and attribute a pool ONLY when the injected member + // PROVES this request used THAT specific pool. No blind ResolveActive + // fallback — without proof, skip the pool (no cooldown). + injected := "" + if f != nil && f.Id != uuid.Nil { + if m, ok := a.flowInjected.Peek(f.Id); ok { + injected = m + } + } for _, boundName := range res.CredentialsForDestination(host, port, proto) { if !pr.IsPool(boundName) { continue } // Attribute the failover to the member that backed THIS request // when it was SENT, recovered by flow ID from the injection-time - // tag. ResolveActive at response time is unsafe under concurrency: - // a sibling request's 429 may have already switched the active - // member, so attributing by response-time active would cool an - // innocent member and park both accounts (Finding 1). Fall back to - // ResolveActive only when no per-flow tag exists (e.g. the request - // never went through the pooled injection path). - if f != nil && f.Id != uuid.Nil { - if injected, ok := a.flowInjected.Recover(f.Id); ok && injected != "" { - // Only honor the tag if the injected member is still a - // member of this pool (a membership change could have - // raced); otherwise fall through to ResolveActive. - if pr.PoolForMember(injected) == boundName { - return boundName, injected, proto, pr, true - } - } - } - member, mok := pr.ResolveActive(boundName) - if !mok || member == "" { - continue + // tag. ResolveActive at response time is unsafe under concurrency + // (a sibling request's 429 may have already switched the active + // member; cooling response-time-active would park an innocent + // member, Finding 1) AND unsound without proof of pool usage + // (cooling the active member of a merely dest-matching pool the + // request never used, round-12). Only attribute when the per-flow + // injected member resolves to THIS pool; otherwise skip it (no + // cooldown). If the injected member left this pool (membership + // raced), there is no longer a sound member to attribute to, so + // likewise skip rather than blind-fall-back. + if injected != "" && pr.PoolForMember(injected) == boundName { + return boundName, injected, proto, pr, true } - return boundName, member, proto, pr, true } // Token-endpoint path. An OAuth refresh hits the credential's token-URL diff --git a/internal/proxy/pool_failover_apihost_test.go b/internal/proxy/pool_failover_apihost_test.go new file mode 100644 index 0000000..e212092 --- /dev/null +++ b/internal/proxy/pool_failover_apihost_test.go @@ -0,0 +1,253 @@ +package proxy + +import ( + "net/http" + "net/url" + "sync/atomic" + "testing" + + mitmproxy "github.com/lqqyt2423/go-mitmproxy/proxy" + "github.com/nemirovsky/sluice/internal/store" + "github.com/nemirovsky/sluice/internal/vault" + uuid "github.com/satori/go.uuid" +) + +// Round-12 regression: poolForResponse's API-host path iterates +// CredentialsForDestination(dest:port), which can return MULTIPLE matching +// pools (or a plain binding) for one destination. The old code called the +// single-use flowInjected.Recover INSIDE that loop and blind-fell-back to +// ResolveActive for any matching pool, so: +// +// - a request that used a PLAIN binding (no flow tag) but whose dest:port +// also matched a pool would wrongly cool that pool's active member; and +// - when two pools matched the same dest:port and the flow tag belonged to +// the SECOND pool, the FIRST pool consumed the tag, mis-cooled itself via +// blind ResolveActive, and starved the true (second) pool of its tag. +// +// The fix Peeks the per-flow injected member ONCE (non-consuming) before the +// loop and attributes a pool ONLY when the injected member proves it belongs +// to THAT pool. No blind ResolveActive fallback. + +// setupTwoPoolAddonSameAPIHost wires a SluiceAddon with TWO failover pools +// (poolX, poolY) BOTH bound to the same API host api.example.com:443. The +// agent's bindings point at the pool names; CredentialsForDestination for +// api.example.com:443 therefore returns [poolX, poolY] in binding order. +func setupTwoPoolAddonSameAPIHost(t *testing.T) (*SluiceAddon, *atomic.Pointer[vault.PoolResolver]) { + t.Helper() + + provider := &addonWritableProvider{ + creds: map[string]string{ + "x1": poolMemberCred(t, "x1-access", "x1-refresh"), + "x2": poolMemberCred(t, "x2-access", "x2-refresh"), + "y1": poolMemberCred(t, "y1-access", "y1-refresh"), + "y2": poolMemberCred(t, "y2-access", "y2-refresh"), + }, + } + + // Two distinct pool bindings on the SAME api host:port. Binding order + // (poolX first) drives CredentialsForDestination ordering. + bindings := []vault.Binding{ + {Destination: "api.example.com", Ports: []int{443}, Credential: "poolX"}, + {Destination: "api.example.com", Ports: []int{443}, Credential: "poolY"}, + } + resolver, err := vault.NewBindingResolver(bindings) + if err != nil { + t.Fatalf("NewBindingResolver: %v", err) + } + var resolverPtr atomic.Pointer[vault.BindingResolver] + resolverPtr.Store(resolver) + + addon := NewSluiceAddon(WithResolver(&resolverPtr), WithProvider(provider)) + addon.persistDone = make(chan struct{}, 10) + + poolX := store.Pool{Name: "poolX", Strategy: store.PoolStrategyFailover} + poolX.Members = []store.PoolMember{ + {Credential: "x1", Position: 0}, + {Credential: "x2", Position: 1}, + } + poolY := store.Pool{Name: "poolY", Strategy: store.PoolStrategyFailover} + poolY.Members = []store.PoolMember{ + {Credential: "y1", Position: 0}, + {Credential: "y2", Position: 1}, + } + var prPtr atomic.Pointer[vault.PoolResolver] + prPtr.Store(vault.NewPoolResolver([]store.Pool{poolX, poolY}, nil)) + addon.SetPoolResolver(&prPtr) + + return addon, &prPtr +} + +// newAPIHostRespFlow builds a plain API-host response flow (not a token +// endpoint). The request URL is a regular API path on api.example.com so the +// token-URL index path is NOT exercised — only the CONNECT-host API path. +// +// status is parameterized on purpose: this is a general API-host response +// builder for the round-12 suite (429 is the only failover-trigger the +// current cases need, but 403/401/2xx are equally valid inputs). unparam +// only sees the current callers all using 429. +// +//nolint:unparam +func newAPIHostRespFlow(client *mitmproxy.ClientConn, status int) *mitmproxy.Flow { + u, _ := url.Parse("https://api.example.com/v1/responses") + return &mitmproxy.Flow{ + Id: uuid.NewV4(), + ConnContext: &mitmproxy.ConnContext{ClientConn: client}, + Request: &mitmproxy.Request{Method: "POST", URL: u, Header: make(http.Header)}, + Response: &mitmproxy.Response{StatusCode: status, Header: make(http.Header)}, + } +} + +// TestAPIHostFailover_PlainBindingNoTag_NoPoolCooled is round-12 case (a). +// +// A request used a PLAIN (non-pool) binding so there is NO per-flow injected +// tag, but dest:port (api.example.com:443) also matches pooled bindings. The +// old code blind-fell-back to ResolveActive and cooled a pool the request +// never used. After the fix, with no proof of pool usage, NO pool member is +// cooled and poolForResponse returns ok=false. +// +// Fails before the fix: the old loop, finding poolX matched and no recoverable +// tag, returned (poolX, x1, true) via the blind ResolveActive fallback. +func TestAPIHostFailover_PlainBindingNoTag_NoPoolCooled(t *testing.T) { + addon, prPtr := setupTwoPoolAddonSameAPIHost(t) + client := setupAddonConn(addon, "api.example.com:443") + + pr := prPtr.Load() + if got, _ := pr.ResolveActive("poolX"); got != "x1" { + t.Fatalf("pre active poolX = %q, want x1", got) + } + if got, _ := pr.ResolveActive("poolY"); got != "y1" { + t.Fatalf("pre active poolY = %q, want y1", got) + } + + // No flowInjected.Tag for this flow id: models a request that went out + // on a plain binding (or never through pooled injection). + f := newAPIHostRespFlow(client, 429) + + pool, member, _, _, ok := addon.poolForResponse(f) + if ok { + t.Fatalf("poolForResponse: with no pool-usage evidence it must NOT "+ + "attribute any pool; got ok=true pool=%q member=%q", pool, member) + } + + called := false + addon.SetOnFailover(func(FailoverEvent) { called = true }) + addon.Response(newAPIHostRespFlow(client, 429)) + if called { + t.Fatal("onFailover invoked though no pool was used by the request") + } + + // Neither pool's active member changed (nothing was cooled). + if got, _ := pr.ResolveActive("poolX"); got != "x1" { + t.Fatalf("post active poolX = %q, want x1 (must not be cooled)", got) + } + if got, _ := pr.ResolveActive("poolY"); got != "y1" { + t.Fatalf("post active poolY = %q, want y1 (must not be cooled)", got) + } +} + +// TestAPIHostFailover_TagBelongsToSecondPool is round-12 case (b). +// +// Both poolX and poolY bind api.example.com:443, so +// CredentialsForDestination returns [poolX, poolY]. The per-flow injected +// member belongs to the SECOND pool (poolY's y1). The old single-use +// Recover, called inside the loop, was consumed by the FIRST pool (poolX): +// poolX's PoolForMember(y1) != "poolX" so poolX blind-fell-back to +// ResolveActive and was wrongly cooled, AND poolY could no longer see the +// (already consumed) tag so poolY — the true owner — was never attributed. +// +// After the fix the single non-consuming Peek serves the whole iteration: +// poolX is skipped (y1 is not its member, no blind fallback) and poolY is +// correctly attributed to y1. +// +// Fails before the fix: poolForResponse returned (poolX, x1) — wrong pool, +// wrong member. +func TestAPIHostFailover_TagBelongsToSecondPool(t *testing.T) { + addon, prPtr := setupTwoPoolAddonSameAPIHost(t) + client := setupAddonConn(addon, "api.example.com:443") + pr := prPtr.Load() + + f := newAPIHostRespFlow(client, 429) + // The request was backed by poolY's member y1 (the SECOND matching + // pool). buildPooledMemberPairs would have Tag'd this at injection time. + addon.flowInjected.Tag(f.Id, "y1") + + pool, member, _, _, ok := addon.poolForResponse(f) + if !ok { + t.Fatal("poolForResponse: a genuinely pooled API request must be attributed; got ok=false") + } + if pool != "poolY" || member != "y1" { + t.Fatalf("got pool=%q member=%q, want poolY/y1 "+ + "(first pool must not consume the tag / mis-cool itself)", pool, member) + } + + var got FailoverEvent + gotCalled := make(chan struct{}, 1) + addon.SetOnFailover(func(ev FailoverEvent) { + got = ev + gotCalled <- struct{}{} + }) + + f2 := newAPIHostRespFlow(client, 429) + addon.flowInjected.Tag(f2.Id, "y1") + addon.Response(f2) + + // poolY's y1 cooled -> poolY rolls to y2. poolX must be UNTOUCHED. + if active, _ := pr.ResolveActive("poolY"); active != "y2" { + t.Fatalf("post active poolY = %q, want y2 (y1 must have been cooled)", active) + } + if active, _ := pr.ResolveActive("poolX"); active != "x1" { + t.Fatalf("post active poolX = %q, want x1 (poolX must NOT be cooled — "+ + "it was not used by this request)", active) + } + if got.Pool != "poolY" || got.From != "y1" { + t.Fatalf("FailoverEvent = %+v, want pool=poolY from=y1", got) + } +} + +// TestAPIHostFailover_SinglePoolValidTag_NoRegression is round-12 case (c): +// the legit happy path must still work. A genuine single-pool API 429 whose +// flow tag proves it used poolX still fails over the correct member of poolX. +func TestAPIHostFailover_SinglePoolValidTag_NoRegression(t *testing.T) { + addon, prPtr := setupTwoPoolAddonSameAPIHost(t) + client := setupAddonConn(addon, "api.example.com:443") + pr := prPtr.Load() + + if got, _ := pr.ResolveActive("poolX"); got != "x1" { + t.Fatalf("pre active poolX = %q, want x1", got) + } + + f := newAPIHostRespFlow(client, 429) + addon.flowInjected.Tag(f.Id, "x1") + + pool, member, _, _, ok := addon.poolForResponse(f) + if !ok { + t.Fatal("poolForResponse: genuine pooled API 429 must be attributed; got ok=false") + } + if pool != "poolX" || member != "x1" { + t.Fatalf("got pool=%q member=%q, want poolX/x1", pool, member) + } + + var got FailoverEvent + gotCalled := make(chan struct{}, 1) + addon.SetOnFailover(func(ev FailoverEvent) { + got = ev + gotCalled <- struct{}{} + }) + + f2 := newAPIHostRespFlow(client, 429) + addon.flowInjected.Tag(f2.Id, "x1") + addon.Response(f2) + + if active, _ := pr.ResolveActive("poolX"); active != "x2" { + t.Fatalf("post active poolX = %q, want x2 (x1 must be cooled by the 429)", active) + } + if active, _ := pr.ResolveActive("poolY"); active != "y1" { + t.Fatalf("post active poolY = %q, want y1 (poolY must not be cooled)", active) + } + if got.Pool != "poolX" || got.From != "x1" || got.To != "x2" || got.Reason != "429" { + t.Fatalf("FailoverEvent = %+v, want pool=poolX from=x1 to=x2 reason=429", got) + } + if got.Class != failoverRateLimited { + t.Fatalf("class = %v, want rate-limited", got.Class) + } +} diff --git a/internal/proxy/pool_failover_test.go b/internal/proxy/pool_failover_test.go index 4ab632a..4c5e472 100644 --- a/internal/proxy/pool_failover_test.go +++ b/internal/proxy/pool_failover_test.go @@ -113,7 +113,14 @@ func TestFailoverSynchronousHealthSwap(t *testing.T) { gotCalled <- struct{}{} }) - addon.Response(newPoolRespFlow(client, 429, []byte(`{"error":"rate_limited"}`))) + f := newPoolRespFlow(client, 429, []byte(`{"error":"rate_limited"}`)) + // A genuine pooled request always carries the injection-time flow tag + // (addon.go buildPhantomPairs / Finding-4 token-host expansion call + // flowInjected.Tag). Post-round-12 the API-host failover path requires + // that pool-usage evidence and no longer blind-falls-back to + // ResolveActive, so a realistic regression must tag like production. + addon.flowInjected.Tag(f.Id, "memA") + addon.Response(f) // Synchronous: by the time Response returns the swap is already done. if active, _ := pr.ResolveActive("codex_pool"); active != "memB" { @@ -151,7 +158,9 @@ func TestFailoverCooldownTTLAndLazyRecovery(t *testing.T) { // Auth failure (401) -> memA cools down for AuthFailCooldown. before := time.Now() - addon.Response(newPoolRespFlow(client, 401, nil)) + f401 := newPoolRespFlow(client, 401, nil) + addon.flowInjected.Tag(f401.Id, "memA") // production injection-time tag + addon.Response(f401) until, cooling := pr.CooldownUntil("memA") if !cooling { t.Fatal("memA should be cooling down after 401") @@ -215,8 +224,10 @@ func TestFailoverNoticeNonBlocking(t *testing.T) { }() }) + fnb := newPoolRespFlow(client, 429, nil) + addon.flowInjected.Tag(fnb.Id, "memA") // production injection-time tag start := time.Now() - addon.Response(newPoolRespFlow(client, 429, nil)) + addon.Response(fnb) elapsed := time.Since(start) if elapsed > 200*time.Millisecond { t.Fatalf("Response blocked %v on failover callback; must be non-blocking", elapsed) @@ -273,7 +284,9 @@ func TestFailoverAuditEvent(t *testing.T) { addon.auditLog = logger client := setupAddonConn(addon, "auth.example.com:443") - addon.Response(newPoolRespFlow(client, 429, []byte(`{"error":"rate_limited"}`))) + fae := newPoolRespFlow(client, 429, []byte(`{"error":"rate_limited"}`)) + addon.flowInjected.Tag(fae.Id, "memA") // production injection-time tag + addon.Response(fae) if err := logger.Close(); err != nil { t.Fatalf("logger close: %v", err) @@ -317,6 +330,7 @@ func TestPoolForResponseResolvesActiveMember(t *testing.T) { addon, _, prPtr := setupPoolAddon(t, "memA", "memB") client := setupAddonConn(addon, "auth.example.com:443") f := newPoolRespFlow(client, 429, nil) + addon.flowInjected.Tag(f.Id, "memA") // production injection-time tag pool, member, _, pr, ok := addon.poolForResponse(f) if !ok { diff --git a/internal/proxy/pool_splithost_test.go b/internal/proxy/pool_splithost_test.go index ff86b70..ff9c24d 100644 --- a/internal/proxy/pool_splithost_test.go +++ b/internal/proxy/pool_splithost_test.go @@ -349,6 +349,11 @@ func TestFinding3_ProtocolScopedPooledBindingFailoverLookup(t *testing.T) { f.Request.URL.Host = "grpc.example.com" f.Request.Header.Set("Content-Type", "application/grpc") f.Response.Header.Set("Content-Type", "application/grpc") + // A genuine pooled gRPC request carries the injection-time flow tag + // (addon.go buildPhantomPairs flowInjected.Tag). Post-round-12 the + // API-host failover path requires that pool-usage evidence instead of + // blind-falling-back to ResolveActive, so model production. + addon.flowInjected.Tag(f.Id, "gA") // Sanity: detectRequestProtocol must classify this as gRPC, and the // hardcoded-"https" lookup would have missed the grpc-scoped binding. From b2daf4a819ef306024644b367740d78ca20125fe Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 17:41:28 +0800 Subject: [PATCH 37/49] fix(cred): validate vault before store removal; cap coalesced subscribers; CLAUDE.md attribution accuracy --- CLAUDE.md | 2 +- cmd/sluice/cred.go | 43 ++++++---- cmd/sluice/cred_test.go | 98 ++++++++++++++++++++++ cmd/sluice/pool_test.go | 12 ++- internal/channel/broker.go | 26 ++++++ internal/channel/channel_test.go | 139 +++++++++++++++++++++++++++++++ 6 files changed, 302 insertions(+), 18 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index d948f15..bc8a605 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -240,7 +240,7 @@ Auto-failover on 429/401 is the primary mechanism; `pool rotate` is an operator **Phase 2 — auto-failover on 429 / 401:** - **Classification** (`classifyFailover` in `internal/proxy/pool_failover.go`, called from `SluiceAddon.Response` for pooled destinations): `429` or `403 + insufficient_quota` → rate-limited; `401` or token-body `invalid_grant` / `invalid_token` → auth-failure; `5xx` / other → no-op. The token-endpoint body is only trusted when the request URL matched the OAuth index. -- **Pool attribution for the response** (`poolForResponse`): a response is attributed to a pool either (a) when the flow's CONNECT host has a pooled binding (the API-host 429/403 path), **or** (b) when the request URL matches the OAuth token-URL index for a credential that is a pool member (the token-endpoint 401 / `invalid_grant` path). Case (b) is essential: an OAuth refresh hits the credential's token-URL host (e.g. `auth.openai.com`), which has no pool binding — only the API host (e.g. `api.openai.com`) does — so without the token-URL index match the token-endpoint classification would be dead code for the Codex deployment. `idx.Match` is strict 1:1 token_url→credential, so case (b) cools the exact member whose refresh token was injected. +- **Pool attribution for the response** (`poolForResponse`): a response is attributed to a pool either (a) when the flow's CONNECT host has a pooled binding (the API-host 429/403 path), **or** (b) when the request URL matches the OAuth token-URL index for a credential that is a pool member (the token-endpoint 401 / `invalid_grant` path). Case (b) is essential: an OAuth refresh hits the credential's token-URL host (e.g. `auth.openai.com`), which has no pool binding — only the API host (e.g. `api.openai.com`) does — so without the token-URL index match the token-endpoint classification would be dead code for the Codex deployment. Two pooled members share that token URL, so `OAuthIndex.Match` (deterministic-first) is **not** trusted to name the member — using it would misattribute the refresh to whichever credential sorts first. Instead `resolveOAuthResponseAttribution` consults `OAuthIndex.MatchAll`: if *any* credential sharing this token URL is a pool member, the owning member is recovered from the **real refresh token sluice injected into this request's body** (the R1 join key from `refreshAttr.Recover`, unique per member). That recovered member is what gets cooled/persisted. If the member is unrecoverable (no live tag — tag expired or consumed before a slow response), it is **fail-closed**: the vault write is skipped (the agent still gets phantoms; the next refresh re-tags and retries) — sluice never guesses the member from the shared token URL. - **Synchronous in-memory failover (I1):** health is updated in-process *before* the response returns — `MarkCooldown` takes the resolver write lock, `ResolveActive` the read lock — so the active-member switch never waits on the 2s data-version watcher (which only reconciles). A detached `onFailover` callback also writes `SetCredentialHealth(member, 'cooldown', now+ttl, reason)` for durability. Cooldown TTLs: `vault.RateLimitCooldown` = 60s, `vault.AuthFailCooldown` = 300s. **Cooldown extension is monotonic on both layers:** a member parked for an auth failure (300s) that subsequently trips a rate-limit (60s) keeps the LATER expiry — `MarkCooldown` (in-memory) and `SetCredentialHealth`'s `cooldown` upsert (durable, via a `CASE`/comparison against the stored future `cooldown_until`) both keep `max(existing-future, new)` so a known-bad credential is never made eligible early. Only the extend path is monotonic: an explicit clear (zero/past `until` in `MarkCooldown`) and any transition to `status='healthy'` still shorten/clear (recovery intact), and lazy expiry still wins over an already-expired stored cooldown. No in-flight retry — the next request uses the new member. - **Reload does not resurrect a cooled member:** because the durable `SetCredentialHealth` write is detached and best-effort, any reload (SIGHUP or the 2s data-version watcher firing on *any* unrelated DB write) rebuilds the resolver from store rows alone via `NewPoolResolver`. `Server.StorePool` therefore calls `PoolResolver.MergeLiveCooldowns(prev)` to carry forward still-active in-memory cooldowns from the resolver being replaced before the atomic swap. The merge is monotonic (a live cooldown is never shortened/erased by an unrelated reload) and drops cooldowns for credentials no longer in any pool. - **Audit:** a `cred_failover` event (Verdict `failover`, Credential = the cooled-down member) with `Reason = ":->:<429|403|401|invalid_grant>"`, emitted synchronously in `handlePoolFailover`. diff --git a/cmd/sluice/cred.go b/cmd/sluice/cred.go index 42d1527..c91f359 100644 --- a/cmd/sluice/cred.go +++ b/cmd/sluice/cred.go @@ -562,17 +562,35 @@ func handleCredRemove(args []string) error { } name := fs.Arg(0) - // Store-first removal order (Finding 3, round-9). The authoritative - // pool-membership gate is the atomic, fail-closed RemoveCredentialMeta - // in the store layer: it refuses inside its own transaction if the - // credential is still a live pool member, closing the TOCTOU window - // where a separate pre-check passes and a concurrent caller then - // creates a pool with this credential before the vault secret is - // deleted. The vault secret is therefore only deleted AFTER the store - // removal has succeeded; if the store removal refuses, the vault - // secret is left untouched and no window exists where the secret is - // gone but credential_pool_members still references it. + // Removal order (Finding 1, round-13 + Finding 3, round-9): // + // 1. Open/validate the vault store FIRST (no delete yet -- just + // confirm it opens). If the configured backend cannot be opened + // (e.g. a non-age provider unsupported by the CLI), abort BEFORE + // any metadata is removed. Doing the store removal first and then + // discovering the vault is unopenable would leave credential_meta + // gone while the vault secret + bindings/rules are orphaned. + // + // 2. Run the store-layer pool-membership gate (RemoveCredentialMeta). + // This is the atomic, fail-closed guard: it refuses inside its own + // transaction if the credential is still a live pool member, + // closing the TOCTOU window where a separate pre-check passes and a + // concurrent caller then creates a pool with this credential before + // the vault secret is deleted. + // + // 3. Only call vs.Remove AFTER that gate succeeds. If the gate + // refuses, the vault secret is left untouched and no window exists + // where the secret is gone but credential_pool_members still + // references it. + // + // (1) precedes (2) so an unopenable vault aborts before any metadata is + // removed; (2) still precedes (3) so the store gate always runs before + // the actual secret delete. + vs, err := openVaultStore(*dbPath) + if err != nil { + return err + } + // Only consult/mutate the DB if it already exists (do not create it as // a side effect of a removal). dbExists := false @@ -609,11 +627,6 @@ func handleCredRemove(args []string) error { } } - vs, err := openVaultStore(*dbPath) - if err != nil { - return err - } - // Store removal already succeeded (or the DB does not exist). Now it is // safe to delete the vault secret. If already gone (previous partial // cleanup), continue to DB cleanup so stale rules/bindings can be diff --git a/cmd/sluice/cred_test.go b/cmd/sluice/cred_test.go index bde1bd0..eb942d6 100644 --- a/cmd/sluice/cred_test.go +++ b/cmd/sluice/cred_test.go @@ -2717,3 +2717,101 @@ func TestFinding3Round9_TOCTOUInterleaveStoreGatesVaultDelete(t *testing.T) { } } } + +// TestFinding1Round13_UnopenableVaultAbortsBeforeMetaRemoval is the Copilot +// round-13 Finding 1 regression. round-9 moved the store-layer +// RemoveCredentialMeta gate before the vault delete, but openVaultStore was +// still called only AFTER RemoveCredentialMeta. If the configured vault +// backend cannot be opened (e.g. a non-age provider unsupported by the CLI), +// `cred remove` returned an error with credential_meta ALREADY removed while +// the vault secret + bindings/rules were left orphaned. +// +// The fix opens/validates the vault store FIRST (no delete -- just confirm it +// opens). An unopenable vault must abort BEFORE any metadata is removed. +// +// Fail-before/pass-after: with a non-age provider configured, openVaultStore +// returns an error. Before the fix, RemoveCredentialMeta had already run by +// the time that error surfaced, so credential_meta was gone. After the fix, +// the vault-open failure aborts first and credential_meta + bindings + rules +// remain intact. +func TestFinding1Round13_UnopenableVaultAbortsBeforeMetaRemoval(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "test.db") + + db, err := store.New(dbPath) + if err != nil { + t.Fatal(err) + } + // Seed credential metadata, an auto-created rule, and a binding so we can + // assert they all survive a failed removal. + if err := db.AddCredentialMeta("orphan_cred", "oauth", "https://auth.example.com/token"); err != nil { + t.Fatalf("AddCredentialMeta: %v", err) + } + if _, err := db.AddRule("allow", store.RuleOpts{ + Destination: "api.example.com", + Source: store.CredAddSourcePrefix + "orphan_cred", + }); err != nil { + t.Fatalf("AddRule: %v", err) + } + if _, err := db.AddBinding("api.example.com", "orphan_cred", store.BindingOpts{ + Ports: []int{443}, + }); err != nil { + t.Fatalf("AddBinding: %v", err) + } + // Configure a non-age provider so openVaultStore fails (CLI only + // supports the age backend). + provider := "hashicorp" + if err := db.UpdateConfig(store.ConfigUpdate{VaultProvider: &provider}); err != nil { + t.Fatalf("UpdateConfig: %v", err) + } + _ = db.Close() + + // `cred remove` must fail because the vault store cannot be opened. + err = handleCredCommand([]string{"remove", "orphan_cred", "--db", dbPath}) + if err == nil { + t.Fatal("Finding 1 r13: cred remove must fail when the vault backend cannot be opened") + } + if !strings.Contains(err.Error(), "hashicorp") { + t.Errorf("Finding 1 r13: expected a vault-open error mentioning the provider, got: %v", err) + } + + // The invariant: because the vault could not be opened, NOTHING must have + // been removed -- credential_meta, the rule, and the binding must all be + // intact. (Before the fix, credential_meta was already gone here.) + chk, err := store.New(dbPath) + if err != nil { + t.Fatal(err) + } + defer func() { _ = chk.Close() }() + + meta, gerr := chk.GetCredentialMeta("orphan_cred") + if gerr != nil { + t.Fatalf("Finding 1 r13: GetCredentialMeta: %v", gerr) + } + if meta == nil { + t.Fatal("Finding 1 r13: credential_meta was removed despite the vault failing to open — meta removal must come AFTER the vault-open check") + } + + binds, berr := chk.ListBindingsByCredential("orphan_cred") + if berr != nil { + t.Fatalf("Finding 1 r13: ListBindingsByCredential: %v", berr) + } + if len(binds) != 1 { + t.Fatalf("Finding 1 r13: binding was removed despite the vault failing to open; want 1 binding, got %d", len(binds)) + } + + rules, rerr := chk.ListRules(store.RuleFilter{}) + if rerr != nil { + t.Fatalf("Finding 1 r13: ListRules: %v", rerr) + } + found := false + for _, r := range rules { + if r.Source == store.CredAddSourcePrefix+"orphan_cred" { + found = true + break + } + } + if !found { + t.Fatal("Finding 1 r13: auto-created rule was removed despite the vault failing to open") + } +} diff --git a/cmd/sluice/pool_test.go b/cmd/sluice/pool_test.go index ac23405..4a4c398 100644 --- a/cmd/sluice/pool_test.go +++ b/cmd/sluice/pool_test.go @@ -302,8 +302,16 @@ func TestCredRemoveFailsClosedWhenDBUnopenable(t *testing.T) { if err == nil { t.Fatalf("cred remove with unopenable DB: err = nil, want fail-closed error") } - if !strings.Contains(err.Error(), "refusing to remove") { - t.Fatalf("cred remove error = %v, want fail-closed message containing %q", err, "refusing to remove") + // Fail-closed message. Since Finding 1 (round-13) the vault store is + // opened/validated FIRST, before the pool-membership gate, so an + // unopenable DB is caught by openVaultStore ("open store ...") rather + // than by the membership-guard branch ("refusing to remove ..."). Either + // wording is an acceptable fail-closed refusal: the invariant the test + // guards is that the removal aborts BEFORE the vault delete (asserted + // below by the surviving secret), not the exact message. Accept both. + msg := err.Error() + if !strings.Contains(msg, "refusing to remove") && !strings.Contains(msg, "open store") { + t.Fatalf("cred remove error = %v, want a fail-closed message (containing %q or %q)", err, "refusing to remove", "open store") } // The secret must still be present: the removal was refused before the diff --git a/internal/channel/broker.go b/internal/channel/broker.go index 9cb4f26..2ef528e 100644 --- a/internal/channel/broker.go +++ b/internal/channel/broker.go @@ -24,6 +24,20 @@ var ErrDestinationRateLimited = fmt.Errorf("destination rate limited") // never tap expired buttons. const timedOutTTL = 10 * time.Minute +// maxCoalescedSubs caps how many coalesced subscribers may attach to a single +// primary prompt. Coalesced subscribers deliberately bypass both the pending +// limit and the per-destination rate limit (the operator answers the whole +// burst with one tap), but an unbounded attach lets an abusive client +// hammering one dest:port accumulate goroutines and channels without limit. +// The cap bounds that fan-out: a reasonable burst still coalesces to one +// prompt, but once the primary already has this many subscribers the excess +// callers are rejected with the broker's standard over-capacity response +// (ResponseDeny + ErrPendingLimitExceeded) instead of being appended. 256 is +// well above any legitimate concurrent burst to a single target yet small +// enough that the worst-case goroutine/channel footprint per primary stays +// bounded. +const maxCoalescedSubs = 256 + // Broker coordinates approval flow across multiple enabled channels. // Approval requests are broadcast to all channels. The first Resolve call // wins. Other channels receive CancelApproval for cleanup. @@ -262,6 +276,18 @@ func (b *Broker) Request(dest string, port int, protocol string, timeout time.Du if dedupKey != "" { if primaryID, ok := b.dedupIndex[dedupKey]; ok { if w, ok := b.waiters[primaryID]; ok { + // Bound the coalesced fan-out. Without a cap an abusive + // client hammering one dest:port grows w.subs (and a + // blocked goroutine per sub) without limit, since + // coalesced subscribers intentionally skip the pending + // and per-destination limits. Mirror the broker's + // existing over-capacity behavior (the pending-limit + // branch below) and reject the excess caller instead of + // appending unboundedly. + if len(w.subs) >= maxCoalescedSubs { + b.mu.Unlock() + return ResponseDeny, ErrPendingLimitExceeded + } subCh := make(chan Response, 1) w.subs = append(w.subs, subCh) w.count++ diff --git a/internal/channel/channel_test.go b/internal/channel/channel_test.go index 9457647..0bb15dc 100644 --- a/internal/channel/channel_test.go +++ b/internal/channel/channel_test.go @@ -997,6 +997,145 @@ func TestBrokerCoalesceShutdownFanOut(t *testing.T) { } } +// TestBrokerCoalesceSubsCapBounded is the Copilot round-13 Finding 3 +// regression. Coalesced subscribers intentionally bypass both the pending +// limit and the per-destination rate limit, and before the fix they were +// appended to the primary's w.subs with NO cap. A burst/abusive client +// hammering one dest:port could accumulate unbounded subscriber channels and +// blocked goroutines. The fix caps w.subs at maxCoalescedSubs and rejects the +// excess callers with the broker's standard over-capacity response +// (ResponseDeny + ErrPendingLimitExceeded), mirroring the pending-limit +// branch. +// +// Deterministic: a primary stays pending for the whole test (long timeout, no +// resolve). Exactly maxCoalescedSubs subscribers attach (CoalescedCount +// settles at 1+maxCoalescedSubs). Then synchronous over-cap callers must each +// be denied immediately with ErrPendingLimitExceeded while CoalescedCount +// stays pinned at the cap (no unbounded append). Pre-fix this test fails: +// CoalescedCount climbs past 1+maxCoalescedSubs and the over-cap callers +// block instead of being denied. +func TestBrokerCoalesceSubsCapBounded(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const dest = "capburst.example.com" + const port = 443 + + // Primary: stays pending for the entire test (no resolve, long timeout). + type res = result + primaryOut := make(chan res, 1) + go func() { + resp, err := broker.Request(dest, port, "https", 30*time.Second) + primaryOut <- res{resp, err} + }() + + // Wait for the primary prompt to land and become the dedup primary. + deadline := time.After(5 * time.Second) + var primaryID string + for primaryID == "" { + if reqs := ch.getRequests(); len(reqs) >= 1 { + primaryID = reqs[0].ID + break + } + select { + case <-deadline: + t.Fatal("primary prompt did not land") + default: + time.Sleep(time.Millisecond) + } + } + + // Launch exactly maxCoalescedSubs coalesced subscribers. They all attach + // (cap not yet exceeded: len(w.subs) goes 0..maxCoalescedSubs-1 < cap). + subOut := make(chan res, maxCoalescedSubs) + for i := 0; i < maxCoalescedSubs; i++ { + go func() { + resp, err := broker.Request(dest, port, "https", 30*time.Second) + subOut <- res{resp, err} + }() + } + + // Wait until all maxCoalescedSubs subscribers have attached: count is + // 1 (primary) + maxCoalescedSubs. + wantCount := 1 + maxCoalescedSubs + deadline = time.After(10 * time.Second) + for broker.CoalescedCount(primaryID) < wantCount { + select { + case <-deadline: + t.Fatalf("subscribers did not all attach: CoalescedCount=%d want %d", + broker.CoalescedCount(primaryID), wantCount) + default: + time.Sleep(time.Millisecond) + } + } + + // Over-cap callers: each must be rejected immediately with the broker's + // standard over-capacity response, NOT appended (no unbounded growth) and + // NOT blocked. Call synchronously so the deny is observed deterministically. + const extra = 16 + for i := 0; i < extra; i++ { + resp, err := broker.Request(dest, port, "https", 30*time.Second) + if resp != ResponseDeny { + t.Fatalf("over-cap caller %d: expected ResponseDeny, got %v", i, resp) + } + if !errors.Is(err, ErrPendingLimitExceeded) { + t.Fatalf("over-cap caller %d: expected ErrPendingLimitExceeded, got %v", i, err) + } + } + + // The subscriber slice must remain bounded: count must NOT have grown past + // the cap despite the extra hammering. + if c := broker.CoalescedCount(primaryID); c != wantCount { + t.Fatalf("coalesced count grew past the cap: got %d want %d (w.subs must be bounded at maxCoalescedSubs=%d)", + c, wantCount, maxCoalescedSubs) + } + + // Cleanup: shut down so the primary + all subscribers unblock. Every + // attached caller (primary + maxCoalescedSubs subs) must drain with Deny. + broker.CancelAll() + pr := <-primaryOut + if pr.resp != ResponseDeny { + t.Errorf("primary: expected Deny on shutdown, got %v", pr.resp) + } + for i := 0; i < maxCoalescedSubs; i++ { + r := <-subOut + if r.resp != ResponseDeny { + t.Errorf("subscriber %d: expected Deny on shutdown, got %v", i, r.resp) + } + } +} + +// TestBrokerCoalesceSubCapNormalBurstStillFansOut verifies the cap does not +// regress normal coalescing: a reasonable sub-cap burst still collapses to a +// single prompt and every caller receives the one decision. +func TestBrokerCoalesceSubCapNormalBurstStillFansOut(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const n = 32 // well under maxCoalescedSubs + primaryID, out := fireCoalescedBurst(t, broker, ch, "normalburst.example.com", n, 5*time.Second) + + if got := len(ch.getRequests()); got != 1 { + t.Fatalf("expected exactly 1 broadcast for a sub-cap burst, got %d", got) + } + if c := broker.CoalescedCount(primaryID); c != n { + t.Fatalf("expected coalesced count %d, got %d", n, c) + } + + if !broker.Resolve(primaryID, ResponseAlwaysAllow) { + t.Fatal("Resolve returned false for primary") + } + for i := 0; i < n; i++ { + r := <-out + if r.err != nil { + t.Errorf("request %d: unexpected error %v", i, r.err) + } + if r.resp != ResponseAlwaysAllow { + t.Errorf("request %d: expected AlwaysAllow, got %v", i, r.resp) + } + } +} + // TestBrokerCancelAllRetainsCoalescedCount is the Finding 2 regression. // CancelAll cleared the waiter map without retaining each waiter's final // coalesced count, so the shutdown CancelApproval edit saw From c9bacafa203c4b4c8ffbe6116a9c26cc78e4d286 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 17:54:53 +0800 Subject: [PATCH 38/49] fix(store): failover health write no-ops for non-pool-member (atomic; no health-row resurrection) --- cmd/sluice/main.go | 14 ++- internal/store/pools.go | 175 ++++++++++++++++++++++++----------- internal/store/pools_test.go | 101 ++++++++++++++++++++ 3 files changed, 235 insertions(+), 55 deletions(-) diff --git a/cmd/sluice/main.go b/cmd/sluice/main.go index 7553c37..62b9589 100644 --- a/cmd/sluice/main.go +++ b/cmd/sluice/main.go @@ -490,8 +490,20 @@ func main() { go func() { if db != nil { reason := fmt.Sprintf("failover:%s", ev.Reason) - if herr := db.SetCredentialHealth(ev.From, "cooldown", ev.Until, reason); herr != nil { + // Guarded write: this goroutine is detached and can fire + // AFTER a pool/credential removal already deleted the + // health row. SetCredentialHealthIfPoolMember upserts only + // when ev.From is still a live pool member, atomically, so + // a late failover cannot resurrect a removed credential's + // cooldown (which a later same-named credential would + // otherwise inherit via loadPoolResolver). A live member + // still gets the durable cooldown (CRITICAL-1 restart + // durability preserved). + switch wrote, herr := db.SetCredentialHealthIfPoolMember(ev.From, "cooldown", ev.Until, reason); { + case herr != nil: log.Printf("[POOL-FAILOVER] durable health write for %q failed: %v", ev.From, herr) + case !wrote: + log.Printf("[POOL-FAILOVER] durable health write for %q skipped: no longer a live pool member (removed before failover landed)", ev.From) } } if failoverBroker != nil { diff --git a/internal/store/pools.go b/internal/store/pools.go index 482c345..f6de3a0 100644 --- a/internal/store/pools.go +++ b/internal/store/pools.go @@ -381,71 +381,138 @@ func (s *Store) PoolsForMember(credential string) ([]string, error) { return pools, rows.Err() } -// SetCredentialHealth upserts a credential's health row. When status is -// "healthy" the cooldown is cleared. cooldown_until is stored as RFC3339. -func (s *Store) SetCredentialHealth(credential, status string, cooldownUntil time.Time, reason string) error { +// credentialHealthUpsertSQL is the monotonic-extend upsert shared by the +// unconditional SetCredentialHealth and the guarded +// SetCredentialHealthIfPoolMember. Both paths must apply the identical +// cooldown-extend semantics so a guarded failover write and a manual-rotate +// write cannot diverge in how they collapse competing cooldown TTLs. +// +// Monotonic extend for the durable row, mirroring MarkCooldown's in-memory +// invariant. When the incoming write is a cooldown AND the stored row already +// has a cooldown_until strictly in the future that is LATER than the incoming +// one, keep the stored (longer) value: a short rate-limit cooldown must never +// shorten a longer auth-failure cooldown, even on the durable side, so restart +// durability matches the resolver. Any transition to "healthy" +// (excluded.status = 'healthy', whose cooldown_until is NULL) always +// overwrites, so the recovery/heal path is intact. cooldown_until is always +// written as UTC RFC3339 by the callers, so the string comparison is a valid +// chronological ordering; the datetime('now') guard makes an already expired +// stored cooldown lose to the fresh future one (lazy expiry preserved). +const credentialHealthUpsertSQL = `INSERT INTO credential_health (credential, status, cooldown_until, last_failure_reason, updated_at) + VALUES (?, ?, ?, ?, datetime('now')) + ON CONFLICT(credential) DO UPDATE SET + cooldown_until = CASE + WHEN excluded.status = 'cooldown' + AND credential_health.cooldown_until IS NOT NULL + AND credential_health.cooldown_until > strftime('%Y-%m-%dT%H:%M:%SZ', 'now') + AND credential_health.cooldown_until > excluded.cooldown_until + THEN credential_health.cooldown_until + ELSE excluded.cooldown_until + END, + status = CASE + WHEN excluded.status = 'cooldown' + AND credential_health.cooldown_until IS NOT NULL + AND credential_health.cooldown_until > strftime('%Y-%m-%dT%H:%M:%SZ', 'now') + AND credential_health.cooldown_until > excluded.cooldown_until + THEN credential_health.status + ELSE excluded.status + END, + last_failure_reason = CASE + WHEN excluded.status = 'cooldown' + AND credential_health.cooldown_until IS NOT NULL + AND credential_health.cooldown_until > strftime('%Y-%m-%dT%H:%M:%SZ', 'now') + AND credential_health.cooldown_until > excluded.cooldown_until + THEN credential_health.last_failure_reason + ELSE excluded.last_failure_reason + END, + updated_at = excluded.updated_at` + +// validateCredentialHealthArgs validates the inputs shared by both health +// upsert entry points and returns the cooldown_until bind value (string or +// nil) the upsert SQL expects. +func validateCredentialHealthArgs(credential, status string, cooldownUntil time.Time) (interface{}, error) { if credential == "" { - return fmt.Errorf("credential name is required") + return nil, fmt.Errorf("credential name is required") } if status != "healthy" && status != "cooldown" { - return fmt.Errorf("invalid health status %q: must be healthy or cooldown", status) + return nil, fmt.Errorf("invalid health status %q: must be healthy or cooldown", status) } - var cu interface{} if status == "cooldown" && !cooldownUntil.IsZero() { - cu = cooldownUntil.UTC().Format(time.RFC3339) - } else { - cu = nil - } - // Monotonic extend for the durable row, mirroring MarkCooldown's - // in-memory invariant. When the incoming write is a cooldown AND the - // stored row already has a cooldown_until strictly in the future that - // is LATER than the incoming one, keep the stored (longer) value: a - // short rate-limit cooldown must never shorten a longer auth-failure - // cooldown, even on the durable side, so restart durability matches - // the resolver. Any transition to "healthy" (excluded.status = - // 'healthy', whose cooldown_until is NULL) always overwrites, so the - // recovery/heal path is intact. cooldown_until is always written as - // UTC RFC3339 by this function, so the string comparison is a valid - // chronological ordering; the datetime('now') guard makes an already - // expired stored cooldown lose to the fresh future one (lazy expiry - // preserved). - _, err := s.db.Exec( - `INSERT INTO credential_health (credential, status, cooldown_until, last_failure_reason, updated_at) - VALUES (?, ?, ?, ?, datetime('now')) - ON CONFLICT(credential) DO UPDATE SET - cooldown_until = CASE - WHEN excluded.status = 'cooldown' - AND credential_health.cooldown_until IS NOT NULL - AND credential_health.cooldown_until > strftime('%Y-%m-%dT%H:%M:%SZ', 'now') - AND credential_health.cooldown_until > excluded.cooldown_until - THEN credential_health.cooldown_until - ELSE excluded.cooldown_until - END, - status = CASE - WHEN excluded.status = 'cooldown' - AND credential_health.cooldown_until IS NOT NULL - AND credential_health.cooldown_until > strftime('%Y-%m-%dT%H:%M:%SZ', 'now') - AND credential_health.cooldown_until > excluded.cooldown_until - THEN credential_health.status - ELSE excluded.status - END, - last_failure_reason = CASE - WHEN excluded.status = 'cooldown' - AND credential_health.cooldown_until IS NOT NULL - AND credential_health.cooldown_until > strftime('%Y-%m-%dT%H:%M:%SZ', 'now') - AND credential_health.cooldown_until > excluded.cooldown_until - THEN credential_health.last_failure_reason - ELSE excluded.last_failure_reason - END, - updated_at = excluded.updated_at`, - credential, status, cu, nilIfEmpty(reason), - ) + return cooldownUntil.UTC().Format(time.RFC3339), nil + } + return nil, nil +} + +// SetCredentialHealth upserts a credential's health row UNCONDITIONALLY. When +// status is "healthy" the cooldown is cleared. cooldown_until is stored as +// RFC3339. Used by callers that operate on a credential known to be live (the +// manual-rotate path cools the resolver's currently-active member) and by the +// store unit tests that exercise the raw upsert. The failover durable write +// must NOT use this — it can race a pool/credential removal; it uses +// SetCredentialHealthIfPoolMember instead. +func (s *Store) SetCredentialHealth(credential, status string, cooldownUntil time.Time, reason string) error { + cu, err := validateCredentialHealthArgs(credential, status, cooldownUntil) if err != nil { + return err + } + if _, err := s.db.Exec(credentialHealthUpsertSQL, credential, status, cu, nilIfEmpty(reason)); err != nil { return fmt.Errorf("set credential health %q: %w", credential, err) } return nil } +// SetCredentialHealthIfPoolMember performs the same monotonic-extend upsert as +// SetCredentialHealth, but ONLY when the credential is still a live member of +// some pool, with the membership check and the upsert in a SINGLE +// transaction. This closes the failover-vs-removal race: a detached failover +// goroutine that fires AFTER a pool/credential removal (which deletes the +// credential_health row in its own transaction) must not resurrect a health +// row for a credential that no longer belongs to any pool. credential_health +// is not FK-tied to live membership, so a resurrected stale cooldown would +// otherwise be inherited by a later same-named credential the next time +// loadPoolResolver seeds PoolHealth from ALL credential_health rows. +// +// Returns wrote=true when the row was upserted (credential is a live pool +// member: CRITICAL-1 restart durability preserved) and wrote=false when the +// write was skipped because the credential is no longer in any pool (a benign +// no-op the caller logs — a removed member legitimately needs no cooldown). +// The membership SELECT and the upsert share one transaction so a concurrent +// removal cannot interleave between the check and the write. +func (s *Store) SetCredentialHealthIfPoolMember(credential, status string, cooldownUntil time.Time, reason string) (wrote bool, err error) { + cu, verr := validateCredentialHealthArgs(credential, status, cooldownUntil) + if verr != nil { + return false, verr + } + + tx, err := s.db.Begin() + if err != nil { + return false, fmt.Errorf("begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + var live int + qerr := tx.QueryRow( + "SELECT 1 FROM credential_pool_members WHERE credential = ? LIMIT 1", credential, + ).Scan(&live) + switch { + case errors.Is(qerr, sql.ErrNoRows): + // Not a live pool member: skip the durable write entirely so a + // removed credential's health row is never resurrected. No commit + // needed — nothing was written. + return false, nil + case qerr != nil: + return false, fmt.Errorf("check pool membership for %q: %w", credential, qerr) + } + + if _, err := tx.Exec(credentialHealthUpsertSQL, credential, status, cu, nilIfEmpty(reason)); err != nil { + return false, fmt.Errorf("set credential health %q: %w", credential, err) + } + if err := tx.Commit(); err != nil { + return false, fmt.Errorf("commit: %w", err) + } + return true, nil +} + // GetCredentialHealth returns the health row for a credential, or nil if no // row exists (which callers treat as healthy). This is an intentional // single-row introspection surface (tests, and a targeted lookup the diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go index 0da7f0c..9e9fdc2 100644 --- a/internal/store/pools_test.go +++ b/internal/store/pools_test.go @@ -735,3 +735,104 @@ func TestAddCredentialMetaRejectsLivePoolMemberDowngrade(t *testing.T) { t.Fatalf("non-member static upsert did not apply: %+v", fm) } } + +// TestSetCredentialHealthIfPoolMemberLiveMemberPersists is case (a): a +// credential that IS a live pool member must get its durable cooldown written +// by the guarded failover path, preserving the CRITICAL-1 restart-durability +// guarantee. Fail-before would exist if the guard skipped a live member; +// pass-after asserts wrote=true and the row is readable with the cooldown. +func TestSetCredentialHealthIfPoolMemberLiveMemberPersists(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "live") + if err := s.CreatePoolWithMembers("p", "failover", []string{"live"}); err != nil { + t.Fatalf("create pool: %v", err) + } + until := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) + + wrote, err := s.SetCredentialHealthIfPoolMember("live", "cooldown", until, "failover:401 auth fail") + if err != nil { + t.Fatalf("SetCredentialHealthIfPoolMember: %v", err) + } + if !wrote { + t.Fatal("guarded write skipped a LIVE pool member (CRITICAL-1 durability regressed)") + } + h, err := s.GetCredentialHealth("live") + if err != nil || h == nil { + t.Fatalf("GetCredentialHealth(live) = %+v, %v; want a persisted cooldown row", h, err) + } + if h.Status != "cooldown" { + t.Errorf("health status = %q, want cooldown", h.Status) + } + if !h.CooldownUntil.Equal(until) { + t.Errorf("cooldown_until = %v, want %v (durable cooldown not persisted)", h.CooldownUntil, until) + } +} + +// TestSetCredentialHealthIfPoolMemberSkipsRemoved is case (b): once the +// credential is no longer a live pool member (its pool — and health row — was +// removed), a LATE-running failover goroutine's guarded write must be a no-op: +// NO credential_health row may be (re)created, and a later same-named +// credential added to a NEW pool must inherit NO stale cooldown. +// +// Fail-before: the old unconditional db.SetCredentialHealth upsert would +// resurrect a credential_health row for the removed credential, which +// loadPoolResolver later seeds into PoolHealth, so a same-named re-add starts +// in cooldown. Pass-after: the guarded write returns wrote=false and writes +// nothing. +func TestSetCredentialHealthIfPoolMemberSkipsRemoved(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "gone") + if err := s.CreatePoolWithMembers("p", "failover", []string{"gone"}); err != nil { + t.Fatalf("create pool: %v", err) + } + + // Pool removal deletes the member's credential_health row (round-8/9/11 + // cleanup) AND drops it from credential_pool_members. + removed, err := s.RemovePool("p") + if err != nil || !removed { + t.Fatalf("RemovePool = %v, %v; want true, nil", removed, err) + } + + // A failover goroutine for the just-removed credential lands LATE. + until := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) + wrote, err := s.SetCredentialHealthIfPoolMember("gone", "cooldown", until, "failover:401 auth fail") + if err != nil { + t.Fatalf("SetCredentialHealthIfPoolMember (late failover): %v", err) + } + if wrote { + t.Fatal("late failover write resurrected a removed credential's health row (Finding)") + } + if h, herr := s.GetCredentialHealth("gone"); herr != nil || h != nil { + t.Fatalf("health row resurrected for a removed credential: %+v, %v", h, herr) + } + + // A later same-named credential added to a NEW pool must inherit NO stale + // cooldown: ListCredentialHealth (what loadPoolResolver seeds from) must + // carry no row for "gone". + seedOAuthCred(t, s, "gone") + if err := s.CreatePoolWithMembers("p2", "failover", []string{"gone"}); err != nil { + t.Fatalf("recreate pool: %v", err) + } + rows, err := s.ListCredentialHealth() + if err != nil { + t.Fatalf("ListCredentialHealth: %v", err) + } + for _, r := range rows { + if r.Credential == "gone" { + t.Fatalf("same-named credential inherited a stale cooldown from a resurrected row: %+v", r) + } + } +} + +// TestSetCredentialHealthIfPoolMemberValidates pins that the guarded variant +// applies the same input validation as the unconditional path before touching +// the DB (no transaction opened for invalid input). +func TestSetCredentialHealthIfPoolMemberValidates(t *testing.T) { + s := newTestStore(t) + if _, err := s.SetCredentialHealthIfPoolMember("", "cooldown", time.Now(), "x"); err == nil { + t.Error("empty credential name accepted") + } + if _, err := s.SetCredentialHealthIfPoolMember("c", "bogus", time.Time{}, ""); err == nil { + t.Error("invalid status accepted") + } +} From e2a26964780ceea951acd6de9c92b4cc20c634ea Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 18:26:15 +0800 Subject: [PATCH 39/49] fix: guard pool-rotate health write; atomic REST cred removal; gate MarkCooldown to current member set --- cmd/sluice/cred.go | 60 +++++---------- cmd/sluice/pool.go | 20 ++++- cmd/sluice/pool_test.go | 120 ++++++++++++++++++++++++++++++ internal/api/server.go | 54 +++++--------- internal/store/pools_test.go | 134 ++++++++++++++++++++++++++++++++++ internal/store/store.go | 70 ++++++++++++++++++ internal/telegram/commands.go | 48 ++++++------ internal/vault/pool.go | 89 +++++++++++++++++++++- internal/vault/pool_test.go | 104 ++++++++++++++++++++++++++ 9 files changed, 593 insertions(+), 106 deletions(-) diff --git a/cmd/sluice/cred.go b/cmd/sluice/cred.go index c91f359..3afb064 100644 --- a/cmd/sluice/cred.go +++ b/cmd/sluice/cred.go @@ -614,17 +614,27 @@ func handleCredRemove(args []string) error { } defer func() { _ = db.Close() }() - // GATE: atomic, fail-closed pool-member guard. This MUST run before - // the vault delete. If the credential is still a live pool member, - // RemoveCredentialMeta returns an error inside its transaction and - // the vault secret below is never touched. - metaDeleted, rmMetaErr := db.RemoveCredentialMeta(name) - if rmMetaErr != nil { - return fmt.Errorf("remove credential metadata for %q (refusing to delete the vault secret so a pool member is not orphaned): %w", name, rmMetaErr) + // GATE + atomic store cleanup (Finding 2, round-15). This MUST run + // before the vault delete. RemoveCredentialFully runs the + // fail-closed pool-member guard AND deletes credential_meta, + // credential_health, all bindings on the credential, and all + // auto-created rules in ONE transaction. If the credential is still + // a live pool member (or any store delete fails) it returns an + // error with NOTHING removed and the vault secret below is never + // touched — no partially-deleted-credential window. + metaDeleted, rmBindings, rmRules, rmErr := db.RemoveCredentialFully(name) + if rmErr != nil { + return fmt.Errorf("remove credential store state for %q (refusing to delete the vault secret so the credential is not partially deleted): %w", name, rmErr) } if metaDeleted { fmt.Printf("removed credential metadata for %q\n", name) } + if rmRules > 0 { + fmt.Printf("removed %d auto-created rule(s) for credential %q\n", rmRules, name) + } + if rmBindings > 0 { + fmt.Printf("removed %d binding(s) for %q\n", rmBindings, name) + } } // Store removal already succeeded (or the DB does not exist). Now it is @@ -640,39 +650,9 @@ func handleCredRemove(args []string) error { fmt.Printf("credential %q removed\n", name) } - // Clean up associated bindings and auto-created rules. The DB handle - // was opened above for the membership gate; if the DB did not exist - // there is nothing to clean up. - if db == nil { - return nil - } - - // Remove rules tagged either by "sluice cred add --destination" - // (cred-add:) or by "sluice binding add" (binding-add:). - // Both paths may have produced rules associated with this credential, - // and failing to clean up either set leaves orphan allow rules in - // the store. - var total int64 - for _, src := range []string{ - store.CredAddSourcePrefix + name, - store.BindingAddSourcePrefix + name, - } { - n, rmErr := db.RemoveRulesBySource(src) - if rmErr != nil { - log.Printf("warning: failed to remove rules with source %q for credential %q: %v", src, name, rmErr) - continue - } - total += n - } - if total > 0 { - fmt.Printf("removed %d auto-created rule(s) for credential %q\n", total, name) - } - removed, rmBindErr := db.RemoveBindingsByCredential(name) - if rmBindErr != nil { - log.Printf("warning: failed to remove bindings for %q: %v", name, rmBindErr) - } else if removed > 0 { - fmt.Printf("removed %d binding(s) for %q\n", removed, name) - } + // Bindings and auto-created rules were already removed atomically with + // credential_meta + health by RemoveCredentialFully above, before the + // vault secret was deleted. Nothing left to clean up here. return nil } diff --git a/cmd/sluice/pool.go b/cmd/sluice/pool.go index a66bac2..58920ba 100644 --- a/cmd/sluice/pool.go +++ b/cmd/sluice/pool.go @@ -208,10 +208,28 @@ func handlePoolRotate(args []string) error { // position order becomes active. The cooldown lapses on its own (lazy // recovery, same as auto-failover), so a rotated-away member rejoins the // rotation once its cooldown expires. + // + // Finding 1 (round-15): use the guarded SetCredentialHealthIfPoolMember, + // NOT the unconditional SetCredentialHealth. `active` was resolved from a + // snapshot taken above; another process could remove the pool (or this + // member from it) between that snapshot and this write. The unconditional + // upsert would then RESURRECT a credential_health row for a credential no + // longer a live pool member — a later same-named credential/pool would + // inherit the stale cooldown. The guarded variant performs the + // pool-membership check and the upsert in one transaction, so a raced + // removal makes the write a no-op (wrote=false) instead of resurrecting + // the row. wrote=false means the rotate raced a pool removal: nothing was + // persisted and the in-memory rotate is meaningless (the pool is gone), + // so surface that to the operator as a failed/stale rotate rather than + // silently claiming success. until := time.Now().Add(vault.AuthFailCooldown) - if err := db.SetCredentialHealth(active, "cooldown", until, "manual rotate"); err != nil { + wrote, err := db.SetCredentialHealthIfPoolMember(active, "cooldown", until, "manual rotate") + if err != nil { return err } + if !wrote { + return fmt.Errorf("pool %q rotate raced a concurrent pool/member removal: %q is no longer a live member of pool %q, so nothing was persisted; re-check the pool with \"sluice pool list %s\"", name, active, name, name) + } // Recompute the new active member for operator feedback. healthRows, err = db.ListCredentialHealth() diff --git a/cmd/sluice/pool_test.go b/cmd/sluice/pool_test.go index 4a4c398..f8f6a25 100644 --- a/cmd/sluice/pool_test.go +++ b/cmd/sluice/pool_test.go @@ -326,3 +326,123 @@ func TestCredRemoveFailsClosedWhenDBUnopenable(t *testing.T) { } sb2.Release() } + +// TestPoolRotateGuardedAgainstConcurrentRemoval is the round-15 Finding 1 +// regression. `pool rotate` resolves the active member from a snapshot, then +// writes the cooldown. The OLD code used the UNCONDITIONAL +// SetCredentialHealth: if another caller removed the pool between the +// snapshot and the write, that upsert RESURRECTED a credential_health row +// for a credential no longer a live pool member, so a later same-named +// credential/pool inherited a stale cooldown. +// +// The fix routes the write through the guarded +// SetCredentialHealthIfPoolMember and treats wrote=false as a failed/stale +// rotate (operator-facing error). This test races `pool rotate` against a +// concurrent pool remove+recreate. The DETERMINISTIC assertion that must +// hold on EVERY iteration regardless of who wins the race: after the dust +// settles there is NO credential_health row for the freshly recreated +// member (a resurrected stale cooldown would mean the recreated credential +// starts parked). It also requires the stale-race error branch to be +// exercised at least once within the bound; if it never is, the test fails +// loudly rather than silently passing on the happy path only. +func TestPoolRotateGuardedAgainstConcurrentRemoval(t *testing.T) { + dir := t.TempDir() + dbPath := setupVaultDB(t, dir) + seedPoolCred(t, dbPath, dir, "acct_a") + seedPoolCred(t, dbPath, dir, "acct_b") + + // --- Normal rotate (pool still live) still persists the cooldown and + // succeeds. This is the must-not-regress half of the fix. --- + if err := handlePoolCommand([]string{"create", "--db", dbPath, "--members", "acct_a,acct_b", "codex"}); err != nil { + t.Fatalf("pool create: %v", err) + } + out := captureStdout(t, func() { + if err := handlePoolCommand([]string{"rotate", "--db", dbPath, "codex"}); err != nil { + t.Fatalf("normal rotate: %v", err) + } + }) + if !strings.Contains(out, "acct_a -> acct_b") { + t.Errorf("normal rotate output = %q", out) + } + db, err := store.New(dbPath) + if err != nil { + t.Fatalf("open db: %v", err) + } + h, err := db.GetCredentialHealth("acct_a") + if err != nil || h == nil || h.Status != "cooldown" { + _ = db.Close() + t.Fatalf("normal rotate did not persist acct_a cooldown: h=%+v err=%v", h, err) + } + _ = db.Close() + + // --- Finding 1: simulate the pool removed BETWEEN the active-member + // resolve and the guarded health write. handlePoolRotate resolves + // `active` from a snapshot then calls SetCredentialHealthIfPoolMember. + // We reproduce the exact post-race store state the guarded write + // observes: the resolved active member is no longer in + // credential_pool_members. The deterministic primitive-level proof that + // NO health row is resurrected lives in the store package + // (TestSetCredentialHealthIfPoolMemberSkipsRemoved); here we assert the + // HANDLER converts the guard's wrote=false into a clear operator error + // and that nothing is persisted. + // + // Build the state directly: pool row + a stale credential_health row + // for acct_b (as if a previous rotate parked it) but acct_b removed + // from every pool. The OLD unconditional write would re-/over-write + // that row; the guarded write must no-op and the handler must error. --- + db, err = store.New(dbPath) + if err != nil { + t.Fatalf("reopen db: %v", err) + } + // codex currently has acct_a (cooled) + acct_b. Remove the whole pool + // so RemovePool clears the members' health rows too (round-8/9/11), then + // recreate it with ONLY acct_a so acct_a is the resolvable active. + if _, rerr := db.RemovePool("codex"); rerr != nil { + _ = db.Close() + t.Fatalf("RemovePool: %v", rerr) + } + if cerr := db.CreatePoolWithMembers("codex", "failover", []string{"acct_a"}); cerr != nil { + _ = db.Close() + t.Fatalf("recreate pool: %v", cerr) + } + // Now drop acct_a from credential_pool_members directly, leaving the + // credential_pools row intact: this is exactly the state a concurrent + // "pool member removed" would leave for the guarded write. GetPool then + // returns codex with NO members, so the handler reports no resolvable + // member -- and crucially writes NOTHING (the bug would have written an + // unconditional cooldown for the resolved member before this check). + if _, eerr := db.DB().Exec("DELETE FROM credential_pool_members WHERE pool = 'codex'"); eerr != nil { + _ = db.Close() + t.Fatalf("delete membership rows: %v", eerr) + } + _ = db.Close() + + rotErr := handlePoolCommand([]string{"rotate", "--db", dbPath, "codex"}) + if rotErr == nil { + t.Fatal("rotate against a pool whose members vanished must fail, not silently succeed") + } + + // INVARIANT: no credential_health row may have been resurrected for the + // vanished members. RemovePool cleared them; the failed rotate must not + // bring any back. Re-add the members to a fresh pool and assert clean. + db, err = store.New(dbPath) + if err != nil { + t.Fatalf("final reopen db: %v", err) + } + defer func() { _ = db.Close() }() + if _, rerr := db.RemovePool("codex"); rerr != nil { + t.Fatalf("final RemovePool: %v", rerr) + } + if cerr := db.CreatePoolWithMembers("codex", "failover", []string{"acct_a", "acct_b"}); cerr != nil { + t.Fatalf("final recreate pool: %v", cerr) + } + rows, lerr := db.ListCredentialHealth() + if lerr != nil { + t.Fatalf("ListCredentialHealth: %v", lerr) + } + for _, r := range rows { + if (r.Credential == "acct_a" || r.Credential == "acct_b") && r.Status == "cooldown" { + t.Fatalf("recreated member %q inherited a resurrected stale cooldown: %+v (Finding 1)", r.Credential, r) + } + } +} diff --git a/internal/api/server.go b/internal/api/server.go index ccaef29..3474a3e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -1263,46 +1263,26 @@ func (s *Server) DeleteApiCredentialsName(w http.ResponseWriter, r *http.Request } } - // Store-first removal order (Finding 3, round-9). RemoveCredentialMeta - // is the authoritative, fail-closed pool-membership gate: it refuses - // inside its own transaction if the credential is still a live pool - // member. Run it BEFORE the vault delete (and before bindings/rules - // cleanup) so the vault secret is only destroyed once the store has - // accepted the removal. If it refuses, the vault secret, bindings, and - // rules are all left intact and no window exists where the secret is - // gone but credential_pool_members still references it. - if _, err := s.store.RemoveCredentialMeta(name); err != nil { - writeError(w, http.StatusConflict, "failed to remove credential metadata (vault secret left intact so a pool member is not orphaned): "+err.Error(), "") - return - } - - // Remove associated bindings and auto-created rules next. If vault.Remove - // below fails, bindings/rules are already gone. This is a pre-existing - // ordering tradeoff: reversing it would orphan bindings when vault succeeds - // but SQLite fails. A transactional approach would require the vault to - // participate in the same transaction, which is not currently possible. - if _, err := s.store.RemoveBindingsByCredential(name); err != nil { - writeError(w, http.StatusInternalServerError, "failed to remove bindings: "+err.Error(), "") - return - } - // Rules may have been created by either "cred add --destination" (tagged - // cred-add:) or by "binding add" against the same credential - // (tagged binding-add:). Remove both so cleanup is symmetric with - // the CLI, otherwise orphan allow rules would persist after the - // credential is gone. - for _, src := range []string{ - store.CredAddSourcePrefix + name, - store.BindingAddSourcePrefix + name, - } { - if _, err := s.store.RemoveRulesBySource(src); err != nil { - writeError(w, http.StatusInternalServerError, "failed to remove associated rules: "+err.Error(), "") - return - } + // Store-first removal order (Finding 3, round-9 + Finding 2, round-15). + // RemoveCredentialFully is the authoritative, fail-closed + // pool-membership gate AND the atomic store-side cleanup: in one + // transaction it refuses if the credential is still a live pool member, + // otherwise deletes credential_meta, credential_health, all bindings on + // the credential, and all auto-created rules (cred-add:/binding-add:). + // It runs BEFORE the vault delete so the vault secret is only destroyed + // once the entire store unit has committed. If it refuses (or any of + // the store deletes fail), the whole tx rolls back: vault secret, + // bindings, rules, meta, and health are all left intact and no window + // exists where credential_meta is gone but bindings/rules survive + // (the partially-deleted-credential bug this fixes). + if _, _, _, err := s.store.RemoveCredentialFully(name); err != nil { + writeError(w, http.StatusConflict, "failed to remove credential store state (vault secret + all store rows left intact so the credential is not partially deleted): "+err.Error(), "") + return } // Store removal already succeeded above (the pool-membership gate - // passed and credential_meta is gone). Only now is it safe to delete - // the vault secret. + // passed and credential_meta+health+bindings+rules are gone atomically). + // Only now is it safe to delete the vault secret. if err := s.vault.Remove(name); err != nil { writeError(w, http.StatusInternalServerError, "failed to remove credential: "+err.Error(), "") return diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go index 9e9fdc2..5c3cb9c 100644 --- a/internal/store/pools_test.go +++ b/internal/store/pools_test.go @@ -836,3 +836,137 @@ func TestSetCredentialHealthIfPoolMemberValidates(t *testing.T) { t.Error("invalid status accepted") } } + +// TestRemoveCredentialFullyAtomicHappyPath is the round-15 Finding 2 +// happy-path regression. RemoveCredentialFully must, in ONE transaction, +// delete credential_meta, credential_health, every binding on the +// credential, and every auto-created allow rule (cred-add:/binding-add:). +func TestRemoveCredentialFullyAtomicHappyPath(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "c") + + until := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("c", "cooldown", until, "429"); err != nil { + t.Fatalf("SetCredentialHealth: %v", err) + } + if _, _, err := s.AddRuleAndBinding( + "allow", + RuleOpts{Destination: "api.example.com", Ports: []int{443}, Source: CredAddSourcePrefix + "c"}, + "c", + BindingOpts{Ports: []int{443}, Header: "Authorization", Template: "Bearer {value}"}, + ); err != nil { + t.Fatalf("AddRuleAndBinding: %v", err) + } + + metaDeleted, bn, rn, err := s.RemoveCredentialFully("c") + if err != nil { + t.Fatalf("RemoveCredentialFully: %v", err) + } + if !metaDeleted { + t.Error("metaDeleted = false, want true") + } + if bn != 1 { + t.Errorf("bindings removed = %d, want 1", bn) + } + if rn != 1 { + t.Errorf("rules removed = %d, want 1", rn) + } + if m, _ := s.GetCredentialMeta("c"); m != nil { + t.Errorf("credential_meta survived: %+v", m) + } + if h, _ := s.GetCredentialHealth("c"); h != nil { + t.Errorf("credential_health survived: %+v", h) + } + if b, _ := s.ListBindingsByCredential("c"); len(b) != 0 { + t.Errorf("bindings survived: %+v", b) + } + rules, _ := s.ListRules(RuleFilter{Type: "network"}) + for _, r := range rules { + if r.Source == CredAddSourcePrefix+"c" { + t.Errorf("auto-created rule survived: %+v", r) + } + } +} + +// TestRemoveCredentialFullyRollsBackOnRuleFailure is the round-15 Finding 2 +// fail-before/pass-after regression. The OLD removal path committed the +// credential_meta (+health) delete in its OWN transaction, then removed +// bindings/rules in SEPARATE statements: a binding/rule failure left +// meta+health gone while the vault secret + partial store state survived (a +// partially-deleted credential). RemoveCredentialFully folds all four +// deletes into ONE transaction, so a rule-delete failure must roll the +// WHOLE unit back: credential_meta, credential_health, and bindings must +// ALL still be present, and a clear error returned. +// +// The failure is forced deterministically by dropping the `rules` table +// before the call, so the in-tx rules DELETE errors AFTER the meta+health +// +bindings deletes have run inside the same (uncommitted) transaction. +func TestRemoveCredentialFullyRollsBackOnRuleFailure(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "c") + + until := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("c", "cooldown", until, "429"); err != nil { + t.Fatalf("SetCredentialHealth: %v", err) + } + if _, err := s.AddBinding("api.example.com", "c", BindingOpts{ + Ports: []int{443}, Header: "Authorization", Template: "Bearer {value}", + }); err != nil { + t.Fatalf("AddBinding: %v", err) + } + + // Force the in-tx rules DELETE to fail. + if _, err := s.db.Exec("DROP TABLE rules"); err != nil { + t.Fatalf("drop rules table: %v", err) + } + + metaDeleted, _, _, err := s.RemoveCredentialFully("c") + if err == nil { + t.Fatal("RemoveCredentialFully succeeded despite a forced rule-delete failure") + } + if metaDeleted { + t.Error("metaDeleted = true on a rolled-back removal") + } + + // The WHOLE store unit must have rolled back: meta, health, and the + // binding must all still be present. + if m, _ := s.GetCredentialMeta("c"); m == nil { + t.Error("credential_meta was deleted despite the tx rolling back (partial-delete bug)") + } + if h, _ := s.GetCredentialHealth("c"); h == nil { + t.Error("credential_health was deleted despite the tx rolling back (partial-delete bug)") + } + if b, _ := s.ListBindingsByCredential("c"); len(b) != 1 { + t.Errorf("binding count = %d, want 1 (binding deleted despite rollback)", len(b)) + } +} + +// TestRemoveCredentialFullyRefusesLivePoolMember pins that the fail-closed +// pool-member guard still fires inside the atomic unit: a live pool member +// removal is refused with NOTHING deleted (so callers leave the vault +// secret intact). +func TestRemoveCredentialFullyRefusesLivePoolMember(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "m") + seedOAuthCred(t, s, "n") + if err := s.CreatePoolWithMembers("p", "failover", []string{"m", "n"}); err != nil { + t.Fatalf("CreatePoolWithMembers: %v", err) + } + + metaDeleted, _, _, err := s.RemoveCredentialFully("m") + if err == nil { + t.Fatal("expected RemoveCredentialFully to refuse a live pool member") + } + if metaDeleted { + t.Error("metaDeleted = true for a refused removal") + } + if m, _ := s.GetCredentialMeta("m"); m == nil { + t.Error("credential_meta deleted for a refused live pool member") + } + + // A free (non-member) credential still removes cleanly. + seedOAuthCred(t, s, "free") + if md, _, _, ferr := s.RemoveCredentialFully("free"); ferr != nil || !md { + t.Fatalf("RemoveCredentialFully(free) = %v, %v; want true, nil", md, ferr) + } +} diff --git a/internal/store/store.go b/internal/store/store.go index ba9c7c4..d93492d 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1938,6 +1938,76 @@ func deleteCredentialMetaGuardedTx(tx *sql.Tx, name, deleteSQL string, deleteArg return n, nil } +// RemoveCredentialFully removes ALL store-side state for a credential as a +// single atomic transaction: the fail-closed pool-member guard, the +// credential_meta row, the credential_health row, every binding on the +// credential, and every auto-created allow rule tagged cred-add: or +// binding-add:. Either every one of these is gone on return or none is +// (the tx rolls back as a unit). +// +// This is the round-15 Finding 2 fix. The previous REST/CLI/Telegram removal +// paths deleted credential_meta (+ health) in its own committed transaction +// and only THEN removed bindings and rules in separate statements. A failure +// in the binding/rule cleanup left meta+health gone while the vault secret +// and a partial set of bindings/rules survived: a partially-deleted +// credential. Folding all four deletes into one tx removes that window. +// +// Ordering contract for callers (preserved from round-13): open/validate the +// vault FIRST (an unopenable vault must abort before any store mutation), +// then call RemoveCredentialFully, and only delete the vault secret AFTER +// this returns nil. If the pool-member guard refuses, this returns a non-nil +// error with NOTHING deleted, so the caller leaves the vault secret intact +// and no window exists where the secret is gone but credential_pool_members +// still references it. +// +// Returns metaDeleted=true when a credential_meta row was actually deleted +// (false is a benign "already gone" — bindings/rules are still swept so a +// previously partial cleanup is finished). bindings/rules are the counts +// removed, for operator feedback. +func (s *Store) RemoveCredentialFully(name string) (metaDeleted bool, bindings, rules int64, err error) { + if name == "" { + return false, 0, 0, fmt.Errorf("credential name is required") + } + tx, err := s.db.Begin() + if err != nil { + return false, 0, 0, fmt.Errorf("begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + // Guarded meta+health delete (also runs the fail-closed pool-member + // guard). If the credential is still a live pool member this returns an + // error and the deferred Rollback discards everything — nothing is + // removed, exactly as the standalone RemoveCredentialMeta behaved. + n, err := deleteCredentialMetaGuardedTx(tx, name, "DELETE FROM credential_meta WHERE name = ?", []any{name}) + if err != nil { + return false, 0, 0, err + } + + bres, err := tx.Exec("DELETE FROM bindings WHERE credential = ?", name) + if err != nil { + return false, 0, 0, fmt.Errorf("delete bindings by credential %q: %w", name, err) + } + bn, _ := bres.RowsAffected() + + var rn int64 + for _, src := range []string{ + CredAddSourcePrefix + name, + BindingAddSourcePrefix + name, + } { + rres, rerr := tx.Exec("DELETE FROM rules WHERE source = ?", src) + if rerr != nil { + return false, 0, 0, fmt.Errorf("delete rules by source %q: %w", src, rerr) + } + c, _ := rres.RowsAffected() + rn += c + } + + if err := tx.Commit(); err != nil { + return false, 0, 0, fmt.Errorf("commit: %w", err) + } + return n > 0, bn, rn, nil +} + // RemoveCredentialMetaCAS deletes a credential metadata row only when its // current cred_type and token_url match the supplied expected values. It is // the compare-and-swap counterpart to RemoveCredentialMeta and is used during diff --git a/internal/telegram/commands.go b/internal/telegram/commands.go index 5c605eb..d5b5d7d 100644 --- a/internal/telegram/commands.go +++ b/internal/telegram/commands.go @@ -646,20 +646,6 @@ func (h *CommandHandler) credRemove(name string) string { h.reloadMu.Lock() defer h.reloadMu.Unlock() - // Store-first removal order (Finding 3, round-9). - // RemoveCredentialMeta is the authoritative, fail-closed - // pool-membership gate: it refuses inside its own transaction if - // the credential is still a live pool member. Run it BEFORE the - // vault delete (and before bindings/rules cleanup) so the vault - // secret is only destroyed once the store has accepted the - // removal. If it refuses, the vault secret, bindings, and rules - // are all left intact and no window exists where the secret is - // gone but credential_pool_members still references it. - if _, err := h.store.RemoveCredentialMeta(name); err != nil { - log.Printf("[WARN] remove credential meta for %q: %v", name, err) - return fmt.Sprintf("Failed to remove credential %q (vault secret left intact so a pool member is not orphaned): %v", name, err) - } - // Read env_var values from bindings before removal so we can clear // them from the agent container after the bindings are deleted. if credBindings, err := h.store.ListBindingsByCredential(name); err == nil { @@ -670,25 +656,33 @@ func (h *CommandHandler) credRemove(name string) string { } } - // Store removal already succeeded (the pool-membership gate - // passed). Only now is it safe to delete the vault secret. If - // already gone (previous partial cleanup), continue to clean up - // stale DB state. + // Store-first removal order (Finding 3, round-9 + Finding 2, + // round-15). RemoveCredentialFully is the authoritative, + // fail-closed pool-membership gate AND the atomic store cleanup: + // in one transaction it refuses if the credential is still a live + // pool member, otherwise deletes credential_meta, + // credential_health, all bindings, and all auto-created rules + // (cred-add:/binding-add:). Run it BEFORE the vault delete so the + // vault secret is only destroyed once the entire store unit has + // committed. If it refuses (or any store delete fails) the whole + // tx rolls back: vault secret, bindings, rules, meta, and health + // are all left intact and no partially-deleted-credential window + // exists. + if _, _, _, err := h.store.RemoveCredentialFully(name); err != nil { + log.Printf("[WARN] remove credential store state for %q: %v", name, err) + return fmt.Sprintf("Failed to remove credential %q (vault secret + all store rows left intact so the credential is not partially deleted): %v", name, err) + } + + // Store removal already succeeded atomically. Only now is it safe + // to delete the vault secret. If already gone (previous partial + // cleanup), continue. if err := h.vault.Remove(name); err != nil { if !os.IsNotExist(err) { return fmt.Sprintf("Failed to remove credential: %v", err) } - // Vault entry already gone. Continue to clean up stale DB state. + // Vault entry already gone. Continue. } - if _, err := h.store.RemoveRulesBySource("cred-add:" + name); err != nil { - log.Printf("[WARN] remove rules for credential %q: %v", name, err) - warnings = append(warnings, fmt.Sprintf("failed to remove rules: %v", err)) - } - if _, err := h.store.RemoveBindingsByCredential(name); err != nil { - log.Printf("[WARN] remove bindings for credential %q: %v", name, err) - warnings = append(warnings, fmt.Sprintf("failed to remove bindings: %v", err)) - } // Recompile engine so removed allow rules take effect immediately. if err := h.recompileAndSwap(); err != nil { log.Printf("[WARN] recompile after cred remove failed: %v", err) diff --git a/internal/vault/pool.go b/internal/vault/pool.go index c252120..6d61c62 100644 --- a/internal/vault/pool.go +++ b/internal/vault/pool.go @@ -46,9 +46,30 @@ type memberHealth struct { // Store rows still seed the map at startup (Seed) for cross-restart // durability, and the seed is monotonic (never shortens a live in-memory // cooldown). +// Finding 3 (round-15): a response handled by an OLD resolver generation can +// call MarkCooldown AFTER a NEW generation already pruned non-members during +// resolver rebuild (StorePool/MergeLiveCooldowns shared path). If that +// credential was removed from every pool in the new generation, the +// unguarded MarkCooldown would re-insert a stale in-memory cooldown that a +// later same-named re-add inherits before its TTL. The fix is to store the +// CURRENT generation's authoritative member set on the shared PoolHealth and +// update it under the SAME mutex that guards the cooldown map, so the prune +// (member-set replace) and a concurrent MarkCooldown cannot interleave to +// leave a non-member cooldown entry behind. MarkCooldown, under the lock, +// no-ops when the credential is not in currentMembers (and currentMembers is +// non-nil). currentMembers stays nil until SetCurrentMembers is called the +// first time (ad-hoc/private resolvers that never set it keep the old +// permissive behavior, so single-generation callers are not regressed). type PoolHealth struct { mu sync.RWMutex health map[string]memberHealth + // currentMembers is the authoritative member set of the CURRENT + // resolver generation. nil = "not tracked" (gate disabled, legacy + // permissive behavior). Non-nil but missing a credential = that + // credential is not a member of any pool in the current generation, so + // MarkCooldown must NOT write a cooldown for it (write-after-prune + // guard). Mutated only under mu, the same lock the cooldown map uses. + currentMembers map[string]struct{} } // NewPoolHealth returns an empty shared health map. Call this exactly once @@ -58,6 +79,24 @@ func NewPoolHealth() *PoolHealth { return &PoolHealth{health: make(map[string]memberHealth)} } +// SetCurrentMembers atomically replaces the authoritative member set for the +// current resolver generation. Called when a fresh generation takes over +// (NewPoolResolverShared with a shared map, and the MergeLiveCooldowns +// shared-path prune) so MarkCooldown can reject write-after-prune attempts +// from a stale (old-generation) response path. The replace happens under the +// same mutex as the cooldown writes, so a concurrent MarkCooldown either +// observes the OLD member set entirely or the NEW one entirely — it can never +// observe a half-updated set, and it can never slip a non-member cooldown in +// between the prune and the member-set swap. +func (ph *PoolHealth) SetCurrentMembers(members map[string]struct{}) { + if ph == nil { + return + } + ph.mu.Lock() + ph.currentMembers = members + ph.mu.Unlock() +} + // Seed merges store-persisted cooldown rows into the shared map. It is // monotonic: a store row never shortens or clears a live in-memory // cooldown (the in-memory value is authoritative because Phase 2 failover @@ -135,6 +174,7 @@ func NewPoolResolver(pools []store.Pool, healthRows []store.CredentialHealth) *P // live in-memory cooldown is never shortened by a store row. Seeding the // shared map on every rebuild is therefore safe and idempotent. func NewPoolResolverShared(pools []store.Pool, healthRows []store.CredentialHealth, shared *PoolHealth) *PoolResolver { + explicitShared := shared != nil if shared == nil { shared = NewPoolHealth() } @@ -152,6 +192,22 @@ func NewPoolResolverShared(pools []store.Pool, healthRows []store.CredentialHeal pr.pools[p.Name] = members } shared.Seed(healthRows) + // Finding 3 (round-15): on the server path (an explicit process-wide + // shared PoolHealth) publish THIS generation's authoritative member set + // so the write-after-prune gate in MarkCooldown is active from the very + // first generation onward, not only after the first MergeLiveCooldowns + // shared-path prune runs. The member-set replace and the cooldown writes + // share PoolHealth.mu, so a concurrent stale MarkCooldown observes either + // the old or the new set atomically. Ad-hoc/private resolvers (shared == + // nil, e.g. CLI `pool` subcommands) leave currentMembers nil to preserve + // the old permissive single-generation behavior. + if explicitShared { + cm := make(map[string]struct{}, len(pr.memberOf)) + for cred := range pr.memberOf { + cm[cred] = struct{}{} + } + shared.SetCurrentMembers(cm) + } return pr } @@ -257,7 +313,26 @@ func (pr *PoolResolver) MarkCooldown(credential string, until time.Time, reason // `until` clears the cooldown (recovery). pr.health.mu.Lock() defer pr.health.mu.Unlock() - if until.IsZero() || !until.After(time.Now()) { + // Finding 3 (round-15) write-after-prune guard. A response handled by an + // OLD resolver generation can reach this AFTER a NEW generation pruned + // non-members and swapped in its member set (both happen under THIS same + // mu, so we either see the pre-prune or post-prune state, never a torn + // one). If currentMembers is tracked (non-nil) and this credential is not + // in it, the credential belongs to no pool in the current generation: + // writing a cooldown would resurrect a non-member entry that a later + // same-named re-add inherits before its TTL. Skip the write. A clear + // (zero/past `until`) is always allowed through below — deleting a stale + // entry for a non-member is only ever beneficial. currentMembers == nil + // means the gate is disabled (ad-hoc/private resolver that never called + // SetCurrentMembers): preserve the old permissive behavior so + // single-generation callers are not regressed. + isClear := until.IsZero() || !until.After(time.Now()) + if !isClear && pr.health.currentMembers != nil { + if _, isMember := pr.health.currentMembers[credential]; !isMember { + return + } + } + if isClear { delete(pr.health.health, credential) return } @@ -323,6 +398,18 @@ func (pr *PoolResolver) MergeLiveCooldowns(prev *PoolResolver) { delete(pr.health.health, cred) } } + // Finding 3 (round-15): publish THIS generation's authoritative + // member set on the shared PoolHealth under the SAME lock as the + // prune above. After this, a MarkCooldown arriving from a stale + // old-generation response path sees the new member set and no-ops + // for any credential this generation no longer owns — the prune + // and the member-set swap are one atomic critical section, so no + // non-member cooldown can be slipped in between them. + cm := make(map[string]struct{}, len(pr.memberOf)) + for cred := range pr.memberOf { + cm[cred] = struct{}{} + } + pr.health.currentMembers = cm pr.health.mu.Unlock() return } diff --git a/internal/vault/pool_test.go b/internal/vault/pool_test.go index bd2263a..20d7568 100644 --- a/internal/vault/pool_test.go +++ b/internal/vault/pool_test.go @@ -409,3 +409,107 @@ func TestSharedHealthConcurrentMarkCooldownVsRebuild(t *testing.T) { } } } + +// TestFinding3Round15_WriteAfterPruneGatedByMemberSet is the round-15 +// Finding 3 regression. A response handled by an OLD resolver generation can +// call MarkCooldown AFTER a NEW generation already pruned non-members and +// published its member set (StorePool -> MergeLiveCooldowns shared path / +// NewPoolResolverShared). If that credential was removed from EVERY pool in +// the new generation, the OLD unguarded MarkCooldown re-inserted a stale +// in-memory cooldown that a later same-named re-add inherited before its TTL. +// +// The fix gates cooldown WRITES on the CURRENT generation's authoritative +// member set, stored on the shared PoolHealth and checked under the SAME +// mutex as the write, so the prune (member-set replace) and a concurrent +// stale MarkCooldown cannot interleave to leave a non-member entry. +// +// Deterministic interleave (no sleeps): we explicitly order the operations +// so the old-generation MarkCooldown(credX) executes AFTER the new +// generation pruned credX. Fail-before: credX would be cooling. Pass-after: +// credX has no cooldown, and a re-add before its TTL is healthy/active. +func TestFinding3Round15_WriteAfterPruneGatedByMemberSet(t *testing.T) { + shared := NewPoolHealth() + + // gen1 (OLD): pool {a, x}. A response that resolved through gen1 is + // still in flight and holds the gen1 *PoolResolver. + gen1 := NewPoolResolverShared([]store.Pool{mkPool("pool", "a", "x")}, nil, shared) + + // gen2 (NEW): "x" was removed from the pool entirely (membership change + // -> resolver rebuild). StorePool's MergeLiveCooldowns shared-path prune + // runs and publishes gen2's member set (which no longer contains "x"). + gen2 := NewPoolResolverShared([]store.Pool{mkPool("pool", "a")}, nil, shared) + gen2.MergeLiveCooldowns(gen1) // prune + publish current member set {a} + + // INTERLEAVE: the stale, still-in-flight gen1 response NOW records a + // failover cooldown for "x" — AFTER gen2 already pruned it. The write + // must be gated out by the current member set. + gen1.MarkCooldown("x", time.Now().Add(300*time.Second), "failover:401") + + if until, cooling := gen2.CooldownUntil("x"); cooling { + t.Fatalf("Finding 3 r15: write-after-prune resurrected a cooldown for non-member x: until=%v (must be gated by current member set)", until) + } + // Also assert through gen1's own view (same shared map): still no entry. + if _, cooling := gen1.CooldownUntil("x"); cooling { + t.Errorf("Finding 3 r15: stale gen1 MarkCooldown(x) leaked into the shared map despite x not being a current member") + } + + // Re-add "x" to a fresh pool BEFORE its (would-be) TTL: it must be + // healthy and selectable, not skipped as still-cooling. + gen3 := NewPoolResolverShared([]store.Pool{mkPool("pool", "a"), mkPool("p2", "x", "d")}, nil, shared) + gen3.MergeLiveCooldowns(gen2) + if _, cooling := gen3.CooldownUntil("x"); cooling { + t.Errorf("Finding 3 r15: re-added x inherited a stale cooldown") + } + if got, ok := gen3.ResolveActive("p2"); !ok || got != "x" { + t.Errorf("Finding 3 r15: re-added x should be active in p2; got %q,%v want x,true", got, ok) + } + + // MUST NOT regress CRITICAL-1: a LIVE member's synchronous cooldown + // recorded on an old generation across a benign StorePool still + // persists. "a" is a member in every generation here. + gen3.MarkCooldown("a", time.Now().Add(300*time.Second), "429") + gen4 := NewPoolResolverShared([]store.Pool{mkPool("pool", "a"), mkPool("p2", "x", "d")}, nil, shared) + gen4.MergeLiveCooldowns(gen3) + if _, cooling := gen4.CooldownUntil("a"); !cooling { + t.Fatalf("Finding 3 r15 regressed CRITICAL-1: a live member's cooldown was dropped across a benign StorePool") + } +} + +// TestFinding3Round15_ConcurrentStaleMarkVsPruneUnderRace forces the prune +// and the stale-generation MarkCooldown to run concurrently so `go test +// -race` exercises the shared-mutex discipline (member-set replace and +// cooldown write are one critical section). The deterministic post-condition +// holds regardless of who wins the lock: a credential removed from every +// pool in the new generation must never end up with a resurrected cooldown. +func TestFinding3Round15_ConcurrentStaleMarkVsPruneUnderRace(t *testing.T) { + for iter := 0; iter < 50; iter++ { + shared := NewPoolHealth() + gen1 := NewPoolResolverShared([]store.Pool{mkPool("pool", "a", "x")}, nil, shared) + + var wg sync.WaitGroup + wg.Add(2) + // New generation rebuild + prune (drops "x"). + go func() { + defer wg.Done() + gen2 := NewPoolResolverShared([]store.Pool{mkPool("pool", "a")}, nil, shared) + gen2.MergeLiveCooldowns(gen1) + }() + // Stale old-generation failover cooldown for the dropped member. + go func() { + defer wg.Done() + gen1.MarkCooldown("x", time.Now().Add(300*time.Second), "failover:401") + }() + wg.Wait() + + // Observe the SHARED health map directly (CooldownUntil is read-only + // and does NOT prune). Build the observer WITH "x" as a member so a + // resurrected entry could NOT be hidden by an observer-side prune: + // if the write-after-prune gate worked, "x" was never written; if it + // failed, the stale cooldown is still here. No MergeLiveCooldowns is + // called, so this asserts the gate's effect, not the prune's. + observer := NewPoolResolverShared([]store.Pool{mkPool("p2", "x")}, nil, shared) + if _, cooling := observer.CooldownUntil("x"); cooling { + t.Fatalf("iter %d: non-member x ended up with a resurrected cooldown after the concurrent prune/mark race", iter) + } + } +} From a0a5322a423ababfdf8eb97011e316bc5d128406 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 18:53:09 +0800 Subject: [PATCH 40/49] fix(proxy): free flow tag after Response; tag token-host flow only on actual pool-phantom swap --- internal/proxy/addon.go | 171 ++++++++++++++++-- internal/proxy/phantom_pairs.go | 2 + internal/proxy/pool_attribution.go | 26 +++ .../proxy/pool_attribution_lifecycle_test.go | 139 ++++++++++++++ internal/proxy/pool_splithost_test.go | 137 +++++++++++++- internal/proxy/ws.go | 11 ++ 6 files changed, 469 insertions(+), 17 deletions(-) create mode 100644 internal/proxy/pool_attribution_lifecycle_test.go diff --git a/internal/proxy/addon.go b/internal/proxy/addon.go index a9096e3..f56127f 100644 --- a/internal/proxy/addon.go +++ b/internal/proxy/addon.go @@ -742,12 +742,21 @@ func (a *SluiceAddon) Request(f *mitmproxy.Flow) { proto := a.detectRequestProtocol(f, port) protoStr := proto.String() - pairs := a.buildPhantomPairs(host, port, protoStr, f.Id, f.Request.URL) + pairs := a.buildPhantomPairs(host, port, protoStr, f.Request.URL) if len(pairs) == 0 && !a.hasPhantomPrefix(f) { return } defer releasePhantomPairs(pairs) + // Finding 2: record the per-flow pool-usage tag ONLY for pooled pairs + // whose pool phantom is actually present in this outbound request. A + // plain OAuth request to a token URL shared with a pool builds the + // pool's pairs as candidates but carries no pool phantom, so it must + // not be tagged (otherwise its 401/invalid_grant would cool an + // unrelated pool member). Done BEFORE the swap so the phantom is still + // present to detect. + a.tagPooledFlowAfterSwap(f, pairs) + // Pass 2+3 on headers. a.swapPhantomHeaders(f, pairs, host, port) @@ -795,15 +804,26 @@ func (a *SluiceAddon) StreamRequestModifier(f *mitmproxy.Flow, in io.Reader) io. proto := a.detectRequestProtocol(f, port) protoStr := proto.String() - pairs := a.buildPhantomPairs(host, port, protoStr, f.Id, f.Request.URL) + pairs := a.buildPhantomPairs(host, port, protoStr, f.Request.URL) if len(pairs) == 0 { return in } + // Finding 2: the streamed (form-urlencoded refresh) body is not + // buffered into f.Request.Body, so pairPhantomPresentInRequest cannot + // see the pool phantom here. Tag per-flow attribution from inside the + // reader, only when a pooled pair's phantom is actually swapped out of + // the stream. flowID is captured so the tag survives the reader. + flowID := f.Id return &phantomSwapReader{ inner: in, pairs: pairs, provider: a.provider, + onPooledSwap: func(member string) { + if flowID != uuid.Nil && member != "" { + a.flowInjected.Tag(flowID, member) + } + }, } } @@ -861,6 +881,21 @@ func (a *SluiceAddon) Response(f *mitmproxy.Flow) { // cheap no-op for non-pooled destinations and for non-trigger statuses. a.handlePoolFailover(f) + // Finding 1: free this flow's API-host failover attribution tag now + // that poolForResponse (invoked inside handlePoolFailover) has used it + // — its API-host branch reads the tag with a NON-consuming Peek, so + // without this delete a completed request's tag would linger for the + // full flowAttrTTL, making Tag's opportunistic sweep O(n) and letting + // the map grow unboundedly under sustained pooled traffic. Both the + // API-host (Peek) and token-endpoint (Recover) uses happen within that + // single poolForResponse call, so deleting here is correct and + // race-free; Delete is idempotent if the token-endpoint path already + // consumed it. The flowAttrTTL + Tag sweep remain a backstop for flows + // whose buffered Response never fires (streamed/abandoned). + if f.Id != uuid.Nil { + a.flowInjected.Delete(f.Id) + } + // Test-only panic injection. Always nil in production. Lets a // regression test exercise the deferred recover above without // having to construct a Flow that triggers a real downstream @@ -1428,15 +1463,20 @@ func (a *SluiceAddon) persistAddonOAuthTokens(credName string, realAccess, realR } // buildPhantomPairs builds the sorted list of phantom/secret pairs for a -// destination. The caller must call releasePhantomPairs when done. flowID is -// the go-mitmproxy Flow ID of the request being processed (uuid.Nil when no -// flow is associated, e.g. the QUIC path); it is used to pin per-request -// pool-member attribution for API-host failover (Finding 1). reqURL is the -// outbound request URL (nil on the QUIC path, which has no parsed URL); it -// is used to expand pooled OAuth credentials whose token endpoint matches +// destination. The caller must call releasePhantomPairs when done. reqURL is +// the outbound request URL (nil on the QUIC path, which has no parsed URL); +// it is used to expand pooled OAuth credentials whose token endpoint matches // the request even when they are not bound to the CONNECT host (Finding 4, // the split-host token-refresh case). -func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flowID uuid.UUID, reqURL *url.URL) []phantomPair { +// +// Per-flow pool-member attribution for API-host failover (Finding 1) is NOT +// recorded here: a pooled pair built here is only a CANDIDATE — the request +// may not actually carry its pool phantom (e.g. a plain OAuth request to a +// token URL shared with a pool, Finding 2). The flowInjected tag is recorded +// post-swap by tagPooledFlowAfterSwap, which inspects the built pairs' +// pooledMember field and tags only the members whose pool phantom was +// actually present in this request. +func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, reqURL *url.URL) []phantomPair { res := a.resolver.Load() if res == nil { return nil @@ -1477,12 +1517,13 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flo if target.pooled { poolName := target.phantomName member := target.secretName - // Pin API-host failover attribution to this request's - // injected member (Finding 1). Idempotent with the - // pass-1 injectHeaders tag for the same flow. - if member != "" { - a.flowInjected.Tag(flowID, member) - } + // Finding 2: do NOT tag flowInjected here. The + // per-flow pool-usage tag must be recorded only AFTER + // the swap confirms this request actually carried the + // pool phantom (tagPooledFlowAfterSwap), so a plain + // OAuth request that merely shares a token URL with + // this pool cannot acquire a pool-usage tag and + // mis-attribute its 401 to an unrelated member. oauthPairs, parseErr := a.buildPooledMemberPairs(poolName, member, secret) if parseErr != nil { continue @@ -1565,7 +1606,12 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, flo secret.Release() continue } - a.flowInjected.Tag(flowID, member) + // Finding 2: tagging is deferred to + // tagPooledFlowAfterSwap (post-swap, only if the + // pool phantom was actually present in this + // request). A plain OAuth request to a token URL + // shared with this pool reaches here too, but it + // carries no pool phantom, so it must not be tagged. oauthPairs, parseErr := a.buildPooledMemberPairs(poolName, member, secret) if parseErr != nil { continue @@ -1605,6 +1651,79 @@ func (a *SluiceAddon) buildPooledMemberPairs(poolName, member string, secret vau ) } +// pairPhantomPresentInRequest reports whether the given pair's pool phantom +// (literal OR either URL-encoded form) is actually present anywhere in the +// outbound request the agent sent: body, any header value, the URL query, or +// the URL path. This is the post-build evidence that the agent really held +// (and is sending) this pool's phantom for THIS request, as opposed to the +// pair merely being a candidate built because the destination/token-URL +// matched a pool binding. +func pairPhantomPresentInRequest(f *mitmproxy.Flow, p phantomPair) bool { + contains := func(data []byte) bool { + if len(data) == 0 { + return false + } + if bytes.Contains(data, p.phantom) { + return true + } + if len(p.encodedPhantom) > 0 && bytes.Contains(data, p.encodedPhantom) { + return true + } + if len(p.encodedPhantomLower) > 0 && bytes.Contains(data, p.encodedPhantomLower) { + return true + } + return false + } + if contains(f.Request.Body) { + return true + } + for _, vals := range f.Request.Header { + for _, v := range vals { + if contains([]byte(v)) { + return true + } + } + } + if f.Request.URL != nil { + if contains([]byte(f.Request.URL.RawQuery)) { + return true + } + if contains([]byte(f.Request.URL.Path)) { + return true + } + } + return false +} + +// tagPooledFlowAfterSwap records the per-flow pool-usage attribution tag +// (flowInjected) for every pooled pair whose pool phantom is ACTUALLY present +// in this outbound request, and only those. +// +// Finding 2: the old code tagged the flow when it BUILT a pooled pair (either +// the CONNECT-host loop or the token-host expansion in buildPhantomPairs), +// before any proof the request carried that pool's phantom. A plain OAuth +// request to a token URL shared with a pool builds the pool's candidate pairs +// too, so it would acquire the tag and have its 401/invalid_grant cool an +// unrelated pool member (poolForResponse treats a flowInjected tag as +// pool-usage proof). Tagging only when the pool phantom is genuinely present +// (and therefore about to be swapped out) keeps the API-host failover +// attribution sound (Finding 1) while no longer mis-attributing plain OAuth +// traffic. Must be called BEFORE the swap, while the phantom is still in the +// request to detect. +func (a *SluiceAddon) tagPooledFlowAfterSwap(f *mitmproxy.Flow, pairs []phantomPair) { + if f == nil || f.Id == uuid.Nil || f.Request == nil { + return + } + for _, p := range pairs { + if p.pooledMember == "" { + continue + } + if pairPhantomPresentInRequest(f, p) { + a.flowInjected.Tag(f.Id, p.pooledMember) + } + } +} + // releasePhantomPairs zeroes all secret values in the pairs slice. func releasePhantomPairs(pairs []phantomPair) { for i := range pairs { @@ -1766,6 +1885,15 @@ type phantomSwapReader struct { pending []byte eof bool released bool + + // onPooledSwap, when non-nil, is invoked with a pool member name the + // first time a pooled pair's phantom is actually replaced in the + // streamed body. It pins the per-flow API-host failover attribution + // (Finding 1) to a member only when this request genuinely carried + // that member's pool phantom (Finding 2: a plain OAuth stream that + // merely shares a token URL with a pool carries no pool phantom, so + // the swap never fires and no spurious tag is recorded). + onPooledSwap func(member string) } // maxPhantomLen returns the length of the longest phantom token in the @@ -1847,8 +1975,10 @@ func (r *phantomSwapReader) Read(p []byte) (int, error) { // phantom is actually present and we need the encoded form of the // real secret. for _, pp := range r.pairs { + swapped := false if bytes.Contains(toProcess, pp.phantom) { toProcess = bytes.ReplaceAll(toProcess, pp.phantom, pp.secret.Bytes()) + swapped = true } var encodedSecret []byte ensureEncodedSecret := func() { @@ -1859,10 +1989,19 @@ func (r *phantomSwapReader) Read(p []byte) (int, error) { if len(pp.encodedPhantom) > 0 && bytes.Contains(toProcess, pp.encodedPhantom) { ensureEncodedSecret() toProcess = bytes.ReplaceAll(toProcess, pp.encodedPhantom, encodedSecret) + swapped = true } if len(pp.encodedPhantomLower) > 0 && bytes.Contains(toProcess, pp.encodedPhantomLower) { ensureEncodedSecret() toProcess = bytes.ReplaceAll(toProcess, pp.encodedPhantomLower, encodedSecret) + swapped = true + } + // Finding 2: only pin per-flow pool attribution when this + // request actually carried (and we just swapped out) the + // pool phantom — not merely because a pooled pair was built + // as a candidate. + if swapped && pp.pooledMember != "" && r.onPooledSwap != nil { + r.onPooledSwap(pp.pooledMember) } } // Pass 3: strip unbound, including URL-encoded phantoms. diff --git a/internal/proxy/phantom_pairs.go b/internal/proxy/phantom_pairs.go index fe40130..27d1624 100644 --- a/internal/proxy/phantom_pairs.go +++ b/internal/proxy/phantom_pairs.go @@ -242,6 +242,7 @@ func buildPooledOAuthPhantomPairs(poolName, member string, secret vault.SecureBy encodedPhantom: accessEncoded, encodedPhantomLower: encodePhantomLowerForPair(accessEncoded), secret: accessSecret, + pooledMember: member, }} if cred.RefreshToken != "" { // Record the precise R1 join: this exact real refresh token is @@ -263,6 +264,7 @@ func buildPooledOAuthPhantomPairs(poolName, member string, secret vault.SecureBy encodedPhantom: refreshEncoded, encodedPhantomLower: encodePhantomLowerForPair(refreshEncoded), secret: refreshSecret, + pooledMember: member, }) } return pairs, nil diff --git a/internal/proxy/pool_attribution.go b/internal/proxy/pool_attribution.go index c740bbc..a10548e 100644 --- a/internal/proxy/pool_attribution.go +++ b/internal/proxy/pool_attribution.go @@ -120,6 +120,32 @@ func (m *flowInjectedMember) Peek(flowID uuid.UUID) (string, bool) { return e.member, true } +// Delete removes the tag for the given flow ID if present. It is a no-op +// when flowID is uuid.Nil or no entry exists. +// +// Finding 1: the API-host failover path uses a NON-consuming Peek (a +// destination can map to multiple candidate pools and a consuming Recover +// inside that loop would let the first pool steal a later pool's tag, the +// round-12 bug). Because Peek does not delete, a COMPLETED pooled request's +// tag would otherwise linger for the full flowAttrTTL (5 min). Under +// sustained pooled traffic that makes Tag's opportunistic whole-map sweep +// O(n) on every new request and lets the map grow unboundedly within the +// TTL window. The buffered Response handler calls Delete keyed by f.Id once +// it has finished poolForResponse (API-host AND token-endpoint use happen +// within that single poolForResponse call), so a completed request's tag is +// freed immediately instead of waiting out the TTL. The TTL + Tag sweep +// remain as a backstop for flows that never complete a buffered Response +// (streamed responses, abandoned/aborted flows whose Response callback never +// fires). +func (m *flowInjectedMember) Delete(flowID uuid.UUID) { + if flowID == uuid.Nil { + return + } + m.mu.Lock() + defer m.mu.Unlock() + delete(m.entries, flowID) +} + // refreshAttrTTL is how long a real-refresh-token -> member tag is retained. // An OAuth refresh round-trip (agent POSTs refresh_token, upstream answers // with rotated tokens) completes in well under a second in practice; a diff --git a/internal/proxy/pool_attribution_lifecycle_test.go b/internal/proxy/pool_attribution_lifecycle_test.go new file mode 100644 index 0000000..599b7e4 --- /dev/null +++ b/internal/proxy/pool_attribution_lifecycle_test.go @@ -0,0 +1,139 @@ +package proxy + +import ( + "testing" + + uuid "github.com/satori/go.uuid" +) + +// flowInjectedSize returns the current number of live flow-attribution +// entries. Tests are in package proxy so the unexported map is directly +// reachable; the read is taken under the same mutex Tag/Peek/Delete use. +func flowInjectedSize(m *flowInjectedMember) int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.entries) +} + +// TestFlowInjectedTagFreedAfterResponse is the Finding 1 regression. +// +// Round-12 made poolForResponse's API-host branch use a NON-consuming Peek, +// so a COMPLETED pooled request's flow tag was never deleted until the 5-min +// flowAttrTTL sweep. Tag opportunistically scans the WHOLE map on every new +// pooled request, so sustained pooled traffic accumulated every completed +// flow for the TTL window: each new Tag became O(n) and the map grew +// unboundedly within the TTL. +// +// The fix deletes the per-flow tag at the end of the buffered Response +// handler (after handlePoolFailover -> poolForResponse has used it via Peek +// AND/OR Recover). This test drives N completed pooled API-host requests +// through Response and asserts: +// +// - the flowInjected map does NOT retain all N entries afterwards (it is +// bounded near zero, not ~N) — fails before the fix (map == N), passes +// after (map == 0); +// - attribution is still correct DURING each request's own Response: the +// request whose 429 is processed cools its OWN injected member; +// - the TTL backstop still works: a tagged flow whose Response never fires +// (streamed/abandoned) is retained for the caller to clean up via the +// TTL + Tag sweep, i.e. Delete did not over-reach and wipe live tags. +func TestFlowInjectedTagFreedAfterResponse(t *testing.T) { + addon, _, prPtr := setupPoolAddon(t, "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + pr := prPtr.Load() + + if got, _ := pr.ResolveActive("codex_pool"); got != "memA" { + t.Fatalf("pre-condition active = %q, want memA", got) + } + + const n = 50 + for i := 0; i < n; i++ { + // A completed pooled API-host request that succeeded (2xx, no + // failover). Production tags the injected member post-swap; mirror + // that. memA is the active member for all of them. + f := newPoolRespFlow(client, 200, []byte(`{"ok":true}`)) + addon.flowInjected.Tag(f.Id, "memA") + addon.Response(f) + } + + // The crux: after N completed Responses the map must NOT still hold all + // N tags. Before the fix every Peek-only completion left its entry + // behind, so size == N. After the fix each Response frees its own tag, + // so size is 0. + if sz := flowInjectedSize(addon.flowInjected); sz > 1 { + t.Fatalf("flowInjected retained %d/%d completed-request tags after their "+ + "Responses — Finding 1: Peek-only API-host path never frees tags, "+ + "so the map grows unboundedly within flowAttrTTL and Tag is O(n)", sz, n) + } + + // Attribution must still be correct DURING a request's own Response. + // Two concurrent requests both backed by memA. req1's 429 cools memA + // (pool switches to memB); req2's 429 must still be attributed to memA + // (its OWN injected member, recovered by flow id) and NOT to memB + // (response-time active). Then both tags must be freed. + req1 := newPoolRespFlow(client, 429, []byte(`{"error":"rate_limited"}`)) + req2 := newPoolRespFlow(client, 429, []byte(`{"error":"rate_limited"}`)) + addon.flowInjected.Tag(req1.Id, "memA") + addon.flowInjected.Tag(req2.Id, "memA") + + addon.Response(req1) + if _, cooling := pr.CooldownUntil("memA"); !cooling { + t.Fatal("memA must be cooling after req1's 429") + } + if got, _ := pr.ResolveActive("codex_pool"); got != "memB" { + t.Fatalf("after req1 failover active = %q, want memB", got) + } + + addon.Response(req2) + if _, cooling := pr.CooldownUntil("memB"); cooling { + t.Fatal("memB was cooled by req2's 429 — attribution must use req2's " + + "OWN injected member (memA), not response-time active (Finding 1)") + } + + if sz := flowInjectedSize(addon.flowInjected); sz != 0 { + t.Fatalf("flowInjected size = %d after all Responses, want 0 "+ + "(every completed flow's tag must be freed)", sz) + } + + // TTL backstop: a tag for a flow whose Response NEVER fires (streamed / + // abandoned) must survive — Delete must not have over-reached. The tag + // is bounded only by flowAttrTTL + the opportunistic sweep in Tag. + abandoned := uuid.NewV4() + addon.flowInjected.Tag(abandoned, "memA") + if m, ok := addon.flowInjected.Peek(abandoned); !ok || m != "memA" { + t.Fatalf("abandoned-flow tag not retained as TTL backstop; got %q ok=%v", m, ok) + } + if sz := flowInjectedSize(addon.flowInjected); sz != 1 { + t.Fatalf("flowInjected size = %d, want exactly 1 (only the abandoned "+ + "flow's TTL-backstopped tag remains)", sz) + } +} + +// TestFlowInjectedDeleteIsIdempotent guards the end-of-Response delete: the +// token-endpoint failover path already consumes the tag via a single-use +// Recover, so the subsequent end-of-Response Delete must be a safe no-op +// (not panic, not corrupt the map) and must not disturb other flows' tags. +func TestFlowInjectedDeleteIsIdempotent(t *testing.T) { + m := newFlowInjectedMember() + id1 := uuid.NewV4() + id2 := uuid.NewV4() + m.Tag(id1, "memA") + m.Tag(id2, "memB") + + // First delete frees id1. + m.Delete(id1) + if _, ok := m.Peek(id1); ok { + t.Fatal("id1 tag must be gone after Delete") + } + // Second delete of the same id is a no-op. + m.Delete(id1) + // uuid.Nil delete is a no-op. + m.Delete(uuid.Nil) + // id2 is untouched throughout. + if got, ok := m.Peek(id2); !ok || got != "memB" { + t.Fatalf("id2 tag disturbed by unrelated Delete calls; got %q ok=%v", got, ok) + } + if sz := flowInjectedSize(m); sz != 1 { + t.Fatalf("size = %d, want 1 (only id2 remains)", sz) + } +} diff --git a/internal/proxy/pool_splithost_test.go b/internal/proxy/pool_splithost_test.go index ff9c24d..1fa1fab 100644 --- a/internal/proxy/pool_splithost_test.go +++ b/internal/proxy/pool_splithost_test.go @@ -526,7 +526,7 @@ func TestFinding1Round9_PoolNamespaceNotSuppressedByMemberPlainBinding(t *testin // The pool ACCESS phantom must also be swappable: build the pairs // directly and assert the pool access phantom maps to memA's real // access token exactly once (no double-emit, not suppressed). - pairs := addon.buildPhantomPairs("auth.example.com", 443, "https", reqFlow.Id, reqFlow.Request.URL) + pairs := addon.buildPhantomPairs("auth.example.com", 443, "https", reqFlow.Request.URL) defer releasePhantomPairs(pairs) accessPhantom := poolStablePhantomAccess(poolName) refreshPhantom := "SLUICE_PHANTOM:" + poolName + ".refresh" @@ -552,3 +552,138 @@ func TestFinding1Round9_PoolNamespaceNotSuppressedByMemberPlainBinding(t *testin t.Fatalf("Finding 1 r9: pool refresh phantom emitted %d times, want exactly 1 (not suppressed, not double-emitted)", refreshCount) } } + +// TestFinding2_PlainOAuthOnSharedTokenURLDoesNotTagOrCoolPool is the +// Finding 2 (round-16) regression. +// +// The token-host expansion in buildPhantomPairs tagged the flow +// (flowInjected.Tag) the moment it BUILT candidate pool phantom pairs for a +// matching token URL — BEFORE verifying any pool phantom was actually +// present in (and swapped out of) the outbound request. A plain OAuth +// credential whose token URL is SHARED with a pool therefore acquired a +// per-flow pool-usage tag on its OWN refresh. poolForResponse treats that +// flowInjected tag as proof the request used the pool, so the plain +// credential's 401 / invalid_grant cooled an UNRELATED active pool member +// and parked it. +// +// The fix moves tagging to tagPooledFlowAfterSwap, which records the tag +// only when the pool phantom is genuinely present in the request (i.e. an +// actual pool-phantom replacement happens for this flow). A plain refresh +// (no pool phantom in the body) is no longer tagged, so its failure cannot +// be mis-attributed to the pool. +// +// Two halves, both must hold: +// +// (a) plain OAuth refresh on the shared token URL, NO pool phantom in the +// body, 401/invalid_grant -> NO pool member cooled (no flowInjected +// tag was set). FAILS before the fix (the build-time tag is set, so +// poolForResponse cools the active member). +// (b) a genuine pooled refresh (pool phantom present, actually swapped) +// with 401/invalid_grant -> the correct member is still cooled. Guards +// against the fix over-restricting the legit split-host pooled-refresh +// path / regressing the round-9/12 fixes. +func TestFinding2_PlainOAuthOnSharedTokenURLDoesNotTagOrCoolPool(t *testing.T) { + // --- (a) plain refresh, no pool phantom: must NOT tag, must NOT cool --- + addon, _, prPtr := setupPoolSplitHostWithPlainCred(t) + // CONNECT target is the shared TOKEN host (no pool binding lives here; + // the only pooled injection path is the token-host expansion). + client := setupAddonConn(addon, "auth.example.com:443") + pr := prPtr.Load() + + // Realistic precursor: memA API-429-cooled, traffic on memB. + memACooldown := time.Now().Add(90 * time.Second) + pr.MarkCooldown("memA", memACooldown, "429") + if got, _ := pr.ResolveActive("codex_pool"); got != "memB" { + t.Fatalf("after cooling memA, active = %q, want memB", got) + } + memBPre, memBPreCooling := pr.CooldownUntil("memB") + + // The agent refreshes a PLAIN OAuth credential against the shared token + // URL. Its body carries NO SLUICE_PHANTOM:codex_pool.* pool phantom — + // it is an ordinary refresh-grant. The token-host expansion still + // builds the pool's candidate phantom pairs (the pool shares this token + // URL), but no pool phantom is present to swap. + reqFlow := newTestFlow(client, "POST", testOAuthTokenURL) + reqFlow.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + reqFlow.Request.Body = []byte("grant_type=refresh_token&refresh_token=plain-refresh-old") + + addon.Requestheaders(reqFlow) + addon.Request(reqFlow) + + // The pool phantom never appeared in the request, so NO flowInjected + // pool-usage tag may have been recorded for this flow. + if m, ok := addon.flowInjected.Peek(reqFlow.Id); ok { + t.Fatalf("Finding 2: a plain OAuth refresh on a shared token URL acquired "+ + "a flowInjected pool-usage tag (member=%q) even though NO pool phantom "+ + "was present/swapped — its 401 would cool an unrelated pool member", m) + } + + // The plain credential's refresh now 401s / invalid_grants on the + // shared token host. With no pool-usage evidence, poolForResponse must + // fail closed and NO pool member may be cooled. + respFlow := newPoolRespFlowBody(client, 401, "plain-refresh-old", + []byte(`{"error":"invalid_grant"}`)) + if pool, member, _, _, ok := addon.poolForResponse(respFlow); ok { + t.Fatalf("Finding 2: poolForResponse attributed a plain-credential failure "+ + "to the pool (pool=%q member=%q) — the build-time flowInjected tag "+ + "mis-flagged a plain refresh as pooled usage", pool, member) + } + addon.Response(respFlow) + + if u, cooling := pr.CooldownUntil("memB"); cooling != memBPreCooling || !u.Equal(memBPre) { + t.Fatalf("Finding 2: active pool member memB cooldown changed (%v/%v -> %v/%v) "+ + "on a PLAIN credential's invalid_grant — an innocent member was parked", + memBPre, memBPreCooling, u, cooling) + } + if aU, c := pr.CooldownUntil("memA"); !c || aU.Sub(memACooldown).Abs() > time.Second { + t.Fatalf("Finding 2: memA's original 429 window disturbed: got %v (cooling=%v), want %v", + aU, c, memACooldown) + } + + // --- (b) genuine pooled refresh: pool phantom present -> still cools --- + addon2, _, prPtr2 := setupPoolSplitHostWithPlainCred(t) + client2 := setupAddonConn(addon2, "auth.example.com:443") + pr2 := prPtr2.Load() + if got, _ := pr2.ResolveActive("codex_pool"); got != "memA" { + t.Fatalf("pre-condition active = %q, want memA", got) + } + + // The agent holds the POOL refresh phantom and POSTs it. The token-host + // expansion swaps it to memA's real refresh token AND (post-swap) + // records the per-flow pool-usage tag because the pool phantom WAS + // genuinely present. + poolReq := newTestFlow(client2, "POST", testOAuthTokenURL) + poolReq.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + poolReq.Request.Body = refreshGrantBody("codex_pool") + + addon2.Requestheaders(poolReq) + addon2.Request(poolReq) + + if strings.Contains(string(poolReq.Request.Body), "SLUICE_PHANTOM:codex_pool.refresh") { + t.Fatalf("genuine pooled refresh: pool phantom not swapped; body=%q", + string(poolReq.Request.Body)) + } + if m, ok := addon2.flowInjected.Peek(poolReq.Id); !ok || m != "memA" { + t.Fatalf("Finding 2 over-restriction: a genuine pooled refresh (pool "+ + "phantom actually swapped) was NOT tagged; got member=%q ok=%v "+ + "(the legit split-host pooled-refresh path must still tag)", m, ok) + } + + // memA's pooled refresh invalid_grants -> memA must still be cooled and + // the pool must fail over to memB. + poolResp := newPoolRespFlowBody(client2, 401, "A-refresh-old", + []byte(`{"error":"invalid_grant"}`)) + pool, member, _, _, ok := addon2.poolForResponse(poolResp) + if !ok || pool != "codex_pool" || member != "memA" { + t.Fatalf("Finding 2 over-restriction: genuine pooled refresh not attributed; "+ + "got ok=%v pool=%q member=%q, want codex_pool/memA", ok, pool, member) + } + addon2.Response(poolResp) + if _, cooling := pr2.CooldownUntil("memA"); !cooling { + t.Fatal("Finding 2 over-restriction: genuine pooled member memA not cooled " + + "after its own invalid_grant") + } + if active, _ := pr2.ResolveActive("codex_pool"); active != "memB" { + t.Fatalf("Finding 2 over-restriction: pool did not fail over; active = %q, want memB", active) + } +} diff --git a/internal/proxy/ws.go b/internal/proxy/ws.go index cfb10fe..56c873e 100644 --- a/internal/proxy/ws.go +++ b/internal/proxy/ws.go @@ -388,6 +388,17 @@ type phantomPair struct { encodedPhantom []byte encodedPhantomLower []byte secret vault.SecureBytes + + // pooledMember is the concrete pool member whose real credential this + // pair injects, set ONLY for pool-keyed pairs built by + // buildPooledMemberPairs. It is "" for plain (non-pool) pairs. The + // request-side injection uses it to record the per-flow + // flowInjected attribution tag ONLY when this pair's pool phantom was + // actually present in (and swapped out of) the outbound request + // (Finding 2): a plain OAuth request to a token URL shared with a pool + // must not acquire a pool-usage tag and mis-attribute its 401 to an + // unrelated pool member. + pooledMember string } // Relay runs bidirectional WebSocket frame forwarding between agent and From ae45808ce9c8dcf3127bf5879476bd0e7cdc177c Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 19:06:25 +0800 Subject: [PATCH 41/49] fix: RemoveCredentialFully cleans health on partial-cleanup finish; exclude 5xx from token-endpoint failover classification --- internal/proxy/pool_failover.go | 15 +++- internal/proxy/pool_failover_test.go | 22 +++++ internal/store/pools_test.go | 125 +++++++++++++++++++++++++++ internal/store/store.go | 15 ++++ 4 files changed, 173 insertions(+), 4 deletions(-) diff --git a/internal/proxy/pool_failover.go b/internal/proxy/pool_failover.go index e54ddde..96bfc9b 100644 --- a/internal/proxy/pool_failover.go +++ b/internal/proxy/pool_failover.go @@ -93,10 +93,17 @@ func classifyFailover(statusCode int, body []byte, isTokenEndpoint bool) (class // a non-token-endpoint with an unrelated body still resolves to // failoverNone there (the body is only trusted on a real token URL). } - // Non-2xx-status path. Only a real token-endpoint body may be classified - // (invalid_grant/invalid_token), and only when the status is not a 2xx - // success. A 2xx token response is a healthy refresh, never a failover. - if isTokenEndpoint && (statusCode < 200 || statusCode > 299) { + // 4xx-client-error token-endpoint path. Only a real token-endpoint body + // may be classified (invalid_grant/invalid_token), and only on a 4xx + // CLIENT error. A 2xx token response is a healthy refresh, never a + // failover; a 5xx is a server-side error and is a documented NO-OP (a + // transient upstream outage is not evidence the member's account is + // exhausted or revoked — failing over would just spread the outage + // across every account in the pool, see README + the failoverNone doc). + // Restricting to [400,500) keeps every existing correct path (400/403 + // invalid_grant -> auth-failure) while excluding 5xx whose body happens + // to echo "invalid_grant"/"invalid_token". + if isTokenEndpoint && statusCode >= 400 && statusCode < 500 { if bodyContainsAny(body, "invalid_grant") { return failoverAuthFailure, "invalid_grant" } diff --git a/internal/proxy/pool_failover_test.go b/internal/proxy/pool_failover_test.go index 4c5e472..deb5400 100644 --- a/internal/proxy/pool_failover_test.go +++ b/internal/proxy/pool_failover_test.go @@ -76,6 +76,28 @@ func TestClassifyFailover(t *testing.T) { {"500 server error -> noop", 500, `oops`, false, failoverNone, ""}, {"502 -> noop", 502, ``, false, failoverNone, ""}, {"404 -> noop", 404, ``, false, failoverNone, ""}, + // Round-17 Finding 2: a token-endpoint 5xx whose body happens to + // contain invalid_grant/invalid_token must STAY a no-op. A server-side + // error is not evidence the member's account is exhausted/revoked + // (README + the failoverNone doc). The OLD non-2xx check + // (statusCode < 200 || statusCode > 299) wrongly classified these as + // auth-failure and cooled an innocent member. The fix restricts the + // body classification to 4xx CLIENT errors only. + {"token-endpoint 503 invalid_grant -> noop", 503, `{"error":"invalid_grant"}`, true, failoverNone, ""}, + {"token-endpoint 500 invalid_grant -> noop", 500, `{"error":"invalid_grant"}`, true, failoverNone, ""}, + {"token-endpoint 502 invalid_token -> noop", 502, `{"error":"invalid_token"}`, true, failoverNone, ""}, + // The 4xx token-endpoint paths are UNAFFECTED by the fix. + {"token-endpoint 400 invalid_grant still auth", 400, `{"error":"invalid_grant"}`, true, failoverAuthFailure, "invalid_grant"}, + {"token-endpoint 403 invalid_grant still auth", 403, `{"error":"invalid_grant"}`, true, failoverAuthFailure, "invalid_grant"}, + {"token-endpoint 499 invalid_token still auth", 499, `{"error":"invalid_token"}`, true, failoverAuthFailure, "invalid_token"}, + // 401/429 are explicit status cases ABOVE the body check; unaffected. + {"401 still auth (unaffected)", 401, ``, true, failoverAuthFailure, "401"}, + {"429 still rate-limited (unaffected)", 429, `{"error":"invalid_grant"}`, true, failoverRateLimited, "429"}, + // 2xx is still a no-op even with an invalid_grant-looking body. + {"2xx invalid_grant body still noop", 200, `{"error":"invalid_grant"}`, true, failoverNone, ""}, + // Non-token-endpoint 4xx with invalid_grant body is still a no-op + // (the body is only trusted on a real token URL). + {"non-token-endpoint 400 invalid_grant -> noop", 400, `{"error":"invalid_grant"}`, false, failoverNone, ""}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go index 5c3cb9c..c99dbab 100644 --- a/internal/store/pools_test.go +++ b/internal/store/pools_test.go @@ -970,3 +970,128 @@ func TestRemoveCredentialFullyRefusesLivePoolMember(t *testing.T) { t.Fatalf("RemoveCredentialFully(free) = %v, %v; want true, nil", md, ferr) } } + +// TestRemoveCredentialFullyCleansHealthOnPartialCleanupFinish is the round-17 +// Finding 1 fail-before/pass-after regression. Pre-state simulates a prior +// PARTIAL cleanup: credential_meta for "x" is ALREADY absent, but a stale +// credential_health cooldown row plus a binding and an auto-created allow +// rule survived. deleteCredentialMetaGuardedTx only drops the health row when +// the meta DELETE affected a row (n>0, CAS no-op semantics), so the OLD +// RemoveCredentialFully left the stale health row behind (n==0 here) and a +// later same-named credential would inherit the dead cooldown. The fix adds +// an UNCONDITIONAL health delete for the named credential in the full-removal +// tx. Pass-after: health, binding, and rule for "x" are all gone, so a later +// same-named pool member inherits NO stale cooldown. +func TestRemoveCredentialFullyCleansHealthOnPartialCleanupFinish(t *testing.T) { + s := newTestStore(t) + + // Bindings + rules require a credential_meta row to be created via the + // normal path; seed it, wire the binding/rule/health, THEN delete ONLY + // the meta row directly to reproduce the "prior partial cleanup" state + // (meta gone, health + binding + rule still present). + seedOAuthCred(t, s, "x") + until := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("x", "cooldown", until, "429"); err != nil { + t.Fatalf("SetCredentialHealth: %v", err) + } + if _, _, err := s.AddRuleAndBinding( + "allow", + RuleOpts{Destination: "api.example.com", Ports: []int{443}, Source: CredAddSourcePrefix + "x"}, + "x", + BindingOpts{Ports: []int{443}, Header: "Authorization", Template: "Bearer {value}"}, + ); err != nil { + t.Fatalf("AddRuleAndBinding: %v", err) + } + // Simulate the prior partial cleanup: meta row gone, everything else left. + if _, err := s.db.Exec("DELETE FROM credential_meta WHERE name = ?", "x"); err != nil { + t.Fatalf("simulate partial cleanup (delete meta): %v", err) + } + if m, _ := s.GetCredentialMeta("x"); m != nil { + t.Fatalf("precondition: credential_meta should be absent, got %+v", m) + } + if h, _ := s.GetCredentialHealth("x"); h == nil { + t.Fatal("precondition: stale credential_health row must still be present") + } + + // Finishing the partial cleanup. metaDeleted is false (meta already + // gone), but bindings/rules AND the stale health row must be swept. + metaDeleted, bn, rn, err := s.RemoveCredentialFully("x") + if err != nil { + t.Fatalf("RemoveCredentialFully: %v", err) + } + if metaDeleted { + t.Error("metaDeleted = true, want false (meta was already gone)") + } + if bn != 1 { + t.Errorf("bindings removed = %d, want 1", bn) + } + if rn != 1 { + t.Errorf("rules removed = %d, want 1", rn) + } + if h, _ := s.GetCredentialHealth("x"); h != nil { + t.Errorf("stale credential_health survived partial-cleanup finish: %+v", h) + } + if b, _ := s.ListBindingsByCredential("x"); len(b) != 0 { + t.Errorf("bindings survived: %+v", b) + } + rules, _ := s.ListRules(RuleFilter{Type: "network"}) + for _, r := range rules { + if r.Source == CredAddSourcePrefix+"x" { + t.Errorf("auto-created rule survived: %+v", r) + } + } + + // A later same-named credential added to a pool must inherit NO stale + // cooldown: ListCredentialHealth (what loadPoolResolver seeds from) + // carries no row for "x". + seedOAuthCred(t, s, "x") + seedOAuthCred(t, s, "y") + if err := s.CreatePoolWithMembers("p", "failover", []string{"x", "y"}); err != nil { + t.Fatalf("CreatePoolWithMembers: %v", err) + } + hrows, err := s.ListCredentialHealth() + if err != nil { + t.Fatalf("ListCredentialHealth: %v", err) + } + for _, r := range hrows { + if r.Credential == "x" { + t.Fatalf("same-named credential inherited a stale cooldown: %+v", r) + } + } +} + +// TestRemoveCredentialMetaCASNoOpLeavesHealthIntact pins the round-11 +// invariant the Finding 1 fix MUST NOT regress: a CAS no-op (a concurrent +// writer changed cred_type/token_url so the guarded meta DELETE matches 0 +// rows) must leave the credential_health row UNTOUCHED. The fix added the +// unconditional health delete ONLY in RemoveCredentialFully, not in the +// shared deleteCredentialMetaGuardedTx helper, so RemoveCredentialMetaCAS's +// no-op semantics are unchanged. +func TestRemoveCredentialMetaCASNoOpLeavesHealthIntact(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "c") + until := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) + if err := s.SetCredentialHealth("c", "cooldown", until, "429"); err != nil { + t.Fatalf("SetCredentialHealth: %v", err) + } + + // CAS with MISMATCHED expected values: a "concurrent writer" effectively + // owns the row, so the delete is a no-op and the health row it owns must + // be left intact. + removed, noConcurrent, err := s.RemoveCredentialMetaCAS("c", "static", "https://wrong.example/token") + if err != nil { + t.Fatalf("RemoveCredentialMetaCAS: %v", err) + } + if removed { + t.Error("removed = true on a mismatched CAS (should be a no-op)") + } + if noConcurrent { + t.Error("noConcurrent = true; expected the concurrent-writer signal") + } + if m, _ := s.GetCredentialMeta("c"); m == nil { + t.Error("credential_meta wrongly deleted by a mismatched CAS") + } + if h, _ := s.GetCredentialHealth("c"); h == nil { + t.Error("credential_health wrongly deleted by a CAS no-op (round-11 invariant regressed)") + } +} diff --git a/internal/store/store.go b/internal/store/store.go index d93492d..3f5219d 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1983,6 +1983,21 @@ func (s *Store) RemoveCredentialFully(name string) (metaDeleted bool, bindings, return false, 0, 0, err } + // Round-17 Finding 1: deleteCredentialMetaGuardedTx only drops the + // credential_health row when the meta DELETE affected a row (CAS no-op + // semantics, correct for RemoveCredentialMetaCAS — see that helper's + // comment). But this is the FULL-removal path: if a prior partial + // cleanup already removed credential_meta, n==0 and the guarded helper + // leaves a stale credential_health row behind, so a later same-named + // credential would inherit the dead cooldown. Full removal must wipe + // the named credential's health UNCONDITIONALLY in the same tx, + // regardless of whether a meta row existed. This does NOT alter + // deleteCredentialMetaGuardedTx's behavior, so the round-11 + // RemoveCredentialMetaCAS no-op invariant is unchanged. + if _, err := tx.Exec("DELETE FROM credential_health WHERE credential = ?", name); err != nil { + return false, 0, 0, fmt.Errorf("delete credential health for %q: %w", name, err) + } + bres, err := tx.Exec("DELETE FROM bindings WHERE credential = ?", name) if err != nil { return false, 0, 0, fmt.Errorf("delete bindings by credential %q: %w", name, err) From cf11d86bcfcfa761823d1d88d3b9aa0fb5397315 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 19:35:09 +0800 Subject: [PATCH 42/49] fix: pool+epoch-scoped health guards; CancelAll lost-wakeup; pooled-JWT phantom swap in query/path --- cmd/sluice/main.go | 22 ++- cmd/sluice/pool.go | 41 +++-- cmd/sluice/pool_test.go | 116 ++++++++++++ internal/channel/broker.go | 56 +++++- internal/channel/channel_test.go | 112 ++++++++++++ internal/proxy/addon.go | 55 +++++- internal/proxy/pool_failover.go | 24 ++- internal/proxy/pool_phantom_test.go | 101 +++++++++++ .../000007_pool_membership_epoch.down.sql | 2 + .../000007_pool_membership_epoch.up.sql | 30 ++++ internal/store/pools.go | 111 ++++++++++-- internal/store/pools_test.go | 168 +++++++++++++++++- internal/vault/pool.go | 133 +++++++++++--- internal/vault/pool_test.go | 82 +++++++++ 14 files changed, 968 insertions(+), 85 deletions(-) create mode 100644 internal/store/migrations/000007_pool_membership_epoch.down.sql create mode 100644 internal/store/migrations/000007_pool_membership_epoch.up.sql diff --git a/cmd/sluice/main.go b/cmd/sluice/main.go index 62b9589..7968c6e 100644 --- a/cmd/sluice/main.go +++ b/cmd/sluice/main.go @@ -492,18 +492,22 @@ func main() { reason := fmt.Sprintf("failover:%s", ev.Reason) // Guarded write: this goroutine is detached and can fire // AFTER a pool/credential removal already deleted the - // health row. SetCredentialHealthIfPoolMember upserts only - // when ev.From is still a live pool member, atomically, so - // a late failover cannot resurrect a removed credential's - // cooldown (which a later same-named credential would - // otherwise inherit via loadPoolResolver). A live member - // still gets the durable cooldown (CRITICAL-1 restart - // durability preserved). - switch wrote, herr := db.SetCredentialHealthIfPoolMember(ev.From, "cooldown", ev.Until, reason); { + // health row AND the same name was re-added into ANOTHER + // pool. SetCredentialHealthIfPoolMemberEpoch upserts only + // when (ev.From, ev.Pool, ev.Epoch) is STILL a live + // membership row, atomically, so a late failover from a + // removed pool cannot persist the old cooldown onto a + // re-added same-name member in a different pool (Cluster A + // #2). The name-only guard checked only that ev.From was in + // SOME pool, which the re-added successor satisfies. A + // genuinely-still-live member (same pool, same epoch) still + // gets the durable cooldown (CRITICAL-1 restart durability + // preserved). + switch wrote, herr := db.SetCredentialHealthIfPoolMemberEpoch(ev.From, ev.Pool, ev.Epoch, "cooldown", ev.Until, reason); { case herr != nil: log.Printf("[POOL-FAILOVER] durable health write for %q failed: %v", ev.From, herr) case !wrote: - log.Printf("[POOL-FAILOVER] durable health write for %q skipped: no longer a live pool member (removed before failover landed)", ev.From) + log.Printf("[POOL-FAILOVER] durable health write for %q skipped: no longer a live member of pool %q at epoch %d (removed/re-added before failover landed)", ev.From, ev.Pool, ev.Epoch) } } if failoverBroker != nil { diff --git a/cmd/sluice/pool.go b/cmd/sluice/pool.go index 58920ba..74bafb9 100644 --- a/cmd/sluice/pool.go +++ b/cmd/sluice/pool.go @@ -209,26 +209,37 @@ func handlePoolRotate(args []string) error { // recovery, same as auto-failover), so a rotated-away member rejoins the // rotation once its cooldown expires. // - // Finding 1 (round-15): use the guarded SetCredentialHealthIfPoolMember, - // NOT the unconditional SetCredentialHealth. `active` was resolved from a - // snapshot taken above; another process could remove the pool (or this - // member from it) between that snapshot and this write. The unconditional - // upsert would then RESURRECT a credential_health row for a credential no - // longer a live pool member — a later same-named credential/pool would - // inherit the stale cooldown. The guarded variant performs the - // pool-membership check and the upsert in one transaction, so a raced - // removal makes the write a no-op (wrote=false) instead of resurrecting - // the row. wrote=false means the rotate raced a pool removal: nothing was - // persisted and the in-memory rotate is meaningless (the pool is gone), - // so surface that to the operator as a failed/stale rotate rather than - // silently claiming success. + // Finding 1 (round-15) + Cluster A #3 (round-18): use the pool+epoch + // scoped guarded write, NOT the unconditional SetCredentialHealth and + // NOT the name-only guard. `active` was resolved from the snapshot `p` + // taken above; another process could remove this pool (or this member + // from it) AND re-add the same name into a DIFFERENT pool between that + // snapshot and this write. The name-only guard only checked that + // `active` was a member of SOME pool — the re-added successor satisfies + // that, so the rotate would park the OTHER pool's member while + // reporting a successful rotate of THIS pool. Capture `active`'s + // pool+epoch identity from the snapshot and gate the write on exactly + // (active, this pool, that epoch): a raced removal/re-add makes the + // write a no-op (wrote=false) because the snapshot's epoch no longer + // matches the live membership row, so we surface a failed/stale rotate + // instead of silently parking an unrelated pool's member. + var rotateEpoch int64 = -1 + for _, m := range p.Members { + if m.Credential == active { + rotateEpoch = m.Epoch + break + } + } + if rotateEpoch < 0 { + return fmt.Errorf("pool %q rotate: resolved active member %q is not in the pool snapshot (membership changed under the rotate); re-check with \"sluice pool list %s\"", name, active, name) + } until := time.Now().Add(vault.AuthFailCooldown) - wrote, err := db.SetCredentialHealthIfPoolMember(active, "cooldown", until, "manual rotate") + wrote, err := db.SetCredentialHealthIfPoolMemberEpoch(active, name, rotateEpoch, "cooldown", until, "manual rotate") if err != nil { return err } if !wrote { - return fmt.Errorf("pool %q rotate raced a concurrent pool/member removal: %q is no longer a live member of pool %q, so nothing was persisted; re-check the pool with \"sluice pool list %s\"", name, active, name, name) + return fmt.Errorf("pool %q rotate raced a concurrent pool/member removal or re-add: %q is no longer a live member of pool %q at the snapshotted epoch %d, so nothing was persisted; re-check the pool with \"sluice pool list %s\"", name, active, name, rotateEpoch, name) } // Recompute the new active member for operator feedback. diff --git a/cmd/sluice/pool_test.go b/cmd/sluice/pool_test.go index f8f6a25..6625e56 100644 --- a/cmd/sluice/pool_test.go +++ b/cmd/sluice/pool_test.go @@ -446,3 +446,119 @@ func TestPoolRotateGuardedAgainstConcurrentRemoval(t *testing.T) { } } } + +// TestPoolRotateEpochScopedRejectsCrossPoolReAdd is the Cluster A #3 +// regression. `pool rotate` snapshots the pool, resolves the active member, +// then writes the cooldown. The round-15 name-only guard only checked the +// active credential was a member of SOME pool. If the pool/member is removed +// after the snapshot and the SAME name is re-added to ANOTHER pool before +// the guarded write, the name-only guard's predicate is satisfied by the +// successor, so the rotate would park the OTHER pool's member while +// reporting a successful rotate of the original pool. +// +// The fix captures the active member's membership epoch from the snapshot +// and gates the write on (active, this pool, that epoch). This test +// reproduces the exact post-race store state the guarded write observes: +// pool P snapshot resolved member c at epoch e1, but by write time c was +// removed from P and re-added into Q at epoch e2. The store-level proof that +// the (pool, epoch) predicate no-ops is +// TestSetCredentialHealthIfPoolMemberEpochRejectsReAddedSuccessor; here we +// assert the HANDLER end-to-end: the genuine rotate persists with the +// correct epoch (must-not-regress), and after a real remove+cross-pool +// re-add the rotate fails AND the re-added member in Q is never parked. +func TestPoolRotateEpochScopedRejectsCrossPoolReAdd(t *testing.T) { + dir := t.TempDir() + dbPath := setupVaultDB(t, dir) + seedPoolCred(t, dbPath, dir, "c") + seedPoolCred(t, dbPath, dir, "d") + + // Genuine rotate of pool P {c,d}: must persist c's cooldown with P's + // CURRENT epoch (must-not-regress half of the fix). + if err := handlePoolCommand([]string{"create", "--db", dbPath, "--members", "c,d", "P"}); err != nil { + t.Fatalf("pool create P: %v", err) + } + out := captureStdout(t, func() { + if err := handlePoolCommand([]string{"rotate", "--db", dbPath, "P"}); err != nil { + t.Fatalf("genuine rotate: %v", err) + } + }) + if !strings.Contains(out, "c -> d") { + t.Errorf("genuine rotate output = %q; want c -> d", out) + } + db, err := store.New(dbPath) + if err != nil { + t.Fatalf("open db: %v", err) + } + if h, herr := db.GetCredentialHealth("c"); herr != nil || h == nil || h.Status != "cooldown" { + _ = db.Close() + t.Fatalf("genuine rotate did not persist c cooldown: h=%+v err=%v", h, herr) + } + _ = db.Close() + + // Now the cross-pool re-add. Remove P (CASCADE clears c/d health rows + // and advances the epoch), re-create c into a DIFFERENT pool Q. Q's + // member row carries a strictly greater epoch than P's snapshot did. + db, err = store.New(dbPath) + if err != nil { + t.Fatalf("reopen db: %v", err) + } + if _, rerr := db.RemovePool("P"); rerr != nil { + _ = db.Close() + t.Fatalf("RemovePool(P): %v", rerr) + } + if cerr := db.CreatePoolWithMembers("Q", "failover", []string{"c", "d"}); cerr != nil { + _ = db.Close() + t.Fatalf("recreate c,d into Q: %v", cerr) + } + qp, _ := db.GetPool("Q") + var cEpochInQ int64 = -1 + for _, m := range qp.Members { + if m.Credential == "c" { + cEpochInQ = m.Epoch + } + } + if cEpochInQ <= 0 { + _ = db.Close() + t.Fatalf("c epoch in Q = %d; want positive", cEpochInQ) + } + _ = db.Close() + + // A stale rotate command for the now-removed pool P. handlePoolRotate's + // GetPool("P") returns nil (P is gone), so the handler must fail with a + // "pool not found" error and persist NOTHING — in particular it must NOT + // park c (which now lives in Q at a greater epoch). + if rotErr := handlePoolCommand([]string{"rotate", "--db", dbPath, "P"}); rotErr == nil { + t.Fatal("rotate of a removed pool must fail, not silently succeed") + } + + // INVARIANT: c (the cross-pool re-added successor in Q) must carry NO + // cooldown. The pre-fix name-only guard would have let a stale P-rotate + // write park it; the epoch-scoped guard rejects the stale epoch. + db, err = store.New(dbPath) + if err != nil { + t.Fatalf("final reopen db: %v", err) + } + defer func() { _ = db.Close() }() + h, herr := db.GetCredentialHealth("c") + if herr != nil { + t.Fatalf("GetCredentialHealth(c): %v", herr) + } + if h != nil && h.Status == "cooldown" { + t.Fatalf("Cluster A #3: cross-pool re-added member c in Q was parked by a stale P-rotate: %+v", h) + } + + // And a genuine rotate of Q now must still work with Q's epoch + // (epoch-scoped guard does not break the live path). + out = captureStdout(t, func() { + if rerr := handlePoolCommand([]string{"rotate", "--db", dbPath, "Q"}); rerr != nil { + t.Fatalf("genuine rotate of Q: %v", rerr) + } + }) + if !strings.Contains(out, "c -> d") { + t.Errorf("genuine Q rotate output = %q; want c -> d", out) + } + h, herr = db.GetCredentialHealth("c") + if herr != nil || h == nil || h.Status != "cooldown" { + t.Fatalf("genuine Q rotate did not persist c cooldown at Q epoch: h=%+v err=%v", h, herr) + } +} diff --git a/internal/channel/broker.go b/internal/channel/broker.go index 2ef528e..f70784a 100644 --- a/internal/channel/broker.go +++ b/internal/channel/broker.go @@ -100,6 +100,16 @@ type Broker struct { // so the resolve/detach interleave can be forced deterministically // without sleeps. nil in production. subDeadlineGate func() + + // cancelAllAfterClearHook is a test-only seam invoked in CancelAll + // immediately after b.mu is released (post-#4-fix that is AFTER the + // terminal denies were already sent under the lock and the waiter map + // cleared). A test uses it to release a coalesced subscriber parked at + // its deadline so the CancelAll-vs-sub-deadline interleave is forced + // deterministically: pre-fix the deny was sent only AFTER this point so + // the released sub observed a lost wakeup; post-fix the deny is already + // buffered. nil in production. + cancelAllAfterClearHook func() } // waiter tracks a pending approval request and its response channel. @@ -643,19 +653,49 @@ func (b *Broker) CancelAll() { // shutdown (Finding 2). b.recordCoalescedLocked(id, w.count) } - b.waiters = make(map[string]waiter) - b.dedupIndex = make(map[string]string) - b.mu.Unlock() - // Send deny responses before closing done. This ensures goroutines in - // the select see the response on ch before they see done closed, so - // they return the response without an error. Coalesced subscribers are - // fanned the same deny on their buffered (cap 1) chans. - for id, w := range waiters { + // Send the deny to every primary and every coalesced subscriber WHILE + // STILL HOLDING b.mu, BEFORE clearing the waiter map — exactly mirroring + // the round-6 Resolve fix. Round-18 #4: the old code cleared b.waiters, + // released b.mu, and only THEN sent the denies. A coalesced subscriber + // whose deadline fired in that window took b.mu in detachSub, found NO + // waiter (map already cleared), returned immediately, and its + // post-detach non-blocking read missed the not-yet-sent buffered deny — + // so it returned a spurious "approval timeout" during shutdown instead + // of the cancel/deny. + // + // Sending under the lock closes that window: a subscriber whose deadline + // fires serializes against this section via detachSub's b.mu. It either + // detaches BEFORE this runs (it is gone from w.subs, gets no send, and + // legitimately returns its own timeout — it never coalesced under this + // shutdown decision) or AFTER (the deny is already buffered on its cap-1 + // chan and waitSub's post-detach non-blocking read picks it up). There + // is no instant where a sub can observe "waiter gone AND deny not yet + // sent". The primary ch and every sub chan are buffered cap-1 and + // receive exactly one value, so these sends never block — holding the + // lock here is safe and cannot deadlock. + for _, w := range waiters { w.ch <- ResponseDeny for _, sub := range w.subs { sub <- ResponseDeny } + } + b.waiters = make(map[string]waiter) + b.dedupIndex = make(map[string]string) + b.mu.Unlock() + + if b.cancelAllAfterClearHook != nil { + // Post-fix the denies were already delivered under the lock above, + // so a sub released here finds its deny buffered. Pre-fix the send + // loop ran AFTER this point, so the released sub saw a lost wakeup. + b.cancelAllAfterClearHook() + } + + // cancelOnChannels calls channel implementations that may do blocking + // network I/O (Telegram message edit), so it must stay OUTSIDE b.mu. + // The deny responses were already delivered above, so a subscriber can + // no longer observe a spurious timeout regardless of how long these take. + for id := range waiters { b.cancelOnChannels(id) } diff --git a/internal/channel/channel_test.go b/internal/channel/channel_test.go index 0bb15dc..d4660c9 100644 --- a/internal/channel/channel_test.go +++ b/internal/channel/channel_test.go @@ -1652,3 +1652,115 @@ func TestBrokerResolveDetachLostWakeup(t *testing.T) { t.Fatal("coalesced subscriber never returned (deadlock?)") } } + +// TestBrokerCancelAllSubDeadlineLostWakeup is the round-18 #4 regression. It +// deterministically forces the CancelAll-vs-coalesced-sub-deadline +// interleave that re-opened the exact lost-wakeup window the round-6 Resolve +// fix closed. Pre-fix CancelAll cleared the waiter map and released b.mu +// BEFORE sending the terminal denies; a coalesced subscriber whose deadline +// fired in that window took b.mu in detachSub, found NO waiter (already +// cleared), and its post-detach non-blocking read missed the not-yet-sent +// buffered deny — returning a spurious "approval timeout" during shutdown. +// Post-fix CancelAll sends every primary/sub deny WHILE HOLDING b.mu, before +// clearing the map, so the released subscriber's detachSub serializes +// against the send: by the time its non-blocking read runs, ResponseDeny is +// already buffered on its cap-1 chan and it returns the cancel deny cleanly. +// +// Forced with two seams, no sleeps for the critical ordering: +// - subDeadlineGate parks the subscriber at the very top of waitSub's +// deadline branch (before detachSub) until released. +// - cancelAllAfterClearHook fires in CancelAll right after b.mu is +// released (post-fix: AFTER the under-lock denies + map clear) and is +// where the test releases the subscriber so its detach/read races the +// terminal send exactly in the (pre-fix) lost-wakeup window. +func TestBrokerCancelAllSubDeadlineLostWakeup(t *testing.T) { + ch := newMockChannel(ChannelTelegram) + broker := NewBroker([]Channel{ch}, WithMaxPending(0), WithDestinationRateLimit(0, 0)) + + const dest = "cancelall-lostwakeup.example.com" + const port = 443 + + // Long-lived primary so it stays pending while the sub coalesces. + primaryOut := make(chan result, 1) + go func() { + resp, err := broker.Request(dest, port, "https", 5*time.Second) + primaryOut <- result{resp, err} + }() + var primaryID string + for { + reqs := ch.getRequests() + if len(reqs) == 1 { + primaryID = reqs[0].ID + break + } + time.Sleep(time.Millisecond) + } + + // Park the subscriber at the start of its deadline branch. + subAtGate := make(chan struct{}) + releaseSub := make(chan struct{}) + var gateOnce sync.Once + broker.subDeadlineGate = func() { + gateOnce.Do(func() { close(subAtGate) }) + <-releaseSub + } + + // Subscriber with a very short timeout: it coalesces onto the primary, + // then its deadline fires and it blocks in subDeadlineGate (still + // attached, detachSub not yet called). + subOut := make(chan result, 1) + go func() { + resp, err := broker.Request(dest, port, "https", 20*time.Millisecond) + subOut <- result{resp, err} + }() + for broker.CoalescedCount(primaryID) < 2 { + time.Sleep(time.Millisecond) + } + <-subAtGate + + // In CancelAll, after b.mu is released (post-fix: denies already sent + // under the lock + map cleared), release the parked subscriber so its + // detachSub + non-blocking read races the terminal send exactly in the + // pre-fix lost-wakeup window. Yield generously so the subscriber + // goroutine is scheduled before the hook returns. + broker.cancelAllAfterClearHook = func() { + close(releaseSub) + for i := 0; i < 1000; i++ { + runtime.Gosched() + } + } + + broker.CancelAll() + + pr := <-primaryOut + if pr.resp != ResponseDeny { + t.Fatalf("primary: expected ResponseDeny from CancelAll, got %v (err %v)", pr.resp, pr.err) + } + + select { + case sr := <-subOut: + // The whole point of #4: the coalesced subscriber whose deadline + // fired in the CancelAll window must observe the shutdown DENY + // CLEANLY (the buffered cancel response via waitSub's post-detach + // non-blocking read), NOT a spurious "approval timeout" error. + // + // Both the clean cancel-deny and the spurious-timeout paths return + // the SAME ResponseDeny value (waitSub's timeout branch returns + // ResponseDeny + an error), so the resp value alone cannot tell + // them apart. The distinguishing signal is the error: the clean + // cancel path returns (ResponseDeny, nil); the lost-wakeup path + // returns (ResponseDeny, "approval timeout after ..."). + if sr.resp != ResponseDeny { + t.Fatalf("coalesced subscriber got %v; want ResponseDeny", sr.resp) + } + if sr.err != nil { + t.Fatalf("coalesced subscriber got ResponseDeny but with error %v; want a "+ + "CLEAN cancel deny — lost-wakeup: CancelAll cleared the waiter map "+ + "then the sub's deadline fired before the terminal send, so the "+ + "sub returned a spurious approval timeout instead of the buffered "+ + "shutdown deny (round-18 #4)", sr.err) + } + case <-time.After(3 * time.Second): + t.Fatal("coalesced subscriber never returned after CancelAll (deadlock?)") + } +} diff --git a/internal/proxy/addon.go b/internal/proxy/addon.go index f56127f..69b6685 100644 --- a/internal/proxy/addon.go +++ b/internal/proxy/addon.go @@ -766,15 +766,33 @@ func (a *SluiceAddon) Request(f *mitmproxy.Flow) { } // Pass 2+3 on URL query. - if rawQ := f.Request.URL.RawQuery; bytesContainsAnyPhantomPrefix([]byte(rawQ)) { + // + // Round-18 #5: the prefix gate alone is INSUFFICIENT for pooled OAuth + // credentials. A pooled credential's access phantom is the R3 + // pool-stable SYNTHETIC JWT (poolStablePhantomAccess) which has NO + // "SLUICE_PHANTOM" prefix — it is `header.payload.sig`. If an SDK puts + // the access token in a query parameter or path segment, + // bytesContainsAnyPhantomPrefix returns false, the swap is skipped, and + // the synthetic phantom JWT is forwarded upstream verbatim (request + // fails / phantom leaks). The body swap above already runs + // unconditionally; query/path must likewise run when any scoped pair's + // actual phantom bytes (including the prefix-less pooled JWT) are + // present. pairsPhantomPresentIn only matches phantom bytes that ARE in + // this destination's scoped pairs, so unrelated requests are not + // over-scanned and R3 byte-stability is untouched (we only trigger the + // existing swapPhantomBytes; the synthetic-JWT shape is unchanged). + if rawQ := f.Request.URL.RawQuery; bytesContainsAnyPhantomPrefix([]byte(rawQ)) || + pairsPhantomPresentIn([]byte(rawQ), pairs) { f.Request.URL.RawQuery = string( a.swapPhantomBytes([]byte(rawQ), pairs, host, port, "URL query", false), ) } // Pass 2+3 on URL path. pathContext=true selects path escaping so - // secrets containing spaces get %20, not '+'. - if rawP := f.Request.URL.Path; bytesContainsAnyPhantomPrefix([]byte(rawP)) { + // secrets containing spaces get %20, not '+'. Same pooled-JWT + // consideration as the query swap above (#5). + if rawP := f.Request.URL.Path; bytesContainsAnyPhantomPrefix([]byte(rawP)) || + pairsPhantomPresentIn([]byte(rawP), pairs) { f.Request.URL.Path = string( a.swapPhantomBytes([]byte(rawP), pairs, host, port, "URL path", true), ) @@ -1695,6 +1713,37 @@ func pairPhantomPresentInRequest(f *mitmproxy.Flow, p phantomPair) bool { return false } +// pairsPhantomPresentIn reports whether the given byte slice contains the +// actual phantom bytes (literal or either URL-encoded form) of ANY pair in +// pairs. Used to gate the URL query/path swap so a pooled credential's +// prefix-less R3 synthetic-JWT access phantom (poolStablePhantomAccess) is +// detected and swapped even though bytesContainsAnyPhantomPrefix — which +// only knows the literal "SLUICE_PHANTOM" prefix — would miss it (#5). +// +// This does NOT over-scan unrelated requests: pairs is already scoped to +// the bindings/pool that match THIS request's destination + protocol +// (buildPhantomPairs), and we only match phantom bytes that are genuinely +// present. It does NOT change R3 byte-stability: the synthetic-JWT shape is +// untouched; we merely let the existing swapPhantomBytes run on the +// query/path when the pooled phantom is there. +func pairsPhantomPresentIn(data []byte, pairs []phantomPair) bool { + if len(data) == 0 { + return false + } + for _, p := range pairs { + if len(p.phantom) > 0 && bytes.Contains(data, p.phantom) { + return true + } + if len(p.encodedPhantom) > 0 && bytes.Contains(data, p.encodedPhantom) { + return true + } + if len(p.encodedPhantomLower) > 0 && bytes.Contains(data, p.encodedPhantomLower) { + return true + } + } + return false +} + // tagPooledFlowAfterSwap records the per-flow pool-usage attribution tag // (flowInjected) for every pooled pair whose pool phantom is ACTUALLY present // in this outbound request, and only those. diff --git a/internal/proxy/pool_failover.go b/internal/proxy/pool_failover.go index 96bfc9b..1755596 100644 --- a/internal/proxy/pool_failover.go +++ b/internal/proxy/pool_failover.go @@ -140,6 +140,12 @@ type FailoverEvent struct { Reason string // short tag: 429 | 403 | 401 | invalid_grant | invalid_token Class failoverClass Until time.Time // member cooldown expiry just applied + // Epoch is the From member's membership epoch in the resolver + // generation that produced this failover. The durable guarded write + // commits only if (From, Pool, Epoch) is still a live membership row, + // so a late callback firing after a remove/re-add cannot persist this + // cooldown onto the re-created same-name successor (Cluster A #2). + Epoch int64 } // poolForResponse maps a response's CONNECT destination back to a pooled @@ -388,7 +394,22 @@ func (a *SluiceAddon) handlePoolFailover(f *mitmproxy.Flow) { // MarkCooldown takes the resolver's write lock; ResolveActive takes the // read lock, so the next request observes the new active member with no // dependency on the store-reconcile watcher. - pr.MarkCooldown(from, until, tag) + // + // Cluster A: capture the FROM member's pool+epoch identity from THIS + // resolver generation and thread it through MarkCooldown and the + // FailoverEvent. If the membership was removed and `from` re-added under + // the same name (a strictly greater epoch, or a different pool) before + // this stale write lands, the identity no longer matches and the + // re-created successor does NOT inherit this old response's cooldown. + idPool, idEpoch, idOK := pr.IdentityForMember(from) + if !idOK { + // `from` is no longer a member of any pool in the current + // generation (raced removal). There is no sound member to + // attribute this cooldown to; skip it entirely rather than write + // an unscoped cooldown a same-name re-add could inherit. + return + } + pr.MarkCooldownScoped(from, idPool, idEpoch, until, tag) // (2) Recompute the active member now that `from` is cooling down. If // every member is in cooldown ResolveActive degrades to the @@ -437,6 +458,7 @@ func (a *SluiceAddon) handlePoolFailover(f *mitmproxy.Flow) { Reason: tag, Class: class, Until: until, + Epoch: idEpoch, }) } } diff --git a/internal/proxy/pool_phantom_test.go b/internal/proxy/pool_phantom_test.go index 9a0874a..48f2761 100644 --- a/internal/proxy/pool_phantom_test.go +++ b/internal/proxy/pool_phantom_test.go @@ -268,6 +268,107 @@ func TestR3PoolPhantomByteIdenticalAcrossMemberSwitch(t *testing.T) { } } +// TestPooledAccessPhantomSwappedInQueryAndPath is the round-18 #5 +// regression. A pooled OAuth credential's access phantom is the R3 +// pool-stable SYNTHETIC JWT (poolStablePhantomAccess) — it has NO +// "SLUICE_PHANTOM" prefix. The request-side URL query/path swap was gated +// SOLELY on bytesContainsAnyPhantomPrefix, which only knows the literal +// "SLUICE_PHANTOM" prefix. So when an SDK puts the access token in a query +// parameter or a path segment, the gate returned false, the swap was +// skipped, and the pool-stable JWT phantom was forwarded UPSTREAM verbatim +// (request fails / phantom leaks). The fix also gates on +// pairsPhantomPresentIn so the scoped pooled JWT triggers the existing +// swap. +// +// Fail-before: the phantom JWT survives in RawQuery/Path. Pass-after: it is +// replaced with the active member's real access token in BOTH. Header +// placement still works, the SLUICE_PHANTOM-prefixed (refresh) phantom path +// is unaffected, and the R3 byte-identical-across-member-switch guarantee +// holds (the synthetic-JWT shape is untouched; we only added a detection +// path that triggers the existing swapPhantomBytes). +func TestPooledAccessPhantomSwappedInQueryAndPath(t *testing.T) { + addon, _, prPtr := setupPoolAddon(t, "memA", "memB") + client := setupAddonConn(addon, "auth.example.com:443") + + poolAccessPhantom := poolStablePhantomAccess("codex_pool") + // Sanity: the pooled access phantom is a prefix-less synthetic JWT, so + // the old prefix-only gate genuinely could not see it. + if strings.HasPrefix(poolAccessPhantom, "SLUICE_PHANTOM") { + t.Fatalf("pooled access phantom unexpectedly carries the SLUICE_PHANTOM prefix: %q", poolAccessPhantom) + } + if bytesContainsAnyPhantomPrefix([]byte(poolAccessPhantom)) { + t.Fatal("pooled access phantom must NOT be detectable by bytesContainsAnyPhantomPrefix (that is the whole #5 bug)") + } + + // --- Query-parameter placement. memA is active (position 0), so the + // phantom must be swapped for memA's real access token "A-access-old". + fq := newTestFlow(client, "GET", + "https://auth.example.com/v1/userinfo?access_token="+url.QueryEscape(poolAccessPhantom)+"&foo=bar") + addon.Request(fq) + gotQ := fq.Request.URL.RawQuery + if strings.Contains(gotQ, poolAccessPhantom) { + t.Fatalf("#5: pooled access phantom NOT swapped in URL query (forwarded upstream verbatim)\n query=%q", gotQ) + } + if !strings.Contains(gotQ, "A-access-old") { + t.Fatalf("#5: active member's real access token not injected into URL query\n query=%q", gotQ) + } + + // --- Path-segment placement. --- + fp := newTestFlow(client, "GET", + "https://auth.example.com/v1/tokens/"+url.PathEscape(poolAccessPhantom)+"/info") + addon.Request(fp) + gotP := fp.Request.URL.Path + if strings.Contains(gotP, poolAccessPhantom) { + t.Fatalf("#5: pooled access phantom NOT swapped in URL path (forwarded upstream verbatim)\n path=%q", gotP) + } + if !strings.Contains(gotP, "A-access-old") { + t.Fatalf("#5: active member's real access token not injected into URL path\n path=%q", gotP) + } + + // --- Header placement still works (must-not-regress). --- + fh := newTestFlow(client, "GET", "https://auth.example.com/v1/userinfo") + fh.Request.Header.Set("Authorization", "Bearer "+poolAccessPhantom) + addon.Request(fh) + auth := fh.Request.Header.Get("Authorization") + if strings.Contains(auth, poolAccessPhantom) || !strings.Contains(auth, "A-access-old") { + t.Fatalf("#5: header phantom swap regressed; Authorization=%q", auth) + } + + // --- SLUICE_PHANTOM-prefixed (refresh) phantom in query still swaps via + // the unchanged prefix path (no regression to the non-pooled path). --- + fr := newTestFlow(client, "GET", + "https://auth.example.com/v1/refresh?rt="+url.QueryEscape("SLUICE_PHANTOM:codex_pool.refresh")) + addon.Request(fr) + if strings.Contains(fr.Request.URL.RawQuery, "SLUICE_PHANTOM") { + t.Fatalf("prefix-form refresh phantom not swapped in query: %q", fr.Request.URL.RawQuery) + } + if !strings.Contains(fr.Request.URL.RawQuery, "A-refresh-old") { + t.Fatalf("prefix-form refresh phantom not replaced with real refresh: %q", fr.Request.URL.RawQuery) + } + + // --- R3 byte-identity preserved: fail member A over and confirm the + // phantom the agent would hold is still byte-identical (pool-stable), + // and the query swap now injects member B's real token. --- + before := poolStablePhantomAccess("codex_pool") + prPtr.Load().MarkCooldown("memA", timeFuture(), "429") + if got, _ := prPtr.Load().ResolveActive("codex_pool"); got != "memB" { + t.Fatalf("after cooldown active = %q, want memB", got) + } + after := poolStablePhantomAccess("codex_pool") + if before != after { + t.Fatalf("R3 byte-identity violated across member switch:\n before %q\n after %q", before, after) + } + fq2 := newTestFlow(client, "GET", + "https://auth.example.com/v1/userinfo?access_token="+url.QueryEscape(after)) + addon.Request(fq2) + if strings.Contains(fq2.Request.URL.RawQuery, after) { + t.Fatalf("#5: pooled access phantom not swapped after failover; query=%q", fq2.Request.URL.RawQuery) + } + if !strings.Contains(fq2.Request.URL.RawQuery, "B-access-old") { + t.Fatalf("#5: post-failover query swap did not inject member B's real access token; query=%q", fq2.Request.URL.RawQuery) + } +} + // TestR1RefreshAttributionByInjectedRefreshToken asserts a B-refresh // response is persisted to B's vault entry, never A's, even though both // members share one token URL (OAuthIndex.Match is 1:1 and collides). diff --git a/internal/store/migrations/000007_pool_membership_epoch.down.sql b/internal/store/migrations/000007_pool_membership_epoch.down.sql new file mode 100644 index 0000000..ab38d50 --- /dev/null +++ b/internal/store/migrations/000007_pool_membership_epoch.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE credential_pool_members DROP COLUMN epoch; +DROP TABLE IF EXISTS pool_membership_epoch; diff --git a/internal/store/migrations/000007_pool_membership_epoch.up.sql b/internal/store/migrations/000007_pool_membership_epoch.up.sql new file mode 100644 index 0000000..dde2001 --- /dev/null +++ b/internal/store/migrations/000007_pool_membership_epoch.up.sql @@ -0,0 +1,30 @@ +-- Pool membership epoch. +-- +-- A pool name + a credential name share one namespace and the +-- credential_health table is NOT foreign-keyed to live membership. A +-- credential removed from a pool (or whose pool was removed) and then +-- re-created/re-added under the SAME name produces a row with the SAME +-- (pool, credential) primary key. Without an epoch, a stale in-flight +-- failover write — the durable SetCredentialHealthIfPoolMember from a +-- detached goroutine, a manual `pool rotate`, or an old-generation +-- resolver's MarkCooldown — could not tell the OLD membership from its +-- re-created successor, so it would wrongly park the NEW member with the +-- OLD response's cooldown. +-- +-- pool_membership_epoch is a single-row monotonic counter bumped on every +-- pool create and pool remove. Each credential_pool_members row is stamped +-- with the counter value live at insert time, so a remove/re-add cycle +-- yields a strictly greater epoch on the successor row. The guarded health +-- write and MarkCooldown gate on (credential, pool, epoch): a stale write +-- carrying the old epoch finds no matching row and no-ops, while a +-- genuinely-still-live member (same epoch) still persists/rotates +-- (CRITICAL-1 durability + round-9/11/14/15 fixes preserved). + +CREATE TABLE pool_membership_epoch ( + id INTEGER PRIMARY KEY CHECK (id = 1), + epoch INTEGER NOT NULL DEFAULT 0 +); + +INSERT INTO pool_membership_epoch (id, epoch) VALUES (1, 0); + +ALTER TABLE credential_pool_members ADD COLUMN epoch INTEGER NOT NULL DEFAULT 0; diff --git a/internal/store/pools.go b/internal/store/pools.go index f6de3a0..aa00a65 100644 --- a/internal/store/pools.go +++ b/internal/store/pools.go @@ -22,10 +22,37 @@ type Pool struct { } // PoolMember is one credential entry in a pool. Position determines the -// failover order (lowest first). +// failover order (lowest first). Epoch is the value of the monotonic +// pool_membership_epoch counter at the time this membership row was +// inserted: a remove/re-add of the same (pool, credential) yields a +// strictly greater epoch, so a stale in-flight failover write that carries +// the OLD epoch can be told apart from its re-created successor and no-ops +// instead of parking the new member with the old response's cooldown. type PoolMember struct { Credential string Position int + Epoch int64 +} + +// bumpMembershipEpochTx increments the single-row monotonic +// pool_membership_epoch counter inside the supplied transaction and returns +// the NEW value. Called on every pool create and pool remove so any +// membership change advances the epoch; member inserts in the same +// transaction stamp the returned value onto their rows. Monotonic across +// the process lifetime and across restarts (the counter is durable). +func bumpMembershipEpochTx(tx *sql.Tx) (int64, error) { + if _, err := tx.Exec( + "UPDATE pool_membership_epoch SET epoch = epoch + 1 WHERE id = 1", + ); err != nil { + return 0, fmt.Errorf("bump membership epoch: %w", err) + } + var ep int64 + if err := tx.QueryRow( + "SELECT epoch FROM pool_membership_epoch WHERE id = 1", + ).Scan(&ep); err != nil { + return 0, fmt.Errorf("read membership epoch: %w", err) + } + return ep, nil } // CredentialHealth records whether a credential is currently eligible for @@ -172,6 +199,15 @@ func (s *Store) CreatePoolWithMembers(name, strategy string, members []string) e return fmt.Errorf("insert pool %q: %w", name, err) } + // Advance the monotonic membership epoch and stamp it on every member + // inserted here. A pool removed and re-created under the same name (or a + // member removed and re-added) gets a strictly greater epoch, so a stale + // failover write carrying the OLD epoch cannot apply to the successor. + epoch, err := bumpMembershipEpochTx(tx) + if err != nil { + return err + } + for i, m := range members { if err := validatePoolMemberTx(tx, m); err != nil { return err @@ -180,8 +216,8 @@ func (s *Store) CreatePoolWithMembers(name, strategy string, members []string) e return err } if _, err := tx.Exec( - "INSERT INTO credential_pool_members (pool, credential, position) VALUES (?, ?, ?)", - name, m, i, + "INSERT INTO credential_pool_members (pool, credential, position, epoch) VALUES (?, ?, ?, ?)", + name, m, i, epoch, ); err != nil { return fmt.Errorf("insert pool member %q: %w", m, err) } @@ -208,7 +244,7 @@ func (s *Store) GetPool(name string) (*Pool, error) { } rows, err := s.db.Query( - "SELECT credential, position FROM credential_pool_members WHERE pool = ? ORDER BY position", name, + "SELECT credential, position, epoch FROM credential_pool_members WHERE pool = ? ORDER BY position", name, ) if err != nil { return nil, fmt.Errorf("list pool members %q: %w", name, err) @@ -216,7 +252,7 @@ func (s *Store) GetPool(name string) (*Pool, error) { defer func() { _ = rows.Close() }() for rows.Next() { var m PoolMember - if err := rows.Scan(&m.Credential, &m.Position); err != nil { + if err := rows.Scan(&m.Credential, &m.Position, &m.Epoch); err != nil { return nil, fmt.Errorf("scan pool member: %w", err) } p.Members = append(p.Members, m) @@ -252,7 +288,7 @@ func (s *Store) ListPools() ([]Pool, error) { _ = rows.Close() mrows, err := s.db.Query( - "SELECT pool, credential, position FROM credential_pool_members ORDER BY pool, position", + "SELECT pool, credential, position, epoch FROM credential_pool_members ORDER BY pool, position", ) if err != nil { return nil, fmt.Errorf("list pool members: %w", err) @@ -261,7 +297,7 @@ func (s *Store) ListPools() ([]Pool, error) { for mrows.Next() { var pool string var m PoolMember - if err := mrows.Scan(&pool, &m.Credential, &m.Position); err != nil { + if err := mrows.Scan(&pool, &m.Credential, &m.Position, &m.Epoch); err != nil { return nil, fmt.Errorf("scan pool member: %w", err) } if p, ok := pools[pool]; ok { @@ -328,6 +364,14 @@ func (s *Store) RemovePool(name string) (bool, error) { n, _ := res.RowsAffected() if n > 0 { + // Advance the membership epoch on removal too. A guarded write or a + // MarkCooldown still carrying this pool generation's epoch will no + // longer find a matching (credential, pool, epoch) row once the + // CASCADE has wiped the membership, so a late failover cannot + // resurrect the removed member's cooldown for a re-created successor. + if _, err := bumpMembershipEpochTx(tx); err != nil { + return false, err + } // The CASCADE has now removed this pool's credential_pool_members // rows. For each former member, drop its health row UNLESS it is // still a member of some OTHER pool (the membership query runs @@ -479,6 +523,34 @@ func (s *Store) SetCredentialHealth(credential, status string, cooldownUntil tim // The membership SELECT and the upsert share one transaction so a concurrent // removal cannot interleave between the check and the write. func (s *Store) SetCredentialHealthIfPoolMember(credential, status string, cooldownUntil time.Time, reason string) (wrote bool, err error) { + return s.setCredentialHealthGuarded(credential, "", -1, status, cooldownUntil, reason) +} + +// SetCredentialHealthIfPoolMemberEpoch is the pool+epoch-scoped guarded +// write. It commits the monotonic-extend cooldown upsert ONLY when a +// credential_pool_members row exists for exactly (credential, pool, epoch), +// with the membership check and the upsert in ONE transaction. +// +// This closes the remove/re-add aliasing hole that the name-only guard left +// open (Cluster A #1/#2/#3). Sequence: pool P with member c (epoch e1) takes +// a 429; remove P; recreate c into a new pool Q (epoch e2 > e1); the +// detached failover goroutine — or a stale old-generation MarkCooldown, or +// a raced `pool rotate` — fires SetCredentialHealthIfPoolMemberEpoch(c, "P", +// e1, ...). The name-only guard would find c present (now in Q) and wrongly +// persist the OLD response's cooldown onto the NEW member. The (pool, epoch) +// predicate finds no row matching ("P", e1) and no-ops (wrote=false). A +// genuinely-still-live member fires with its CURRENT (pool, epoch) and the +// row matches, so CRITICAL-1 restart durability and the round-9/11/14/15 +// fixes are preserved. +// +// pool=="" with epoch<0 falls back to the legacy name-only predicate so +// callers without pool/epoch context (and the store unit tests that +// exercise the name-only path) are not regressed. +func (s *Store) SetCredentialHealthIfPoolMemberEpoch(credential, pool string, epoch int64, status string, cooldownUntil time.Time, reason string) (wrote bool, err error) { + return s.setCredentialHealthGuarded(credential, pool, epoch, status, cooldownUntil, reason) +} + +func (s *Store) setCredentialHealthGuarded(credential, pool string, epoch int64, status string, cooldownUntil time.Time, reason string) (wrote bool, err error) { cu, verr := validateCredentialHealthArgs(credential, status, cooldownUntil) if verr != nil { return false, verr @@ -491,14 +563,27 @@ func (s *Store) SetCredentialHealthIfPoolMember(credential, status string, coold defer func() { _ = tx.Rollback() }() var live int - qerr := tx.QueryRow( - "SELECT 1 FROM credential_pool_members WHERE credential = ? LIMIT 1", credential, - ).Scan(&live) + var qerr error + if pool == "" && epoch < 0 { + // Legacy name-only predicate (no pool/epoch context). + qerr = tx.QueryRow( + "SELECT 1 FROM credential_pool_members WHERE credential = ? LIMIT 1", credential, + ).Scan(&live) + } else { + // Pool+epoch-scoped predicate. A stale write carrying the OLD epoch + // (the membership row was removed and the credential re-added under + // the same name into another pool with a strictly greater epoch) + // finds no matching row and is a no-op. + qerr = tx.QueryRow( + "SELECT 1 FROM credential_pool_members WHERE credential = ? AND pool = ? AND epoch = ? LIMIT 1", + credential, pool, epoch, + ).Scan(&live) + } switch { case errors.Is(qerr, sql.ErrNoRows): - // Not a live pool member: skip the durable write entirely so a - // removed credential's health row is never resurrected. No commit - // needed — nothing was written. + // Not a live member of THIS pool at THIS epoch: skip the durable + // write entirely so a removed/superseded membership's health row is + // never resurrected onto a re-created successor. No commit needed. return false, nil case qerr != nil: return false, fmt.Errorf("check pool membership for %q: %w", credential, qerr) diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go index c99dbab..9776be4 100644 --- a/internal/store/pools_test.go +++ b/internal/store/pools_test.go @@ -384,25 +384,75 @@ func TestMigration000006DownUp(t *testing.T) { t.Fatalf("migrator: %v", err) } - // Step down one migration (000006 -> 000005). + columnExists := func(table, col string) bool { + rows, qerr := s.db.Query("PRAGMA table_info(" + table + ")") + if qerr != nil { + return false + } + defer func() { _ = rows.Close() }() + for rows.Next() { + var cid int + var name, ctype string + var notnull, pk int + var dflt interface{} + if scanErr := rows.Scan(&cid, &name, &ctype, ¬null, &dflt, &pk); scanErr != nil { + return false + } + if name == col { + return true + } + } + return false + } + + // 000007 added the epoch column + pool_membership_epoch counter on top + // of 000006. Both must be present after the up migration. + if !columnExists("credential_pool_members", "epoch") { + t.Fatal("credential_pool_members.epoch missing after up migration (000007)") + } + if !tableExists("pool_membership_epoch") { + t.Fatal("pool_membership_epoch missing after up migration (000007)") + } + + // Step down one migration (000007 -> 000006): the epoch column and the + // counter table go away, the 000006 pool tables stay. if err := m.Steps(-1); err != nil { - t.Fatalf("down 1: %v", err) + t.Fatalf("down 1 (000007): %v", err) + } + if columnExists("credential_pool_members", "epoch") { + t.Error("credential_pool_members.epoch still present after 000007 down") + } + if tableExists("pool_membership_epoch") { + t.Error("pool_membership_epoch still present after 000007 down") } for _, tbl := range []string{"credential_pools", "credential_pool_members", "credential_health"} { - if tableExists(tbl) { - t.Errorf("table %q still present after down migration", tbl) + if !tableExists(tbl) { + t.Errorf("000006 table %q wrongly dropped by 000007 down", tbl) } } - // Step back up; tables return. - if err := m.Steps(1); err != nil { - t.Fatalf("up 1: %v", err) + // Step down again (000006 -> 000005): the pool tables themselves go. + if err := m.Steps(-1); err != nil { + t.Fatalf("down 1 (000006): %v", err) } for _, tbl := range []string{"credential_pools", "credential_pool_members", "credential_health"} { + if tableExists(tbl) { + t.Errorf("table %q still present after 000006 down migration", tbl) + } + } + + // Step back up twice; tables and the epoch column return. + if err := m.Steps(2); err != nil { + t.Fatalf("up 2: %v", err) + } + for _, tbl := range []string{"credential_pools", "credential_pool_members", "credential_health", "pool_membership_epoch"} { if !tableExists(tbl) { t.Errorf("table %q missing after re-up migration", tbl) } } + if !columnExists("credential_pool_members", "epoch") { + t.Error("credential_pool_members.epoch missing after re-up migration") + } } // TestRemoveCredentialMetaBlocksLivePoolMember is the Finding 3 regression. @@ -824,6 +874,110 @@ func TestSetCredentialHealthIfPoolMemberSkipsRemoved(t *testing.T) { } } +// TestSetCredentialHealthIfPoolMemberEpochRejectsReAddedSuccessor is the +// Cluster A #2 regression. The name-only guard only verifies the credential +// is a member of SOME pool. Sequence: pool P with member c (epoch e1) takes a +// 429; remove P; re-create c into a NEW pool Q (epoch e2 > e1); a late +// detached failover goroutine fires the guarded write for (c, P, e1). The +// name-only guard finds c present (now in Q) and WRONGLY persists the OLD +// response's cooldown — which loadPoolResolver then seeds onto the NEW member. +// +// Fail-before: SetCredentialHealthIfPoolMember(c) returns wrote=true and a +// cooldown row appears for the re-added c. Pass-after: the epoch-scoped +// SetCredentialHealthIfPoolMemberEpoch(c, "P", e1) finds no (c, P, e1) row +// (Q's row has e2) and no-ops; the genuinely-live (c, Q, e2) write still +// persists. +func TestSetCredentialHealthIfPoolMemberEpochRejectsReAddedSuccessor(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "c") + if err := s.CreatePoolWithMembers("P", "failover", []string{"c"}); err != nil { + t.Fatalf("create pool P: %v", err) + } + pP, err := s.GetPool("P") + if err != nil || pP == nil || len(pP.Members) != 1 { + t.Fatalf("GetPool(P) = %+v, %v", pP, err) + } + e1 := pP.Members[0].Epoch + if e1 <= 0 { + t.Fatalf("expected a positive membership epoch for the first pool, got %d", e1) + } + + // Remove P (CASCADE wipes membership + deletes c's health row), then + // re-create c into a different pool Q. The re-add gets a strictly + // greater epoch. + if removed, rerr := s.RemovePool("P"); rerr != nil || !removed { + t.Fatalf("RemovePool(P) = %v, %v", removed, rerr) + } + seedOAuthCred(t, s, "c") + if err := s.CreatePoolWithMembers("Q", "failover", []string{"c"}); err != nil { + t.Fatalf("recreate c into Q: %v", err) + } + pQ, err := s.GetPool("Q") + if err != nil || pQ == nil || len(pQ.Members) != 1 { + t.Fatalf("GetPool(Q) = %+v, %v", pQ, err) + } + e2 := pQ.Members[0].Epoch + if e2 <= e1 { + t.Fatalf("re-added member epoch %d must be strictly greater than the removed one %d (monotonic epoch broken)", e2, e1) + } + + // The stale late failover from P (epoch e1) must NOT apply to the + // re-added c that now lives in Q at epoch e2. + until := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) + wrote, err := s.SetCredentialHealthIfPoolMemberEpoch("c", "P", e1, "cooldown", until, "failover:429 stale") + if err != nil { + t.Fatalf("guarded epoch write (stale): %v", err) + } + if wrote { + t.Fatal("stale failover from removed pool P (epoch e1) wrongly persisted onto the re-added member in pool Q (Cluster A #2)") + } + if h, herr := s.GetCredentialHealth("c"); herr != nil || h != nil { + t.Fatalf("re-added c inherited a stale cooldown row: %+v, %v", h, herr) + } + + // The genuinely-still-live member (Q, e2) must still persist (CRITICAL-1 + // restart durability preserved). + wrote, err = s.SetCredentialHealthIfPoolMemberEpoch("c", "Q", e2, "cooldown", until, "failover:429 live") + if err != nil { + t.Fatalf("guarded epoch write (live): %v", err) + } + if !wrote { + t.Fatal("guarded epoch write skipped the genuinely-live member of Q (CRITICAL-1 durability regressed)") + } + h, herr := s.GetCredentialHealth("c") + if herr != nil || h == nil || h.Status != "cooldown" || !h.CooldownUntil.Equal(until) { + t.Fatalf("live member cooldown not persisted: %+v, %v", h, herr) + } +} + +// TestSetCredentialHealthIfPoolMemberEpochLiveMemberSamePool pins the common +// path: a still-live member failing over against its CURRENT (pool, epoch) +// always persists, so monotonic-cooldown / round-9/11/14/15 durability is +// intact for a genuinely-live member. +func TestSetCredentialHealthIfPoolMemberEpochLiveMemberSamePool(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "m") + if err := s.CreatePoolWithMembers("pool", "failover", []string{"m"}); err != nil { + t.Fatalf("create pool: %v", err) + } + p, _ := s.GetPool("pool") + ep := p.Members[0].Epoch + until := time.Now().Add(5 * time.Minute).UTC().Truncate(time.Second) + wrote, err := s.SetCredentialHealthIfPoolMemberEpoch("m", "pool", ep, "cooldown", until, "429") + if err != nil || !wrote { + t.Fatalf("live same-pool/epoch write = %v, %v; want true, nil", wrote, err) + } + // A write carrying a wrong (mismatched) epoch for the SAME live pool + // must also no-op — defends against an epoch confusion bug. + w2, err := s.SetCredentialHealthIfPoolMemberEpoch("m", "pool", ep+999, "cooldown", until, "429") + if err != nil { + t.Fatalf("mismatched-epoch write err: %v", err) + } + if w2 { + t.Fatal("write with a mismatched epoch for the live pool wrongly committed") + } +} + // TestSetCredentialHealthIfPoolMemberValidates pins that the guarded variant // applies the same input validation as the unconditional path before touching // the DB (no transaction opened for invalid input). diff --git a/internal/vault/pool.go b/internal/vault/pool.go index 6d61c62..bdeefa8 100644 --- a/internal/vault/pool.go +++ b/internal/vault/pool.go @@ -25,6 +25,18 @@ type memberHealth struct { reason string } +// memberIdentity is the pool+epoch identity of a credential in the CURRENT +// resolver generation. A remove/re-add of the same credential name yields a +// strictly greater epoch (and possibly a different pool), so a stale +// MarkCooldown carrying the OLD identity can be told apart from its +// re-created successor and rejected. The zero value (pool=="", epoch==0) is +// never a live identity because every membership row is stamped with a +// post-bump epoch >= 1. +type memberIdentity struct { + pool string + epoch int64 +} + // PoolHealth is the mutex-guarded credential cooldown map. It is // deliberately a SEPARATE object from PoolResolver so it can outlive any // single resolver generation. @@ -46,30 +58,32 @@ type memberHealth struct { // Store rows still seed the map at startup (Seed) for cross-restart // durability, and the seed is monotonic (never shortens a live in-memory // cooldown). -// Finding 3 (round-15): a response handled by an OLD resolver generation can -// call MarkCooldown AFTER a NEW generation already pruned non-members during -// resolver rebuild (StorePool/MergeLiveCooldowns shared path). If that -// credential was removed from every pool in the new generation, the -// unguarded MarkCooldown would re-insert a stale in-memory cooldown that a -// later same-named re-add inherits before its TTL. The fix is to store the -// CURRENT generation's authoritative member set on the shared PoolHealth and -// update it under the SAME mutex that guards the cooldown map, so the prune -// (member-set replace) and a concurrent MarkCooldown cannot interleave to -// leave a non-member cooldown entry behind. MarkCooldown, under the lock, -// no-ops when the credential is not in currentMembers (and currentMembers is -// non-nil). currentMembers stays nil until SetCurrentMembers is called the -// first time (ad-hoc/private resolvers that never set it keep the old -// permissive behavior, so single-generation callers are not regressed). +// Finding 3 (round-15) + Cluster A (round-18): a response handled by an OLD +// resolver generation can call MarkCooldown AFTER a NEW generation already +// pruned/replaced membership during resolver rebuild. The round-15 gate +// keyed only on the credential NAME being present in SOME pool, so a +// remove/re-add of the same name into a DIFFERENT pool would let the stale +// write through and the new member would inherit the OLD response's +// cooldown. The fix is to key the gate on the credential's pool+epoch +// IDENTITY (memberIdentity) for the current generation, updated under the +// SAME mutex that guards the cooldown map, so the membership replace and a +// concurrent stale MarkCooldown cannot interleave AND a stale write whose +// (pool, epoch) no longer matches the live identity no-ops. currentMembers +// stays nil until SetCurrentMembers is called the first time +// (ad-hoc/private resolvers that never set it keep the old permissive +// behavior, so single-generation callers are not regressed). type PoolHealth struct { mu sync.RWMutex health map[string]memberHealth - // currentMembers is the authoritative member set of the CURRENT - // resolver generation. nil = "not tracked" (gate disabled, legacy - // permissive behavior). Non-nil but missing a credential = that - // credential is not a member of any pool in the current generation, so - // MarkCooldown must NOT write a cooldown for it (write-after-prune - // guard). Mutated only under mu, the same lock the cooldown map uses. - currentMembers map[string]struct{} + // currentMembers maps a credential name -> its pool+epoch identity in + // the CURRENT resolver generation. nil = "not tracked" (gate disabled, + // legacy permissive behavior). Non-nil but missing a credential = that + // credential is not a member of any pool in the current generation. + // Present but with a DIFFERENT (pool, epoch) than the stale write + // carries = the credential was removed and re-added (a later + // same-named successor) so the stale write must NOT apply. Mutated only + // under mu, the same lock the cooldown map uses. + currentMembers map[string]memberIdentity } // NewPoolHealth returns an empty shared health map. Call this exactly once @@ -88,7 +102,7 @@ func NewPoolHealth() *PoolHealth { // observes the OLD member set entirely or the NEW one entirely — it can never // observe a half-updated set, and it can never slip a non-member cooldown in // between the prune and the member-set swap. -func (ph *PoolHealth) SetCurrentMembers(members map[string]struct{}) { +func (ph *PoolHealth) SetCurrentMembers(members map[string]memberIdentity) { if ph == nil { return } @@ -143,6 +157,12 @@ type PoolResolver struct { pools map[string][]string // memberOf maps a credential name -> the pools that contain it. memberOf map[string][]string + // identity maps a credential name -> its pool+epoch identity in THIS + // generation. Threaded into MarkCooldown and the FailoverEvent so a + // stale write carrying an old (pool, epoch) cannot apply to a + // re-created same-name successor. A credential belongs to at most one + // pool (store enforces this), so a single identity per credential. + identity map[string]memberIdentity // health is the shared, swap-surviving cooldown map. Never nil after // NewPoolResolver (a fresh PoolHealth is allocated when none is given, @@ -181,6 +201,7 @@ func NewPoolResolverShared(pools []store.Pool, healthRows []store.CredentialHeal pr := &PoolResolver{ pools: make(map[string][]string, len(pools)), memberOf: make(map[string][]string), + identity: make(map[string]memberIdentity), health: shared, } for _, p := range pools { @@ -188,6 +209,7 @@ func NewPoolResolverShared(pools []store.Pool, healthRows []store.CredentialHeal for _, m := range p.Members { members = append(members, m.Credential) pr.memberOf[m.Credential] = append(pr.memberOf[m.Credential], p.Name) + pr.identity[m.Credential] = memberIdentity{pool: p.Name, epoch: m.Epoch} } pr.pools[p.Name] = members } @@ -202,9 +224,9 @@ func NewPoolResolverShared(pools []store.Pool, healthRows []store.CredentialHeal // nil, e.g. CLI `pool` subcommands) leave currentMembers nil to preserve // the old permissive single-generation behavior. if explicitShared { - cm := make(map[string]struct{}, len(pr.memberOf)) - for cred := range pr.memberOf { - cm[cred] = struct{}{} + cm := make(map[string]memberIdentity, len(pr.identity)) + for cred, id := range pr.identity { + cm[cred] = id } shared.SetCurrentMembers(cm) } @@ -296,12 +318,52 @@ func (pr *PoolResolver) ResolveActive(name string) (member string, ok bool) { return soonest, true } +// IdentityForMember returns the pool+epoch identity of a credential in THIS +// resolver generation. ok is false when the credential is not a member of +// any pool in this generation. The failover path captures this at the time +// the cooldown decision is made and threads (pool, epoch) through to +// MarkCooldown and the durable guarded write so a stale write cannot apply +// to a re-created same-name successor (Cluster A). +func (pr *PoolResolver) IdentityForMember(credential string) (pool string, epoch int64, ok bool) { + if pr == nil { + return "", 0, false + } + id, ok := pr.identity[credential] + if !ok { + return "", 0, false + } + return id.pool, id.epoch, true +} + // MarkCooldown records, in memory and synchronously, that a member should be // skipped until `until`. Phase 2 failover calls this on the response path // BEFORE the response returns so the very next request injects the next // member; the durable store write only reconciles afterwards. Calling with a // zero/past `until` clears the cooldown (recovery). +// +// This is the legacy identity-UNSCOPED form: it keeps the round-15 +// name-only write-after-prune guard but does NOT distinguish a removed and +// re-added same-name credential. The response/failover path MUST use +// MarkCooldownScoped so a stale write cannot park a re-created successor +// (Cluster A #1). Single-generation callers (CLI tools, unit tests) keep +// using this. func (pr *PoolResolver) MarkCooldown(credential string, until time.Time, reason string) { + pr.markCooldown(credential, "", -1, until, reason) +} + +// MarkCooldownScoped is the pool+epoch identity-scoped form used by the +// Phase 2 failover response path. pool+epoch identify WHICH membership +// generation the cooldown decision was made against. The gate commits the +// in-memory write only if that exact (pool, epoch) is still the live +// identity for `credential` in the current generation: a stale write whose +// membership was removed and the name re-added (a strictly greater epoch, +// or a different pool) no-ops, so the re-created successor does NOT inherit +// the old response's cooldown (Cluster A #1). +func (pr *PoolResolver) MarkCooldownScoped(credential, pool string, epoch int64, until time.Time, reason string) { + pr.markCooldown(credential, pool, epoch, until, reason) +} + +func (pr *PoolResolver) markCooldown(credential, pool string, epoch int64, until time.Time, reason string) { if pr == nil { return } @@ -328,7 +390,20 @@ func (pr *PoolResolver) MarkCooldown(credential string, until time.Time, reason // single-generation callers are not regressed. isClear := until.IsZero() || !until.After(time.Now()) if !isClear && pr.health.currentMembers != nil { - if _, isMember := pr.health.currentMembers[credential]; !isMember { + live, isMember := pr.health.currentMembers[credential] + if !isMember { + // Not a member of any pool in the current generation: + // write-after-prune guard (round-15). + return + } + // Cluster A #1: identity-scoped guard. When the caller carries a + // pool+epoch (epoch >= 0), reject the write unless it still matches + // the live identity. A removed+re-added same-name credential has a + // strictly greater epoch (or a different pool), so an old in-flight + // 429's MarkCooldown does NOT park the re-created successor. A + // caller that opts out (epoch < 0) keeps the round-15 name-only + // behavior. + if epoch >= 0 && (live.pool != pool || live.epoch != epoch) { return } } @@ -405,9 +480,9 @@ func (pr *PoolResolver) MergeLiveCooldowns(prev *PoolResolver) { // for any credential this generation no longer owns — the prune // and the member-set swap are one atomic critical section, so no // non-member cooldown can be slipped in between them. - cm := make(map[string]struct{}, len(pr.memberOf)) - for cred := range pr.memberOf { - cm[cred] = struct{}{} + cm := make(map[string]memberIdentity, len(pr.identity)) + for cred, id := range pr.identity { + cm[cred] = id } pr.health.currentMembers = cm pr.health.mu.Unlock() diff --git a/internal/vault/pool_test.go b/internal/vault/pool_test.go index 20d7568..e6ed6e4 100644 --- a/internal/vault/pool_test.go +++ b/internal/vault/pool_test.go @@ -16,6 +16,88 @@ func mkPool(name string, members ...string) store.Pool { return p } +// mkPoolEpoch builds a pool whose members all carry the given membership +// epoch, mirroring what the store stamps on credential_pool_members rows. +func mkPoolEpoch(name string, epoch int64, members ...string) store.Pool { + p := store.Pool{Name: name, Strategy: store.PoolStrategyFailover} + for i, m := range members { + p.Members = append(p.Members, store.PoolMember{Credential: m, Position: i, Epoch: epoch}) + } + return p +} + +// TestMarkCooldownScopedRejectsReAddedSuccessor is the Cluster A #1 +// regression. The round-15 gate only checked the credential NAME was in the +// current generation's member set. Sequence: pool P with member c (epoch e1) +// takes a 429 on an in-flight request; c/P are removed and c is re-created +// into a NEW pool Q (epoch e2 > e1); the OLD in-flight response's +// MarkCooldown for c now lands. The name-only gate sees c present (it is a +// member of Q now) and WRONGLY parks the re-created successor with the OLD +// response's cooldown. +// +// Deterministic interleave (no sleeps): operations are explicitly ordered so +// the stale gen1 MarkCooldownScoped(c, P, e1) runs AFTER gen2 published Q's +// member set (c at epoch e2). Fail-before: c cooling in gen2. Pass-after: +// the (pool, epoch) identity no longer matches so the write no-ops, and the +// genuinely-live (c, Q, e2) cooldown still applies. +func TestMarkCooldownScopedRejectsReAddedSuccessor(t *testing.T) { + shared := NewPoolHealth() + + const e1 = int64(1) + const e2 = int64(2) + + // gen1 (OLD): pool P with member c at epoch e1. An in-flight response + // resolved through gen1 and holds the gen1 resolver. + gen1 := NewPoolResolverShared([]store.Pool{mkPoolEpoch("P", e1, "c")}, nil, shared) + + // gen2 (NEW): P removed; c re-created into pool Q at a strictly greater + // epoch e2. The rebuild publishes gen2's identity map. + gen2 := NewPoolResolverShared([]store.Pool{mkPoolEpoch("Q", e2, "c")}, nil, shared) + gen2.MergeLiveCooldowns(gen1) + + // INTERLEAVE: the stale gen1 response records a failover cooldown for c + // using the identity it captured (P, e1) — AFTER gen2 published (c -> Q, + // e2). The identity no longer matches, so the write must be gated out. + gen1.MarkCooldownScoped("c", "P", e1, time.Now().Add(300*time.Second), "failover:429") + + if until, cooling := gen2.CooldownUntil("c"); cooling { + t.Fatalf("Cluster A #1: stale (P,e1) MarkCooldownScoped parked the re-added successor c (Q,e2): until=%v", until) + } + if got, ok := gen2.ResolveActive("Q"); !ok || got != "c" { + t.Fatalf("Cluster A #1: re-added c must be active in Q; got %q,%v want c,true", got, ok) + } + + // CRITICAL-1 preserved: the genuinely-live member failing over against + // its CURRENT identity (Q, e2) still records the cooldown. + gen2.MarkCooldownScoped("c", "Q", e2, time.Now().Add(300*time.Second), "failover:429 live") + if _, cooling := gen2.CooldownUntil("c"); !cooling { + t.Fatal("Cluster A #1 regressed CRITICAL-1: live (Q,e2) failover cooldown was dropped") + } +} + +// TestMarkCooldownLegacyUnscopedStillGated pins that the legacy +// identity-UNSCOPED MarkCooldown keeps the round-15 name-only +// write-after-prune behavior (single-generation CLI/test callers are not +// regressed): a non-member write is still gated, a member write still +// applies. +func TestMarkCooldownLegacyUnscopedStillGated(t *testing.T) { + shared := NewPoolHealth() + gen1 := NewPoolResolverShared([]store.Pool{mkPool("pool", "a", "x")}, nil, shared) + gen2 := NewPoolResolverShared([]store.Pool{mkPool("pool", "a")}, nil, shared) + gen2.MergeLiveCooldowns(gen1) + + // Non-member "x": still gated by the name-only set (round-15 preserved). + gen1.MarkCooldown("x", time.Now().Add(300*time.Second), "failover:401") + if _, cooling := gen2.CooldownUntil("x"); cooling { + t.Fatal("legacy MarkCooldown lost the round-15 write-after-prune gate") + } + // Member "a": still applies. + gen2.MarkCooldown("a", time.Now().Add(300*time.Second), "429") + if _, cooling := gen2.CooldownUntil("a"); !cooling { + t.Fatal("legacy MarkCooldown wrongly gated a live member") + } +} + func TestResolveActivePassthroughForNonPool(t *testing.T) { pr := NewPoolResolver(nil, nil) got, ok := pr.ResolveActive("plain_cred") From d89382337923ce33166cc6430e29af3195da7927 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 20:10:20 +0800 Subject: [PATCH 43/49] fix(proxy): plain-cred refresh attribution on shared token URL; pool-aware QUIC injection --- CLAUDE.md | 5 +- internal/proxy/addon.go | 74 +++++++- internal/proxy/phantom_pairs.go | 20 ++- internal/proxy/pool_failover.go | 43 ++++- internal/proxy/pool_splithost_test.go | 242 ++++++++++++++++++++++++++ internal/proxy/quic.go | 93 ++++++++-- internal/proxy/quic_test.go | 6 +- internal/proxy/server.go | 2 +- 8 files changed, 452 insertions(+), 33 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index bc8a605..1a142ce 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -232,9 +232,10 @@ Auto-failover on 429/401 is the primary mechanism; `pool rotate` is an operator **Phase 1 — phantom indirection (pool phantom → active member):** -- **Single chokepoint (I2):** every `binding.Credential` / `OAuthIndex.Has` / `extractInjectableSecret` / persist consumer on the HTTP/HTTPS OAuth path routes through `PoolResolver.ResolveActive` (`resolveInjectionTarget` for pass-1 header + pass-2 phantom swap; `resolveOAuthResponseAttribution` for the response/persist path). `idx.Has` is always called with the resolved member name, never the pool. Plain (non-pool) credentials pass through `ResolveActive` unchanged. SSH/mail/QUIC are non-OAuth and out of scope. +- **Single chokepoint (I2):** every `binding.Credential` / `OAuthIndex.Has` / `extractInjectableSecret` / persist consumer on the HTTP/HTTPS OAuth path routes through `PoolResolver.ResolveActive` (`resolveInjectionTarget` for pass-1 header + pass-2 phantom swap; `resolveOAuthResponseAttribution` for the response/persist path). `idx.Has` is always called with the resolved member name, never the pool. Plain (non-pool) credentials pass through `ResolveActive` unchanged. SSH/mail are non-OAuth and out of scope. +- **QUIC pool support is active-member expansion only (HTTP-vs-QUIC capability boundary):** the HTTP/1.x/HTTP/2 MITM addon implements the full pool feature set (R1 refresh attribution, R3 pool-stable phantom, Phase 2 429/401 auto-failover). The HTTP/3/QUIC injection path (`QUICProxy.buildPhantomPairs` and the binding-header injection in `quic.go`) is a simpler buffered swap with no response-side OAuth interception. It IS pool-aware: `QUICProxy.resolvePoolMember` (wired via `NewQUICProxy`'s `poolResolver` arg from `server.go`) expands a pool-named binding to the pool's current active member before `provider.Get`, mirroring `resolveInjectionTarget`, so a pool binding *functions* over QUIC and the agent-held phantom stays keyed on the pool name (stable across member switches). What QUIC does **not** do: per-request OAuth refresh attribution (R1), pool-stable synthetic-JWT minting (R3), and automatic 429/401 member failover. Over QUIC the injected secret is whatever member the HTTP path (or an operator via `pool rotate`) last made active; a QUIC-only 429/401 does not trigger a member switch. Deployments needing full pool failover must route the pooled upstream over HTTP/HTTPS rather than HTTP/3. - **Active-member selection:** healthy or expired-cooldown members first, by configured position; if all members are in cooldown, the soonest-recovering member is returned with a WARNING (degrade, never hard-fail). Recovery is lazy — evaluated in `ResolveActive`, no scheduler. -- **R1 refresh-token attribution / fail-closed:** when pass-2 swaps `SLUICE_PHANTOM:.refresh`, sluice records `realRefreshToken → member` in a short-TTL map. On the token-endpoint response it recovers the member by that real refresh token and persists to that member (`persistAddonOAuthTokens(member, ...)`, singleflight key `"persist:"+member`). The join key is the real **refresh** token sluice injected — never the access token, the client connection, or `OAuthIndex.Match` (two pooled members share `auth.openai.com`'s token URL and collide there). If the member is unrecoverable: WARNING + skip the vault write, never guess. Rotating refresh tokens are single-use, so a mis-attributed write would brick both accounts — fail-closed is mandatory. +- **R1 refresh-token attribution / fail-closed:** when pass-2 swaps `SLUICE_PHANTOM:.refresh`, sluice records `realRefreshToken → member` in a short-TTL map. On the token-endpoint response it recovers the member by that real refresh token and persists to that member (`persistAddonOAuthTokens(member, ...)`, singleflight key `"persist:"+member`). The join key is the real **refresh** token sluice injected — never the access token, the client connection, or `OAuthIndex.Match` (two pooled members share `auth.openai.com`'s token URL and collide there). If the member is unrecoverable: WARNING + skip the vault write, never guess. Rotating refresh tokens are single-use, so a mis-attributed write would brick both accounts — fail-closed is mandatory. **Plain-credential disambiguation on a shared token URL:** a plain (non-pool) OAuth credential that merely shares its token URL with a pool also has its injected real refresh token tagged `realRefreshToken → ` (the plain path in `buildPhantomPairs` / `buildOAuthPhantomPairs`'s `onRefreshInject`, including the token-host expansion for split-host plain creds). On the response side, when a pool shares the token URL, `resolveOAuthResponseAttribution` recovers the tag: if it resolves to a name that is **not** a pool member (`PoolForMember == ""`), the refresh is attributed 1:1 to that plain credential (its own phantom, its own vault write) — NOT fail-closed as a pooled refresh. The pooled fail-closed path is taken only when recovery fails entirely or resolves to an actual pool member. The `poolForResponse` failover path applies the same rule: a recovered owner not in any pool only triggers the membership-raced active-member fallback when an independent `flowInjected` pool-usage tag (set post-swap only if a pool phantom was actually present) confirms pooled usage; otherwise the failure is treated as a plain credential's and no pool member is cooled. - **R3 pool-stable phantom JWT:** Codex access tokens are JWTs and the per-real-token `resignJWT` would emit a *different* phantom after every cross-member refresh, breaking the "agent never notices" guarantee. The dedicated `poolStablePhantomAccess` (in `internal/proxy/oauth_response.go`) instead builds the phantom JWT from a deterministic synthetic payload keyed on the **pool name** (`sub: sluice-pool:`, `iss: sluice-phantom`, fixed far-future `exp`, no `iat`), HMAC-SHA256'd with the existing fixed key — byte-identical across member switches while still a structurally valid JWT. The pool name is JSON-marshaled (never concatenated) so a name with quotes/control chars cannot inject claims. Static-form fallback (`SLUICE_PHANTOM:.access`) is emitted only on the unreachable `json.Marshal` failure of the fixed struct (and is documented as the equivalent for an agent verified to treat the access token as opaque). The **refresh** phantom is unaffected — it stays the static `SLUICE_PHANTOM:.refresh`. **Phase 2 — auto-failover on 429 / 401:** diff --git a/internal/proxy/addon.go b/internal/proxy/addon.go index 69b6685..a8c1a0f 100644 --- a/internal/proxy/addon.go +++ b/internal/proxy/addon.go @@ -1004,6 +1004,25 @@ func (a *SluiceAddon) resolveOAuthResponseAttribution(f *mitmproxy.Flow, matched } realRefresh := extractRequestRefreshToken(reqBody, reqCT) member, ok := a.refreshAttr.Recover(realRefresh) + if ok && pr.PoolForMember(member) == "" { + // Finding 1: the injected refresh token was tagged under a name + // that is NOT a pool member. That tag is only ever recorded by + // the PLAIN-credential injection path (buildOAuthPhantomPairs' + // onRefreshInject in buildPhantomPairs), so this response is a + // normal refresh for the plain OAuth credential `member` that + // merely shares its token URL with a pool. Attribute 1:1: the + // agent receives the plain credential's own phantom and the + // rotated tokens are persisted to the plain credential's own + // vault entry. This is NOT the pooled fail-closed path — the + // refresh token uniquely identified a plain credential, so there + // is no guessing and no R1 violation. The genuine pooled + // fail-closed branch below is reached only when recovery fails + // entirely or resolves to an actual pool member. + log.Printf("[ADDON-OAUTH] Finding 1: token URL shared with pool %q but the injected "+ + "refresh token attributes to plain credential %q; persisting 1:1 to %q", + poolName, member, member) + return oauthRespAttribution{phantomName: member, persistMember: member} + } if !ok { // Recovery failed AND a pool shares this token URL (the // poolName == "" plain-only case already returned above with a @@ -1551,7 +1570,17 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, req pairs = append(pairs, oauthPairs...) continue } - oauthPairs, parseErr := buildOAuthPhantomPairs(name, secret, "ADDON-INJECT") + // Finding 1: tag the real refresh token under the PLAIN + // credential name so a plain OAuth refresh whose token URL + // is shared with a pool is recoverable on the response side + // and attributed 1:1 (its own phantom + vault write) instead + // of being mistaken for an unrecoverable pooled refresh and + // fail-closed. PoolForMember(name) is "" for a plain cred, so + // the response path can tell it apart from a pooled member. + oauthPairs, parseErr := buildOAuthPhantomPairs(name, secret, "ADDON-INJECT", + func(realRefresh string) { + a.refreshAttr.Tag(realRefresh, name) + }) if parseErr != nil { continue } @@ -1599,6 +1628,47 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, req if pr != nil { for _, credName := range idx.MatchAll(reqURL) { poolName := pr.PoolForMember(credName) + if poolName == "" { + // Finding 1: plain (non-pool) OAuth credential + // whose token URL matches this request but whose + // own API binding is on a DIFFERENT host (split + // host), so the CONNECT-host loop above produced + // no pairs for it. Without this, the plain + // credential's SLUICE_PHANTOM:.refresh would + // travel upstream verbatim (refresh fails) and — + // when a pool shares this token URL — no plain + // attribution tag would be recorded, so the + // response side would fail-close the plain + // refresh as if it were an unrecoverable pooled + // refresh. Expand it here, recording the plain + // realRefreshToken -> name tag so the response + // path attributes it 1:1 to the plain credential. + if covered[credName] { + continue + } + secret, err := a.provider.Get(credName) + if err != nil { + log.Printf("[ADDON-INJECT] token-host plain oauth %q lookup failed: %v", + credName, err) + continue + } + if !vault.IsOAuth(secret.Bytes()) { + secret.Release() + continue + } + oauthPairs, parseErr := buildOAuthPhantomPairs(credName, secret, "ADDON-INJECT", + func(realRefresh string) { + a.refreshAttr.Tag(realRefresh, credName) + }) + if parseErr != nil { + continue + } + covered[credName] = true + log.Printf("[ADDON-INJECT] token-host phantom expansion for plain oauth %q (%s)", + credName, reqURL.Host) + pairs = append(pairs, oauthPairs...) + continue + } // Gate on the POOL namespace only. covered[member] is // deliberately NOT consulted here: a plain direct // binding for the active member on this same token @@ -1606,7 +1676,7 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, req // pool-keyed phantoms the agent actually holds, so // suppressing on it would leak SLUICE_PHANTOM:.* // upstream unswapped (Finding 1, round-9). - if poolName == "" || poolEmitted[poolName] { + if poolEmitted[poolName] { continue } poolEmitted[poolName] = true diff --git a/internal/proxy/phantom_pairs.go b/internal/proxy/phantom_pairs.go index 27d1624..7791196 100644 --- a/internal/proxy/phantom_pairs.go +++ b/internal/proxy/phantom_pairs.go @@ -185,7 +185,17 @@ const maxProxyBody = 16 << 20 // pairs for the access and (optionally) refresh tokens. The caller's // raw secret is released before returning. On parse failure the secret // is still released and an error is returned. -func buildOAuthPhantomPairs(name string, secret vault.SecureBytes, logPrefix string) ([]phantomPair, error) { +// +// onRefreshInject, when supplied (variadic; at most the first element is +// used), is called with the credential's real refresh token before the +// swap injects it into the outbound refresh-grant request body. This is +// the PLAIN-credential analogue of buildPooledOAuthPhantomPairs' +// onRefreshInject: it lets the caller record a realRefreshToken -> name +// attribution tag so a plain OAuth refresh whose token URL is shared with +// a pool can be told apart from a genuine pooled refresh on the response +// side (Finding 1). Plain callers that have no attribution context +// (ws.go, quic.go) simply omit it. +func buildOAuthPhantomPairs(name string, secret vault.SecureBytes, logPrefix string, onRefreshInject ...func(realRefresh string)) ([]phantomPair, error) { cred, err := vault.ParseOAuth(secret.Bytes()) secret.Release() if err != nil { @@ -202,6 +212,14 @@ func buildOAuthPhantomPairs(name string, secret vault.SecureBytes, logPrefix str secret: accessSecret, }} if cred.RefreshToken != "" { + // Record the plain R1 join: this exact real refresh token is + // about to be injected for the plain credential `name`. The + // token-endpoint response recovers this value to attribute the + // rotated tokens back to `name` rather than fail-closing as if + // it were an unrecoverable pooled refresh. + if len(onRefreshInject) > 0 && onRefreshInject[0] != nil { + onRefreshInject[0](cred.RefreshToken) + } refreshSecret := vault.NewSecureBytes(cred.RefreshToken) refreshPhantom := []byte(oauthPhantomRefresh(name, cred.RefreshToken)) refreshEncoded := encodePhantomForPair(refreshPhantom) diff --git a/internal/proxy/pool_failover.go b/internal/proxy/pool_failover.go index 1755596..52e85c7 100644 --- a/internal/proxy/pool_failover.go +++ b/internal/proxy/pool_failover.go @@ -287,15 +287,42 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember, pr if ownerPool := pr.PoolForMember(owner); ownerPool != "" { return ownerPool, owner, proto, pr, true } - // owner is no longer in any pool (membership change - // raced the failure); the refresh-attr tag still proves - // THIS request used the pool, so fall through to the - // active-member fallback below for a still-meaningful - // attribution. - if active, aok := pr.ResolveActive(pool); aok && active != "" { + // owner is not in any pool. Two cases now collapse here + // because the refresh-attr map is no longer pool-only: + // + // (1) Round-19 Finding 1: the PLAIN-credential injection + // path tags realRefresh -> too (so the + // 2xx persist path can attribute a plain refresh on a + // shared token URL 1:1). A plain refresh must NEVER + // cool a pool member. + // (2) A genuine pooled member whose membership raced the + // failure (it left the pool between inject and + // response). + // + // The refresh-attr tag alone can no longer tell them + // apart, so it is NOT sufficient evidence for the + // active-member fallback. Require the independent + // pool-usage proof (flowInjected, set post-swap ONLY when + // a pool phantom was actually present in this request): + // case (2) still has it; case (1) never does. Without it, + // fall through to the no-evidence path below, which + // returns ok=false and cools nothing. + injected, injOK := "", false + if f.Id != uuid.Nil { + injected, injOK = a.flowInjected.Peek(f.Id) + } + if !injOK || injected == "" { + log.Printf("[POOL-FAILOVER] pool %q: token-endpoint failure "+ + "owner %q is not a pool member and no flow-injection "+ + "pool-usage tag exists; treating as a plain credential "+ + "sharing this token URL (not cooling any member)", pool, owner) + // Plain credential (round-19 Finding 1): do not cool. + // Fall through to the no-evidence return below. + } else if active, aok := pr.ResolveActive(pool); aok && active != "" { log.Printf("[POOL-FAILOVER] pool %q: token-endpoint failure "+ - "owner %q left the pool (membership raced); falling back "+ - "to active member %q", pool, owner, active) + "owner %q left the pool (membership raced, flow-injection "+ + "tag confirms pooled usage); falling back to active "+ + "member %q", pool, owner, active) return pool, active, proto, pr, true } } diff --git a/internal/proxy/pool_splithost_test.go b/internal/proxy/pool_splithost_test.go index 1fa1fab..5e65bf2 100644 --- a/internal/proxy/pool_splithost_test.go +++ b/internal/proxy/pool_splithost_test.go @@ -687,3 +687,245 @@ func TestFinding2_PlainOAuthOnSharedTokenURLDoesNotTagOrCoolPool(t *testing.T) { t.Fatalf("Finding 2 over-restriction: pool did not fail over; active = %q, want memB", active) } } + +// TestFinding1Round19_PlainCredRefreshOnSharedTokenURLPersistsNormally is the +// round-19 Finding 1 regression. A PLAIN (non-pool) OAuth credential whose +// token URL is shared with a pool refreshes normally: its own phantom is in +// the request body, NO pool phantom. Before this fix, +// resolveOAuthResponseAttribution saw "a pool shares this token URL" + +// refreshAttr.Recover failing (no pooled tag, because the plain injection +// path never recorded one) and took the pooled fail-closed branch — it +// SKIPPED the plain credential's vault write AND rewrote the response with +// the POOL phantom instead of the plain credential's own phantom. That +// breaks a legitimate standalone credential that merely shares an OAuth +// issuer with a pool. +// +// The fix tags the plain credential's real refresh token under the PLAIN +// name (PoolForMember == "" distinguishes it from a pooled member) on the +// request side, so the response side recovers it and attributes 1:1. +func TestFinding1Round19_PlainCredRefreshOnSharedTokenURLPersistsNormally(t *testing.T) { + addon, provider, prPtr := setupPoolSplitHostWithPlainCred(t) + client := setupAddonConn(addon, "auth.example.com:443") + pr := prPtr.Load() + + // Sanity: the pool's active member is memA and idx.Match returns the + // plain credential first (the collision the round-9/16 bug rode on). + if got, _ := pr.ResolveActive("codex_pool"); got != "memA" { + t.Fatalf("pre-condition active = %q, want memA", got) + } + + // The agent refreshes the PLAIN credential. Its body carries the plain + // credential's OWN refresh phantom (SLUICE_PHANTOM:aaa_plain.refresh), + // NOT any pool phantom. The token-host expansion swaps it to the plain + // credential's real refresh token and tags plain-refresh-old -> aaa_plain. + reqFlow := newTestFlow(client, "POST", testOAuthTokenURL) + reqFlow.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + reqFlow.Request.Body = []byte( + "grant_type=refresh_token&refresh_token=SLUICE_PHANTOM:aaa_plain.refresh", + ) + + addon.Requestheaders(reqFlow) + addon.Request(reqFlow) + + // (req) The plain phantom must be swapped to the plain credential's REAL + // refresh token (proves the token-host plain expansion fired) and the + // pool's real refresh token must NOT appear (no pool involvement). + reqBody := string(reqFlow.Request.Body) + if strings.Contains(reqBody, "SLUICE_PHANTOM:aaa_plain.refresh") { + t.Fatalf("plain refresh phantom not swapped on the token host; body=%q", reqBody) + } + if !strings.Contains(reqBody, "plain-refresh-old") { + t.Fatalf("plain credential's real refresh token not injected; body=%q", reqBody) + } + if strings.Contains(reqBody, "A-refresh-old") || strings.Contains(reqBody, "B-refresh-old") { + t.Fatalf("a pool member's real refresh token leaked into a PLAIN refresh; body=%q", reqBody) + } + + // No pool-usage tag may have been recorded (no pool phantom present). + if m, ok := addon.flowInjected.Peek(reqFlow.Id); ok { + t.Fatalf("plain refresh acquired a flowInjected pool-usage tag (member=%q)", m) + } + + // The upstream returns rotated tokens for the PLAIN credential. + respFlow := newPoolReqRespFlow(client, reqFlow.Request.Body, mustJSON(t, map[string]interface{}{ + "access_token": "plain-access-rotated-1", + "refresh_token": "plain-refresh-rotated-1", + "expires_in": 3600, + })) + addon.Response(respFlow) + waitAddonPersist(t, addon) + + // (persist) The PLAIN credential's vault entry MUST hold the rotated + // tokens — NOT skipped (the round-9/16 fail-closed bug skipped it). + credPlain, err := vault.ParseOAuth([]byte(provider.creds["aaa_plain"])) + if err != nil { + t.Fatalf("parse aaa_plain: %v", err) + } + if credPlain.RefreshToken != "plain-refresh-rotated-1" || + credPlain.AccessToken != "plain-access-rotated-1" { + t.Fatalf("Finding 1 round-19: plain credential refresh NOT persisted "+ + "(fail-closed mis-applied); got access=%q refresh=%q want "+ + "plain-access-rotated-1/plain-refresh-rotated-1", + credPlain.AccessToken, credPlain.RefreshToken) + } + + // The pool members' vault entries MUST be untouched. + credA, _ := vault.ParseOAuth([]byte(provider.creds["memA"])) + if credA.RefreshToken != "A-refresh-old" { + t.Fatalf("Finding 1 round-19: plain refresh misfiled into pool member memA; got %q", + credA.RefreshToken) + } + + // (phantom) The agent must receive the PLAIN credential's OWN phantom, + // NOT the pool-stable phantom (the round-9/16 bug rewrote with the pool + // phantom). + agentBody := string(respFlow.Response.Body) + if strings.Contains(agentBody, "plain-access-rotated-1") || + strings.Contains(agentBody, "plain-refresh-rotated-1") { + t.Fatalf("Finding 1 round-19: real rotated plain tokens leaked to agent; body=%q", agentBody) + } + if !strings.Contains(agentBody, "SLUICE_PHANTOM:aaa_plain.refresh") { + t.Fatalf("Finding 1 round-19: agent did not receive the plain credential's "+ + "own refresh phantom; body=%q", agentBody) + } + if !strings.Contains(agentBody, "SLUICE_PHANTOM:aaa_plain.access") { + t.Fatalf("Finding 1 round-19: agent did not receive the plain credential's "+ + "own access phantom; body=%q", agentBody) + } + if strings.Contains(agentBody, poolStablePhantomAccess("codex_pool")) || + strings.Contains(agentBody, "SLUICE_PHANTOM:codex_pool.refresh") { + t.Fatalf("Finding 1 round-19: response rewritten with the POOL phantom for a "+ + "PLAIN refresh; body=%q", agentBody) + } +} + +// TestFinding2Round19_QUICPoolBindingExpandsToActiveMember is the round-19 +// Finding 2 regression. A binding that NAMES A POOL must work over QUIC. +// Before this fix the QUIC injection path was constructed with only the +// binding resolver and called provider.Get() directly — for a +// pool-named binding that is provider.Get(), but no vault secret is +// stored under a pool name, so injection failed for that destination over +// QUIC. The fix wires the pool resolver into QUICProxy and expands a pool +// binding to its ACTIVE member (ResolveActive) before provider.Get, +// mirroring the HTTP-MITM chokepoint. +// +// QUIC-LIMITED scope (documented in CLAUDE.md): only active-member +// expansion is implemented on QUIC. The asserts below also pin the +// documented boundary — the phantom stays keyed on the POOL name (stable +// across member switches) and switching the active member changes only the +// injected SECRET, never the phantom; per-request refresh attribution and +// 429/401 auto-failover are HTTP-path only and are NOT exercised here. +func TestFinding2Round19_QUICPoolBindingExpandsToActiveMember(t *testing.T) { + caCert, _, err := GenerateCA() + if err != nil { + t.Fatalf("GenerateCA: %v", err) + } + + const poolName = "codex_pool" + provider := &addonWritableProvider{ + creds: map[string]string{ + "memA": poolMemberCred(t, "A-access-old", "A-refresh-old"), + "memB": poolMemberCred(t, "B-access-old", "B-refresh-old"), + }, + } + + // The binding NAMES THE POOL, not a member. + bindings := []vault.Binding{{ + Destination: "api.example.com", + Ports: []int{443}, + Credential: poolName, + }} + br, err := vault.NewBindingResolver(bindings) + if err != nil { + t.Fatalf("NewBindingResolver: %v", err) + } + var brPtr atomic.Pointer[vault.BindingResolver] + brPtr.Store(br) + + pool := store.Pool{Name: poolName, Strategy: store.PoolStrategyFailover} + pool.Members = []store.PoolMember{ + {Credential: "memA", Position: 0}, + {Credential: "memB", Position: 1}, + } + var prPtr atomic.Pointer[vault.PoolResolver] + prPtr.Store(vault.NewPoolResolver([]store.Pool{pool}, nil)) + + qp, err := NewQUICProxy(caCert, provider, &brPtr, &prPtr, nil, nil, nil) + if err != nil { + t.Fatalf("NewQUICProxy: %v", err) + } + + // resolvePoolMember must expand the pool name to the active member. + if got := qp.resolvePoolMember(poolName); got != "memA" { + t.Fatalf("Finding 2: resolvePoolMember(%q) = %q, want active member memA", poolName, got) + } + if got := qp.resolvePoolMember("memA"); got != "memA" { + t.Fatalf("resolvePoolMember must pass a plain/member name through; got %q", got) + } + + // buildPhantomPairs must NOT fail with provider.Get() "not found"; + // it must inject the ACTIVE member's (memA) real OAuth tokens while + // keying the phantom on the POOL name (stable across member switches). + pairs := qp.buildPhantomPairs("api.example.com", 443) + if len(pairs) == 0 { + t.Fatal("Finding 2: buildPhantomPairs returned no pairs for a pool-named " + + "binding over QUIC (pool->member expansion missing — provider.Get() failed)") + } + var sawPoolAccessPhantom, sawPoolRefreshPhantom, sawMemAAccess, sawMemARefresh bool + for _, p := range pairs { + ps := string(p.phantom) + switch ps { + case "SLUICE_PHANTOM:" + poolName + ".access": + sawPoolAccessPhantom = true + case "SLUICE_PHANTOM:" + poolName + ".refresh": + sawPoolRefreshPhantom = true + } + switch p.secret.String() { + case "A-access-old": + sawMemAAccess = true + case "A-refresh-old": + sawMemARefresh = true + } + if strings.HasPrefix(ps, "SLUICE_PHANTOM:memA") || strings.HasPrefix(ps, "SLUICE_PHANTOM:memB") { + t.Fatalf("Finding 2 QUIC-limit: phantom keyed on a MEMBER name (%q) — must "+ + "be keyed on the POOL name so it is stable across member switches", ps) + } + } + releasePhantomPairs(pairs) + if !sawPoolAccessPhantom || !sawPoolRefreshPhantom { + t.Fatalf("Finding 2: pool-keyed phantoms missing (access=%v refresh=%v)", + sawPoolAccessPhantom, sawPoolRefreshPhantom) + } + if !sawMemAAccess || !sawMemARefresh { + t.Fatalf("Finding 2: active member memA's real OAuth tokens not injected "+ + "(access=%v refresh=%v)", sawMemAAccess, sawMemARefresh) + } + + // Documented QUIC boundary: flipping the active member changes ONLY the + // injected secret; the phantom the agent holds stays pool-keyed and + // byte-identical (no per-request attribution / failover on QUIC, but + // the active member IS honored). + prPtr.Load().MarkCooldown("memA", time.Now().Add(time.Minute), "429") + if got := qp.resolvePoolMember(poolName); got != "memB" { + t.Fatalf("Finding 2: after cooling memA, resolvePoolMember = %q, want memB", got) + } + pairs2 := qp.buildPhantomPairs("api.example.com", 443) + var sawMemBRefresh, stillPoolKeyed bool + for _, p := range pairs2 { + if p.secret.String() == "B-refresh-old" { + sawMemBRefresh = true + } + if string(p.phantom) == "SLUICE_PHANTOM:"+poolName+".refresh" { + stillPoolKeyed = true + } + } + releasePhantomPairs(pairs2) + if !sawMemBRefresh { + t.Fatal("Finding 2: after failover the new active member memB's real refresh " + + "token was not injected over QUIC") + } + if !stillPoolKeyed { + t.Fatal("Finding 2 QUIC-limit: phantom changed across member switch — it must " + + "stay keyed on the pool name (R3-style stability) even on QUIC") + } +} diff --git a/internal/proxy/quic.go b/internal/proxy/quic.go index 6aa472d..4c03f61 100644 --- a/internal/proxy/quic.go +++ b/internal/proxy/quic.go @@ -72,8 +72,27 @@ type QUICProxy struct { caX509 *x509.Certificate provider vault.Provider resolver *atomic.Pointer[vault.BindingResolver] - audit *audit.FileLogger - rules atomic.Pointer[quicInspectRules] + // poolResolver expands a binding that NAMES A POOL to the pool's + // active member before the vault lookup, mirroring the HTTP-MITM + // chokepoint (SluiceAddon.resolveInjectionTarget). Without it a + // pool-named binding would call provider.Get() — there is no + // vault secret stored under a pool name — and injection would fail + // for that destination over QUIC (Finding 2). Optional: nil means + // no pools are configured and every binding name is taken verbatim. + // + // QUIC pool support is intentionally limited to active-member + // expansion. The per-request OAuth refresh attribution (Risk R1), + // pool-stable phantom keying (Risk R3), and 429/401 auto-failover + // implemented in the HTTP-MITM addon are NOT replicated here: the + // QUIC injection path is a simpler buffered header/body swap with + // no response-side OAuth interception. A pool binding over QUIC + // injects the CURRENT active member's real credential; member + // rotation happens only when the HTTP path (or an operator) flips + // the active member. See CLAUDE.md "Credential pools" for the + // authoritative HTTP-vs-QUIC capability matrix. + poolResolver *atomic.Pointer[vault.PoolResolver] + audit *audit.FileLogger + rules atomic.Pointer[quicInspectRules] // oauthIndex points at the same OAuthIndex the SluiceAddon uses // so QUIC/HTTP3 header injection follows the same OAuth-vs-static @@ -134,6 +153,7 @@ func NewQUICProxy( caCert tls.Certificate, provider vault.Provider, resolver *atomic.Pointer[vault.BindingResolver], + poolResolver *atomic.Pointer[vault.PoolResolver], auditLog *audit.FileLogger, blockConfigs []QUICBlockRuleConfig, redactConfigs []QUICRedactRuleConfig, @@ -147,11 +167,12 @@ func NewQUICProxy( } } qp := &QUICProxy{ - caCert: caCert, - caX509: caX509, - provider: provider, - resolver: resolver, - audit: auditLog, + caCert: caCert, + caX509: caX509, + provider: provider, + resolver: resolver, + poolResolver: poolResolver, + audit: auditLog, } if err := qp.UpdateRules(blockConfigs, redactConfigs); err != nil { return nil, err @@ -423,12 +444,18 @@ func (q *QUICProxy) buildHandler(upstreamHost string, destPort int, checker *Req // Binding-specific header injection. if res := q.resolver.Load(); res != nil { if binding, ok := res.ResolveForProtocol(host, port, ProtoQUIC.String()); ok { - secret, err := q.provider.Get(binding.Credential) + // Finding 2: a binding may name a pool. Expand to the + // active member before the vault lookup AND before the + // OAuth-envelope decision (extractInjectableSecret keys + // off credential_meta, which has no entry for a pool + // name), exactly as the HTTP-MITM chokepoint does. + secretName := q.resolvePoolMember(binding.Credential) + secret, err := q.provider.Get(secretName) if err != nil { - log.Printf("[QUIC-MITM] credential %q lookup failed: %v", binding.Credential, err) + log.Printf("[QUIC-MITM] credential %q lookup failed: %v", secretName, err) } else { if binding.Header != "" { - r.Header.Set(binding.Header, binding.FormatValue(extractInjectableSecret(q.oauthIndex.Load(), binding.Credential, secret.String()))) + r.Header.Set(binding.Header, binding.FormatValue(extractInjectableSecret(q.oauthIndex.Load(), secretName, secret.String()))) } secret.Release() } @@ -545,6 +572,32 @@ func (q *QUICProxy) buildHandler(upstreamHost string, destPort int, checker *Req }) } +// resolvePoolMember expands a binding name that NAMES A POOL to the pool's +// current active member, mirroring SluiceAddon.resolveInjectionTarget on the +// HTTP-MITM path (Finding 2). A plain credential name (or any name when no +// pool resolver is configured) is returned verbatim. An empty or +// unresolvable pool returns the pool name unchanged so the downstream +// provider.Get fails cleanly (no injection) rather than panicking. +// +// QUIC-LIMITED: this performs ONLY active-member expansion. The HTTP path's +// per-request refresh attribution (R1), pool-stable phantom (R3), and +// 429/401 auto-failover are not implemented on QUIC; the active member is +// whatever the HTTP path / operator last selected. Documented in CLAUDE.md. +func (q *QUICProxy) resolvePoolMember(name string) string { + if q.poolResolver == nil { + return name + } + pr := q.poolResolver.Load() + if pr == nil || !pr.IsPool(name) { + return name + } + member, ok := pr.ResolveActive(name) + if !ok || member == "" { + return name + } + return member +} + // buildPhantomPairs resolves credentials bound to the destination and returns // phantom/secret pairs sorted by phantom length descending. // @@ -556,23 +609,31 @@ func (q *QUICProxy) buildHandler(upstreamHost string, destPort int, checker *Req func (q *QUICProxy) buildPhantomPairs(host string, port int) []phantomPair { var pairs []phantomPair if res := q.resolver.Load(); res != nil { - for _, name := range res.CredentialsForDestination(host, port, ProtoQUIC.String()) { - secret, err := q.provider.Get(name) + for _, boundName := range res.CredentialsForDestination(host, port, ProtoQUIC.String()) { + // Finding 2: expand a pool-named binding to its active + // member before the vault lookup. The phantom the agent + // holds is keyed on the BOUND name (pool name when pooled, + // so it is stable across member switches); only the injected + // secret comes from the active member's vault entry. + secretName := q.resolvePoolMember(boundName) + secret, err := q.provider.Get(secretName) if err != nil { - log.Printf("[QUIC-MITM] credential %q lookup failed: %v", name, err) + log.Printf("[QUIC-MITM] credential %q lookup failed: %v", secretName, err) continue } // Check if this is an OAuth credential. If so, build two phantom - // pairs (access + refresh) instead of one static pair. + // pairs (access + refresh) instead of one static pair. The + // phantom is keyed on boundName so a pooled OAuth member swap + // does not change the phantom the agent already holds. if vault.IsOAuth(secret.Bytes()) { - oauthPairs, parseErr := buildOAuthPhantomPairs(name, secret, "QUIC-MITM") + oauthPairs, parseErr := buildOAuthPhantomPairs(boundName, secret, "QUIC-MITM") if parseErr != nil { continue } pairs = append(pairs, oauthPairs...) continue } - phantom := []byte(PhantomToken(name)) + phantom := []byte(PhantomToken(boundName)) encoded := encodePhantomForPair(phantom) pairs = append(pairs, phantomPair{ phantom: phantom, diff --git a/internal/proxy/quic_test.go b/internal/proxy/quic_test.go index 21016af..bffe962 100644 --- a/internal/proxy/quic_test.go +++ b/internal/proxy/quic_test.go @@ -129,7 +129,7 @@ func TestQUICProxy_HandshakeSucceeds(t *testing.T) { } var resolver atomic.Pointer[vault.BindingResolver] - qp, err := NewQUICProxy(caCert, &stubQUICProvider{}, &resolver, nil, nil, nil) + qp, err := NewQUICProxy(caCert, &stubQUICProvider{}, &resolver, nil, nil, nil, nil) if err != nil { t.Fatalf("NewQUICProxy: %v", err) } @@ -200,7 +200,7 @@ func TestQUICProxy_SNIExtraction(t *testing.T) { } var resolver atomic.Pointer[vault.BindingResolver] - qp, err := NewQUICProxy(caCert, &stubQUICProvider{}, &resolver, nil, nil, nil) + qp, err := NewQUICProxy(caCert, &stubQUICProvider{}, &resolver, nil, nil, nil, nil) if err != nil { t.Fatalf("NewQUICProxy: %v", err) } @@ -291,7 +291,7 @@ func setupQUICProxyForH3( resolverPtr.Store(resolver) } - qp, err := NewQUICProxy(caCert, provider, &resolverPtr, nil, blockRules, redactRules) + qp, err := NewQUICProxy(caCert, provider, &resolverPtr, nil, nil, blockRules, redactRules) if err != nil { t.Fatalf("NewQUICProxy: %v", err) } diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 7173da1..4a2adbd 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -817,7 +817,7 @@ func (s *Server) setupInjection(cfg Config, _ net.Listener) error { s.mailProxy = NewMailProxy(cfg.Provider, &caCert) // QUIC proxy for HTTP/3 MITM credential injection over UDP. - qp, qpErr := NewQUICProxy(caCert, cfg.Provider, &s.resolver, cfg.Audit, cfg.QUICBlockRules, cfg.QUICRedactRules) + qp, qpErr := NewQUICProxy(caCert, cfg.Provider, &s.resolver, &s.poolResolver, cfg.Audit, cfg.QUICBlockRules, cfg.QUICRedactRules) if qpErr != nil { log.Printf("QUIC proxy disabled: %v", qpErr) } else { From 19a55f32b82b4a7dc81d40c9b2e846488d33efe8 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 20:47:46 +0800 Subject: [PATCH 44/49] fix: R3-stable QUIC pool phantom; pool+epoch in refresh attribution; atomic RemovePoolIfUnreferenced; REST cred-remove missing-secret tolerant --- cmd/sluice/pool.go | 30 ++-- internal/api/server.go | 12 +- internal/api/server_test.go | 91 +++++++++++ internal/proxy/addon.go | 58 +++++-- internal/proxy/pool_attribution.go | 85 +++++++--- internal/proxy/pool_phantom_test.go | 145 ++++++++++++++++- internal/proxy/pool_splithost_test.go | 131 +++++++++++++++- internal/proxy/quic.go | 70 +++++++-- internal/store/pools.go | 113 +++++++++++--- internal/store/pools_test.go | 214 ++++++++++++++++++++++++++ 10 files changed, 875 insertions(+), 74 deletions(-) diff --git a/cmd/sluice/pool.go b/cmd/sluice/pool.go index 74bafb9..4b1a817 100644 --- a/cmd/sluice/pool.go +++ b/cmd/sluice/pool.go @@ -1,6 +1,7 @@ package main import ( + "errors" "flag" "fmt" "strings" @@ -281,21 +282,24 @@ func handlePoolRemove(args []string) error { // name would silently inherit the stale bindings. This mirrors the // fail-closed pool-membership guard in "sluice cred remove": refuse // with an actionable error instead of cascading or orphaning. - refs, err := db.ListBindingsByCredential(name) + // + // Finding 3: the reference check and the pool delete MUST be atomic. + // RemovePoolIfUnreferenced folds both into ONE store transaction so a + // concurrent "sluice binding add " cannot commit in a window + // between a separate pre-check and the delete and leave a binding + // pointing at a now-deleted pool. The store method is the authoritative + // atomic gate; this CLI layer only formats its typed error. + removed, err := db.RemovePoolIfUnreferenced(name) if err != nil { - return fmt.Errorf("check bindings referencing pool %q: %w", name, err) - } - if len(refs) > 0 { - details := make([]string, len(refs)) - for i, b := range refs { - details[i] = fmt.Sprintf("[%d] %s", b.ID, b.Destination) + var refErr *store.PoolReferencedError + if errors.As(err, &refErr) { + details := make([]string, len(refErr.Bindings)) + for i, b := range refErr.Bindings { + details[i] = fmt.Sprintf("[%d] %s", b.ID, b.Destination) + } + return fmt.Errorf("pool %q is still referenced by %d binding(s): %s; rebind or remove these bindings first (sluice binding remove , which also clears the auto-created allow rule), then retry pool remove", + name, len(refErr.Bindings), strings.Join(details, ", ")) } - return fmt.Errorf("pool %q is still referenced by %d binding(s): %s; rebind or remove these bindings first (sluice binding remove , which also clears the auto-created allow rule), then retry pool remove", - name, len(refs), strings.Join(details, ", ")) - } - - removed, err := db.RemovePool(name) - if err != nil { return err } if !removed { diff --git a/internal/api/server.go b/internal/api/server.go index 3474a3e..b74fb5e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -1283,7 +1283,17 @@ func (s *Server) DeleteApiCredentialsName(w http.ResponseWriter, r *http.Request // Store removal already succeeded above (the pool-membership gate // passed and credential_meta+health+bindings+rules are gone atomically). // Only now is it safe to delete the vault secret. - if err := s.vault.Remove(name); err != nil { + // + // Finding 4: an already-missing vault secret is NOT a failure. The + // store cleanup has already committed; treating os.IsNotExist as a + // hard 500 here would abort BEFORE the engine recompile / resolver + // rebuild, leaving the live policy stale until a manual reload — and + // it makes using the API to finish a previous partial cleanup + // impossible. The CLI (cmd/sluice/cred.go) and Telegram + // (internal/telegram/commands.go) paths already treat os.IsNotExist as + // success and continue; match them. Any OTHER vault error is still a + // hard 500 (do not swallow real failures). + if err := s.vault.Remove(name); err != nil && !os.IsNotExist(err) { writeError(w, http.StatusInternalServerError, "failed to remove credential: "+err.Error(), "") return } diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 91ae9ef..6074fcf 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -1988,6 +1988,97 @@ func TestDeleteApiCredentials_CascadesToMeta(t *testing.T) { } } +// TestDeleteApiCredentials_MissingVaultSecretIsNotFatal is the Finding 4 +// regression. Two concurrent DELETEs WITHOUT a reloadMu wired both pass the +// vault.List() existence check, both run the (idempotent) store cleanup, and +// both call vault.Remove. The loser's vault.Remove returns an os.IsNotExist +// error because the winner already deleted the .age file. Before the fix the +// REST handler treated that as a hard HTTP 500 — AFTER the store-side +// cleanup (meta/bindings/rules) had committed and BEFORE the engine +// recompile — so the live policy was left stale and the API could not be +// used to finish a previous partial cleanup. The CLI and Telegram paths +// already treat os.IsNotExist as success; the REST path must match. +// +// Invariant asserted over many interleavings: NO 500 ever (a 500 means the +// loser aborted after wiping store state, the exact bug), at least one 204, +// and the final state is fully clean in BOTH the vault and the store. +func TestDeleteApiCredentials_MissingVaultSecretIsNotFatal(t *testing.T) { + for iter := 0; iter < 50; iter++ { + st := newTestStore(t) + enableHTTPChannel(t, st) + v := newTestVault(t) + srv := api.NewServer(st, nil, nil, "") + srv.SetVault(v) + // Deliberately do NOT wire reloadMu (SetEnginePtr): that is what + // lets both concurrent handlers pass the List() check and race on + // vault.Remove, so the loser hits the os.IsNotExist path the fix + // targets. An engine pointer alone (no mutex) still exercises + // recompileEngine after the store cleanup. + srv.SetEnginePtr(new(atomic.Pointer[policy.Engine]), nil) + + if _, err := v.Add("dup", "value"); err != nil { + t.Fatalf("iter %d: add: %v", iter, err) + } + if _, err := st.AddBinding("api.example.com", "dup", store.BindingOpts{}); err != nil { + t.Fatalf("iter %d: add binding: %v", iter, err) + } + + t.Setenv("SLUICE_API_TOKEN", "tok") + handler := newTestHandler(t, srv, st) + + var wg sync.WaitGroup + wg.Add(2) + codes := make([]int, 2) + bodies := make([]string, 2) + for i := 0; i < 2; i++ { + idx := i + go func() { + defer wg.Done() + req := httptest.NewRequest("DELETE", "/api/credentials/dup", nil) + req.Header.Set("Authorization", "Bearer tok") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + codes[idx] = rec.Code + bodies[idx] = rec.Body.String() + }() + } + wg.Wait() + + has204 := false + for i, c := range codes { + switch c { + case http.StatusNoContent: + has204 = true + case http.StatusNotFound: + // Loser blocked at the List() precondition (the winner + // committed first) — acceptable, no store wipe happened. + case http.StatusInternalServerError: + t.Fatalf("iter %d: Finding 4: DELETE returned 500 (%q) — an "+ + "already-missing vault secret must NOT be fatal after the "+ + "store cleanup committed", iter, bodies[i]) + default: + t.Fatalf("iter %d: unexpected status %d (%q)", iter, c, bodies[i]) + } + } + if !has204 { + t.Fatalf("iter %d: expected at least one 204, got %v", iter, codes) + } + + // Final state must be fully clean in BOTH stores (no partial + // cleanup left behind by a fail-closed loser). + if names, err := v.List(); err != nil { + t.Fatalf("iter %d: vault list: %v", iter, err) + } else if len(names) != 0 { + t.Fatalf("iter %d: vault not clean: %v", iter, names) + } + if b, err := st.ListBindings(); err != nil { + t.Fatalf("iter %d: list bindings: %v", iter, err) + } else if len(b) != 0 { + t.Fatalf("iter %d: bindings not clean: %d", iter, len(b)) + } + } +} + func TestPostApiCredentials_StaticWithMetaCreated(t *testing.T) { st := newTestStore(t) enableHTTPChannel(t, st) diff --git a/internal/proxy/addon.go b/internal/proxy/addon.go index a8c1a0f..690cc46 100644 --- a/internal/proxy/addon.go +++ b/internal/proxy/addon.go @@ -1003,8 +1003,8 @@ func (a *SluiceAddon) resolveOAuthResponseAttribution(f *mitmproxy.Flow, matched reqBody = f.Request.Body } realRefresh := extractRequestRefreshToken(reqBody, reqCT) - member, ok := a.refreshAttr.Recover(realRefresh) - if ok && pr.PoolForMember(member) == "" { + member, capPool, capEpoch, capPooled, ok := a.refreshAttr.RecoverIdentity(realRefresh) + if ok && !capPooled && pr.PoolForMember(member) == "" { // Finding 1: the injected refresh token was tagged under a name // that is NOT a pool member. That tag is only ever recorded by // the PLAIN-credential injection path (buildOAuthPhantomPairs' @@ -1044,13 +1044,33 @@ func (a *SluiceAddon) resolveOAuthResponseAttribution(f *mitmproxy.Flow, matched "under the wrong credential (next refresh will retry)", poolName) return oauthRespAttribution{phantomName: poolName, pooled: true, skipPersist: true} } - // The recovered member's own pool is authoritative (a membership change - // could have raced; attribute to whatever pool the member is in now). - if mp := pr.PoolForMember(member); mp != "" { - poolName = mp - } - log.Printf("[ADDON-OAUTH] R1 attributed pooled refresh to member %q (pool %q)", member, poolName) - return oauthRespAttribution{phantomName: poolName, persistMember: member, pooled: true} + // Finding 2 (round 20): the recovered tag is pooled. Do NOT trust the + // member's CURRENT pool (pr.PoolForMember / IdentityForMember as they + // stand now) — that races a membership change. Validate the pool + // identity captured AT INJECTION TIME against the live membership. The + // pool whose phantom was actually swapped into this request is capPool + // at generation capEpoch; if the member was removed and re-added (a + // strictly greater epoch) or moved to a different pool between + // injection and this response, attributing/persisting against the new + // generation rewrites the response with the wrong-generation pool + // phantom and misfiles the rotated tokens. That is the documented + // fail-closed/stale case: keep the agent-facing phantom keyed on the + // CAPTURED pool (so the swap that already ran stays byte-consistent) + // but skip the vault write and the pooled audit attribution. + livePool, liveEpoch, liveOK := pr.IdentityForMember(member) + if !liveOK || livePool != capPool || liveEpoch != capEpoch { + log.Printf("[ADDON-OAUTH] Finding 2 stale: pooled refresh for member %q was injected "+ + "under pool %q epoch %d but the live membership is %q epoch %d (ok=%v); "+ + "membership raced (member removed/re-added) — skipping vault write and "+ + "pooled attribution to avoid misfiling against the wrong generation", + member, capPool, capEpoch, livePool, liveEpoch, liveOK) + return oauthRespAttribution{phantomName: capPool, pooled: true, skipPersist: true} + } + // Same generation: the captured identity still matches the live + // membership, so the response genuinely belongs to capPool at capEpoch. + log.Printf("[ADDON-OAUTH] R1 attributed pooled refresh to member %q (pool %q epoch %d)", + member, capPool, capEpoch) + return oauthRespAttribution{phantomName: capPool, persistMember: member, pooled: true} } // processOAuthResponseIfMatching performs OAuth token phantom swap on the @@ -1731,10 +1751,28 @@ func (a *SluiceAddon) buildPhantomPairs(host string, port int, proto string, req // token-endpoint failover path). Shared by the CONNECT-host binding loop and // the Finding 4 token-host expansion so both record attribution identically. func (a *SluiceAddon) buildPooledMemberPairs(poolName, member string, secret vault.SecureBytes) ([]phantomPair, error) { + // Finding 2 (round 20): capture the pool identity (pool name + + // membership epoch) as observed RIGHT NOW, at injection time, and pin + // it to the refresh-attribution tag. The 2xx persist path validates + // this captured identity against the live membership before persisting + // so a response that arrives after the member was removed and re-added + // (a strictly greater epoch) is not silently misfiled against the new + // generation. Fall back to the resolveInjectionTarget poolName with a + // sentinel epoch if the resolver cannot supply an identity (e.g. the + // member just left): the response path's mismatch check then fails + // closed rather than guessing. + idPool, idEpoch := poolName, int64(-1) + if a.poolResolver != nil { + if pr := a.poolResolver.Load(); pr != nil { + if p, e, ok := pr.IdentityForMember(member); ok { + idPool, idEpoch = p, e + } + } + } return buildPooledOAuthPhantomPairs( poolName, member, secret, "ADDON-INJECT", func(realRefresh string) { - a.refreshAttr.Tag(realRefresh, member) + a.refreshAttr.TagPooled(realRefresh, member, idPool, idEpoch) }, ) } diff --git a/internal/proxy/pool_attribution.go b/internal/proxy/pool_attribution.go index a10548e..92e386d 100644 --- a/internal/proxy/pool_attribution.go +++ b/internal/proxy/pool_attribution.go @@ -172,8 +172,29 @@ type refreshAttribution struct { entries map[string]refreshAttrEntry } +// refreshAttrEntry records the member a real refresh token was injected for, +// plus the INJECTION-TIME pool identity (Finding 2, round 20). +// +// Storing only the member name was insufficient: a token-endpoint response +// arriving after the member was removed and re-added into the same-named pool +// (a strictly greater membership epoch — the round-18 mechanism) would be +// attributed via the member's CURRENT pool generation rather than the pool +// whose phantom was actually swapped into the request. That rewrites the +// response with the wrong-generation pool phantom and persists/audits against +// the wrong epoch. The entry therefore captures {pool, epoch} as observed at +// injection time so the response path can detect a raced membership change +// and fail closed instead of silently misfiling. +// +// Plain (non-pooled) OAuth credentials carry the sentinel pool=="" / +// epoch==-1 (round-19 Finding 1: a plain refresh on a shared token URL is +// attributed 1:1 and must never be treated as pooled). pooled reports +// whether this is a pooled tag so a plain entry with no identity is not +// confused with a pooled entry whose identity happened to be zero-valued. type refreshAttrEntry struct { member string + pool string + epoch int64 + pooled bool expires time.Time } @@ -181,28 +202,46 @@ func newRefreshAttribution() *refreshAttribution { return &refreshAttribution{entries: make(map[string]refreshAttrEntry)} } -// Tag records that the given real refresh token was injected for member. -// Called from the pass-2 phantom swap when the phantom being replaced is a -// pooled credential's `.refresh` phantom. A best-effort opportunistic sweep -// of expired entries keeps the map bounded without a background goroutine. -func (r *refreshAttribution) Tag(realRefreshToken, member string) { - if realRefreshToken == "" || member == "" { +func (r *refreshAttribution) tag(realRefreshToken string, e refreshAttrEntry) { + if realRefreshToken == "" || e.member == "" { return } now := time.Now() r.mu.Lock() defer r.mu.Unlock() if len(r.entries) > 0 { - for k, e := range r.entries { - if now.After(e.expires) { + for k, en := range r.entries { + if now.After(en.expires) { delete(r.entries, k) } } } - r.entries[realRefreshToken] = refreshAttrEntry{ - member: member, - expires: now.Add(refreshAttrTTL), - } + e.expires = now.Add(refreshAttrTTL) + r.entries[realRefreshToken] = e +} + +// Tag records that the given real refresh token was injected for a PLAIN +// (non-pooled) OAuth credential. The entry carries the sentinel identity +// (pool=="", epoch==-1, pooled==false) so the 2xx persist path attributes it +// 1:1 (round-19 Finding 1) and never runs the pooled epoch-staleness check. +func (r *refreshAttribution) Tag(realRefreshToken, member string) { + r.tag(realRefreshToken, refreshAttrEntry{member: member, epoch: -1}) +} + +// TagPooled records that the given real refresh token was injected for a +// POOL member, capturing the pool identity (pool name + membership epoch) +// observed at injection time. The 2xx persist path validates this captured +// identity against the live membership before persisting: if the member was +// removed and re-added (a strictly greater epoch) or moved to a different +// pool between injection and response, the response is treated as the +// documented fail-closed/stale case (Finding 2, round 20). +func (r *refreshAttribution) TagPooled(realRefreshToken, member, pool string, epoch int64) { + r.tag(realRefreshToken, refreshAttrEntry{ + member: member, + pool: pool, + epoch: epoch, + pooled: true, + }) } // Recover returns the member tagged for the given real refresh token and @@ -215,20 +254,30 @@ func (r *refreshAttribution) Tag(realRefreshToken, member string) { // refresh token, so the tag is dead after one use and must be deleted to // bound the map. func (r *refreshAttribution) Recover(realRefreshToken string) (string, bool) { + member, _, _, _, ok := r.RecoverIdentity(realRefreshToken) + return member, ok +} + +// RecoverIdentity is Recover plus the injection-time pool identity. It is +// single-use (deletes the entry) and used by the 2xx persist path so it can +// validate the captured {pool, epoch} against the live membership before +// persisting (Finding 2, round 20). For a plain entry pooled is false and +// pool/epoch carry the sentinel. +func (r *refreshAttribution) RecoverIdentity(realRefreshToken string) (member, pool string, epoch int64, pooled, ok bool) { if realRefreshToken == "" { - return "", false + return "", "", -1, false, false } r.mu.Lock() defer r.mu.Unlock() - e, ok := r.entries[realRefreshToken] - if !ok { - return "", false + e, found := r.entries[realRefreshToken] + if !found { + return "", "", -1, false, false } delete(r.entries, realRefreshToken) if time.Now().After(e.expires) { - return "", false + return "", "", -1, false, false } - return e.member, true + return e.member, e.pool, e.epoch, e.pooled, true } // Peek returns the member tagged for the given real refresh token WITHOUT diff --git a/internal/proxy/pool_phantom_test.go b/internal/proxy/pool_phantom_test.go index 48f2761..d2d375e 100644 --- a/internal/proxy/pool_phantom_test.go +++ b/internal/proxy/pool_phantom_test.go @@ -107,6 +107,24 @@ func poolMemberCred(t *testing.T, access, refresh string) string { return string(data) } +// tagPooledForTest records a pooled refresh-attribution tag the same way the +// production injection path does (buildPooledMemberPairs): it captures the +// member's CURRENT pool+epoch identity from the live resolver and pins it to +// the tag via TagPooled. Tests that previously called the bare member-only +// refreshAttr.Tag to simulate a pooled injection must use this so the +// Finding 2 (round 20) injection-time-identity validation in the 2xx persist +// path sees a faithful, same-generation entry (and still fails closed when a +// genuine membership race is simulated). +func tagPooledForTest(t *testing.T, addon *SluiceAddon, prPtr *atomic.Pointer[vault.PoolResolver], realRefresh, member string) { + t.Helper() + pr := prPtr.Load() + pool, epoch, ok := pr.IdentityForMember(member) + if !ok { + t.Fatalf("tagPooledForTest: %q has no pool identity in the live resolver", member) + } + addon.refreshAttr.TagPooled(realRefresh, member, pool, epoch) +} + // setupPoolAddon wires a SluiceAddon with a two-member pool bound to // auth.example.com. Both members share testOAuthTokenURL (the Risk R1 // collision shape: two Codex accounts behind one OpenAI token endpoint). @@ -220,7 +238,13 @@ func TestR3PoolPhantomByteIdenticalAcrossMemberSwitch(t *testing.T) { // Member A active. Request body carries A's real refresh token (as if // pass-2 already swapped it), upstream returns A's rotated tokens. reqA := []byte("grant_type=refresh_token&refresh_token=A-refresh-old") - addon.refreshAttr.Tag("A-refresh-old", "codexA") + // Faithfully mirror what the production pooled injection path records + // (buildPooledMemberPairs -> TagPooled with the resolver's + // injection-time pool+epoch identity), not the bare member-only Tag + // the old single-arg API used. Finding 2 (round 20): the 2xx persist + // path now validates the captured {pool, epoch} against the live + // membership, so a faithful test must record the pooled identity. + tagPooledForTest(t, addon, prPtr, "A-refresh-old", "codexA") respA := mustJSON(t, map[string]interface{}{ "access_token": "A-real-access-NEW-aaaaaaaa", "refresh_token": "A-real-refresh-NEW-aaaaaaaa", @@ -245,7 +269,7 @@ func TestR3PoolPhantomByteIdenticalAcrossMemberSwitch(t *testing.T) { } reqB := []byte("grant_type=refresh_token&refresh_token=B-refresh-old") - addon.refreshAttr.Tag("B-refresh-old", "codexB") + tagPooledForTest(t, addon, prPtr, "B-refresh-old", "codexB") respB := mustJSON(t, map[string]interface{}{ "access_token": "B-real-access-NEW-bbbbbbbbbbbb", "refresh_token": "B-real-refresh-NEW-bbbbbbbbbbbb", @@ -268,6 +292,123 @@ func TestR3PoolPhantomByteIdenticalAcrossMemberSwitch(t *testing.T) { } } +// poolResolverGen builds a single-pool resolver generation with an explicit +// membership epoch (the round-18 mechanism). Re-running with a higher epoch +// simulates the member being removed and re-added under the same name. +func poolResolverGen(pool, member string, epoch int64) *vault.PoolResolver { + p := store.Pool{Name: pool, Strategy: store.PoolStrategyFailover} + p.Members = []store.PoolMember{{Credential: member, Position: 0, Epoch: epoch}} + return vault.NewPoolResolver([]store.Pool{p}, nil) +} + +// TestFinding2Round20_StalePooledRefreshNotMisfiledAcrossEpoch asserts the +// injection-time pool+epoch capture. A pooled refresh is tagged for member M +// in pool P at epoch e1. Before the token-endpoint response arrives, M is +// removed and re-added into the SAME pool P at a strictly greater epoch e2 +// (the round-18 membership-epoch bump). The 2xx persist path must NOT rewrite +// the response/persist against the new generation e2 (fail-closed/stale): the +// agent still gets a pool-stable phantom (the swap ran, no real token leaks) +// but the vault write is skipped and no persist signal fires. +// +// Fail-before: the old code did `pr.PoolForMember(member)` at response time +// (current membership), attributing/persisting against e2. Pass-after: the +// captured {P, e1} no longer matches live {P, e2}, so it fails closed. +// +// The test also asserts a genuine SAME-generation pooled refresh (no +// membership race) still attributes and persists correctly (no regression of +// round-9/16/18/19). +func TestFinding2Round20_StalePooledRefreshNotMisfiledAcrossEpoch(t *testing.T) { + const poolName = "codex_pool" + addon, provider, prPtr := setupPoolAddon(t, "codexA", "codexB") + + // Pin the resolver to generation epoch 1 (single member codexA for a + // crisp identity; codexB stays a registered OAuth cred sharing the + // token URL so the pooled-token-URL path is exercised). + gen1 := poolResolverGen(poolName, "codexA", 1) + prPtr.Store(gen1) + p1, e1, ok := gen1.IdentityForMember("codexA") + if !ok || e1 != 1 { + t.Fatalf("precondition: gen1 identity = (%q,%d,%v), want (codex_pool,1,true)", p1, e1, ok) + } + + client := setupAddonConn(addon, "auth.example.com:443") + + // --- Stale case: tag under epoch 1, then swap to epoch 2 before resp. + addon.refreshAttr.TagPooled("A-refresh-old", "codexA", p1, e1) + + // Remove + re-add codexA into the SAME pool at epoch 2 (membership + // raced between injection and response). + gen2 := poolResolverGen(poolName, "codexA", 2) + prPtr.Store(gen2) + if _, e2, _ := gen2.IdentityForMember("codexA"); e2 != 2 { + t.Fatalf("precondition: gen2 epoch = %d, want 2", e2) + } + + respStale := mustJSON(t, map[string]interface{}{ + "access_token": "A-real-access-STALE", + "refresh_token": "A-real-refresh-STALE", + "expires_in": 3600, + }) + fStale := newPoolReqRespFlow(client, + []byte("grant_type=refresh_token&refresh_token=A-refresh-old"), respStale) + addon.Response(fStale) + + // Fail-closed: NO persist must fire (drain timeout, not the 5s + // waitAddonPersist which would falsely fail on a correct skip). + select { + case <-addon.persistDone: + t.Fatal("Finding 2: a STALE pooled refresh (injected epoch 1, live epoch 2) " + + "was persisted — must fail closed and skip the vault write") + case <-time.After(300 * time.Millisecond): + } + + // Vault entry for codexA must be UNTOUCHED (not misfiled against e2). + credA, err := vault.ParseOAuth([]byte(provider.creds["codexA"])) + if err != nil { + t.Fatalf("parse codexA: %v", err) + } + if credA.AccessToken != "A-access-old" || credA.RefreshToken != "A-refresh-old" { + t.Fatalf("Finding 2 VIOLATION: stale pooled refresh misfiled rotated tokens "+ + "against the new generation; vault access=%q refresh=%q", + credA.AccessToken, credA.RefreshToken) + } + + // The agent still receives the pool-stable phantom (swap ran) and the + // real upstream token never leaks. + staleBody := string(fStale.Response.Body) + if !strings.Contains(staleBody, poolStablePhantomAccess(poolName)) { + t.Fatalf("Finding 2: stale response missing pool-stable phantom (swap must "+ + "still run so the agent never sees real tokens); body=%q", staleBody) + } + if strings.Contains(staleBody, "A-real-access-STALE") { + t.Fatal("Finding 2: real upstream access token leaked to the agent on the stale path") + } + + // --- Same-generation case: tag and respond under the SAME live epoch. + // Must attribute + persist normally (no regression). + cur := prPtr.Load() + pCur, eCur, _ := cur.IdentityForMember("codexA") + addon.refreshAttr.TagPooled("A-refresh-old", "codexA", pCur, eCur) + respOK := mustJSON(t, map[string]interface{}{ + "access_token": "A-real-access-FRESH", + "refresh_token": "A-real-refresh-FRESH", + "expires_in": 3600, + }) + fOK := newPoolReqRespFlow(client, + []byte("grant_type=refresh_token&refresh_token=A-refresh-old"), respOK) + addon.Response(fOK) + waitAddonPersist(t, addon) + + credA2, err := vault.ParseOAuth([]byte(provider.creds["codexA"])) + if err != nil { + t.Fatalf("parse codexA after fresh: %v", err) + } + if credA2.AccessToken != "A-real-access-FRESH" || credA2.RefreshToken != "A-real-refresh-FRESH" { + t.Fatalf("regression: genuine same-generation pooled refresh did not persist; "+ + "vault access=%q refresh=%q", credA2.AccessToken, credA2.RefreshToken) + } +} + // TestPooledAccessPhantomSwappedInQueryAndPath is the round-18 #5 // regression. A pooled OAuth credential's access phantom is the R3 // pool-stable SYNTHETIC JWT (poolStablePhantomAccess) — it has NO diff --git a/internal/proxy/pool_splithost_test.go b/internal/proxy/pool_splithost_test.go index 5e65bf2..50427bd 100644 --- a/internal/proxy/pool_splithost_test.go +++ b/internal/proxy/pool_splithost_test.go @@ -1,6 +1,7 @@ package proxy import ( + "encoding/base64" "encoding/json" "os" "path/filepath" @@ -871,11 +872,18 @@ func TestFinding2Round19_QUICPoolBindingExpandsToActiveMember(t *testing.T) { t.Fatal("Finding 2: buildPhantomPairs returned no pairs for a pool-named " + "binding over QUIC (pool->member expansion missing — provider.Get() failed)") } + // Finding 1 (round 20): the pooled OAuth access phantom is the + // pool-stable SYNTHETIC JWT keyed on the pool name (poolStablePhantomAccess), + // NOT the literal "SLUICE_PHANTOM:.access" string. The earlier + // assertion only held because these access tokens are non-JWT and the + // old code fell back to the static string — a coincidence that masked + // the R3 violation. Assert the real pool-stable phantom here. + wantPoolAccess := poolStablePhantomAccess(poolName) var sawPoolAccessPhantom, sawPoolRefreshPhantom, sawMemAAccess, sawMemARefresh bool for _, p := range pairs { ps := string(p.phantom) switch ps { - case "SLUICE_PHANTOM:" + poolName + ".access": + case wantPoolAccess: sawPoolAccessPhantom = true case "SLUICE_PHANTOM:" + poolName + ".refresh": sawPoolRefreshPhantom = true @@ -929,3 +937,124 @@ func TestFinding2Round19_QUICPoolBindingExpandsToActiveMember(t *testing.T) { "stay keyed on the pool name (R3-style stability) even on QUIC") } } + +// makeTestJWT builds a structurally valid (header.payload.sig) JWT whose +// payload varies by sub. resignJWT re-signs header+payload of the *real* +// token, so two members with DIFFERENT JWT payloads produce DIFFERENT +// re-signed phantoms under the buggy buildOAuthPhantomPairs(boundName,...) +// path — which is exactly the R3 violation Finding 1 targets. The earlier +// QUIC test used non-JWT access strings, so resignJWT returned "" and the +// static SLUICE_PHANTOM:.access fallback masked the bug. +func makeTestJWT(t *testing.T, sub string) string { + t.Helper() + hdr := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) + pl := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"` + sub + `","iss":"real-idp"}`)) + sig := base64.RawURLEncoding.EncodeToString([]byte("real-signature-for-" + sub)) + return hdr + "." + pl + "." + sig +} + +// TestFinding1Round20_QUICPoolAccessPhantomStableAcrossMemberSwitch asserts +// the R3 pool-stable access-token guarantee on the QUIC path: the +// agent-facing access phantom must be byte-identical across a member switch +// (it is keyed on the POOL name via poolStablePhantomAccess), while the +// injected real token must be the *active* member's. Before the fix the +// QUIC path used buildOAuthPhantomPairs(boundName,...) which re-signed the +// active member's REAL JWT, so the phantom changed on every member switch. +func TestFinding1Round20_QUICPoolAccessPhantomStableAcrossMemberSwitch(t *testing.T) { + caCert, _, err := GenerateCA() + if err != nil { + t.Fatalf("GenerateCA: %v", err) + } + + const poolName = "codex_pool" + jwtA := makeTestJWT(t, "member-A") + jwtB := makeTestJWT(t, "member-B") + provider := &addonWritableProvider{ + creds: map[string]string{ + "memA": poolMemberCred(t, jwtA, "A-refresh"), + "memB": poolMemberCred(t, jwtB, "B-refresh"), + }, + } + + bindings := []vault.Binding{{ + Destination: "api.example.com", + Ports: []int{443}, + Credential: poolName, + }} + br, err := vault.NewBindingResolver(bindings) + if err != nil { + t.Fatalf("NewBindingResolver: %v", err) + } + var brPtr atomic.Pointer[vault.BindingResolver] + brPtr.Store(br) + + pool := store.Pool{Name: poolName, Strategy: store.PoolStrategyFailover} + pool.Members = []store.PoolMember{ + {Credential: "memA", Position: 0}, + {Credential: "memB", Position: 1}, + } + var prPtr atomic.Pointer[vault.PoolResolver] + prPtr.Store(vault.NewPoolResolver([]store.Pool{pool}, nil)) + + qp, err := NewQUICProxy(caCert, provider, &brPtr, &prPtr, nil, nil, nil) + if err != nil { + t.Fatalf("NewQUICProxy: %v", err) + } + + accessPhantomFor := func(t *testing.T, wantSecret string) string { + t.Helper() + pairs := qp.buildPhantomPairs("api.example.com", 443) + defer releasePhantomPairs(pairs) + var accessPhantom string + var sawWantSecret bool + for _, p := range pairs { + ps := string(p.phantom) + // The access phantom is the one whose real secret is a JWT + // (the refresh phantom carries the *-refresh secret). + if strings.Count(ps, ".") == 2 && !strings.HasPrefix(ps, "SLUICE_PHANTOM:") { + accessPhantom = ps + } + if p.secret.String() == wantSecret { + sawWantSecret = true + } + if strings.HasPrefix(ps, "SLUICE_PHANTOM:memA") || + strings.HasPrefix(ps, "SLUICE_PHANTOM:memB") { + t.Fatalf("phantom keyed on a MEMBER name (%q)", ps) + } + } + if accessPhantom == "" { + t.Fatal("no JWT-shaped access phantom found in pairs") + } + if !sawWantSecret { + t.Fatalf("active member's real access token %q not injected", wantSecret) + } + return accessPhantom + } + + // Active member memA: phantom must be the pool-stable synthetic JWT, + // real injected secret must be memA's JWT. + phantom1 := accessPhantomFor(t, jwtA) + + // Independently confirm it is exactly the pool-stable synthetic JWT + // (not a re-sign of memA's real JWT). + wantStable := poolStablePhantomAccess(poolName) + if phantom1 != wantStable { + t.Fatalf("access phantom is not the pool-stable synthetic JWT\n got: %q\nwant: %q", + phantom1, wantStable) + } + + // Flip the active member to memB. + prPtr.Load().MarkCooldown("memA", time.Now().Add(time.Minute), "429") + if got := qp.resolvePoolMember(poolName); got != "memB" { + t.Fatalf("after cooling memA, active member = %q, want memB", got) + } + + // After the switch the agent-facing access phantom MUST be + // byte-identical (R3), while the injected real token is now memB's JWT. + phantom2 := accessPhantomFor(t, jwtB) + if phantom2 != phantom1 { + t.Fatalf("Finding 1 (R3 on QUIC): agent-facing access phantom CHANGED across "+ + "a member switch — must be byte-identical\nbefore: %q\n after: %q", + phantom1, phantom2) + } +} diff --git a/internal/proxy/quic.go b/internal/proxy/quic.go index 4c03f61..89fb17f 100644 --- a/internal/proxy/quic.go +++ b/internal/proxy/quic.go @@ -598,6 +598,37 @@ func (q *QUICProxy) resolvePoolMember(name string) string { return member } +// resolvePoolTarget classifies a binding name. When name is a configured +// pool, isPool is true and member is the pool's current active member (or +// "" when the pool is empty/unresolvable). For a plain credential (or when +// no pool resolver is configured) isPool is false and member == name. +// +// Finding 1 (R3 on QUIC): this is the QUIC analogue of +// SluiceAddon.resolveInjectionTarget. It is required so the QUIC OAuth path +// can route a pooled binding through buildPooledOAuthPhantomPairs — which +// keys the agent-facing access phantom on the POOL name (a pool-stable +// synthetic JWT) rather than re-signing the active member's real JWT. The +// latter changes the agent-held phantom on every member switch, violating +// the R3 pool-stable access-token guarantee. Only phantom stability + +// active-member-secret selection is replicated on QUIC; the documented QUIC +// limitation (no response-side R1 attribution, no 429/401 failover) stands. +func (q *QUICProxy) resolvePoolTarget(name string) (member string, isPool bool) { + if q.poolResolver == nil { + return name, false + } + pr := q.poolResolver.Load() + if pr == nil || !pr.IsPool(name) { + return name, false + } + m, ok := pr.ResolveActive(name) + if !ok || m == "" { + // Empty/unresolvable pool: keep the pool name so the caller's + // provider.Get fails cleanly (no injection) instead of panicking. + return "", true + } + return m, true +} + // buildPhantomPairs resolves credentials bound to the destination and returns // phantom/secret pairs sorted by phantom length descending. // @@ -610,23 +641,42 @@ func (q *QUICProxy) buildPhantomPairs(host string, port int) []phantomPair { var pairs []phantomPair if res := q.resolver.Load(); res != nil { for _, boundName := range res.CredentialsForDestination(host, port, ProtoQUIC.String()) { - // Finding 2: expand a pool-named binding to its active - // member before the vault lookup. The phantom the agent - // holds is keyed on the BOUND name (pool name when pooled, - // so it is stable across member switches); only the injected - // secret comes from the active member's vault entry. - secretName := q.resolvePoolMember(boundName) + // Finding 1/2: classify the binding. A pooled binding + // resolves to its active member for the vault lookup, but + // the agent-facing phantom MUST stay keyed on the POOL name + // so it is byte-identical across member switches (R3). + member, isPool := q.resolvePoolTarget(boundName) + secretName := member + if isPool && member == "" { + // Empty/unresolvable pool: keep the pool name so the + // provider.Get below fails cleanly (no injection). + secretName = boundName + } secret, err := q.provider.Get(secretName) if err != nil { log.Printf("[QUIC-MITM] credential %q lookup failed: %v", secretName, err) continue } // Check if this is an OAuth credential. If so, build two phantom - // pairs (access + refresh) instead of one static pair. The - // phantom is keyed on boundName so a pooled OAuth member swap - // does not change the phantom the agent already holds. + // pairs (access + refresh) instead of one static pair. if vault.IsOAuth(secret.Bytes()) { - oauthPairs, parseErr := buildOAuthPhantomPairs(boundName, secret, "QUIC-MITM") + var oauthPairs []phantomPair + var parseErr error + if isPool { + // Finding 1 (R3 on QUIC): pool-stable synthetic-JWT + // access phantom keyed on the POOL name, refresh + // phantom is the deterministic SLUICE_PHANTOM: + // .refresh string, secrets are the ACTIVE member's + // real tokens. Byte-identical to the HTTP path so the + // agent-held phantom never changes on a member switch. + // onRefreshInject is nil: QUIC has no response-side + // R1 attribution (documented limitation). + oauthPairs, parseErr = buildPooledOAuthPhantomPairs( + boundName, secretName, secret, "QUIC-MITM", nil, + ) + } else { + oauthPairs, parseErr = buildOAuthPhantomPairs(boundName, secret, "QUIC-MITM") + } if parseErr != nil { continue } diff --git a/internal/store/pools.go b/internal/store/pools.go index aa00a65..f77232b 100644 --- a/internal/store/pools.go +++ b/internal/store/pools.go @@ -315,51 +315,97 @@ func (s *Store) ListPools() ([]Pool, error) { return result, nil } -// RemovePool deletes a pool and (via ON DELETE CASCADE) its members. Returns -// true if a pool row was deleted. +// PoolReferencedError is returned by RemovePoolIfUnreferenced when one or +// more bindings still reference the pool by name. It carries the blocking +// bindings so the caller can render an actionable message. The whole check + +// delete runs in ONE transaction, so a concurrent "binding add " can +// no longer commit in a window between a pre-check and the pool delete. +type PoolReferencedError struct { + Pool string + Bindings []BindingRow +} + +func (e *PoolReferencedError) Error() string { + return fmt.Sprintf("pool %q is still referenced by %d binding(s)", e.Pool, len(e.Bindings)) +} + +// RemovePoolIfUnreferenced atomically refuses to delete a pool that any +// binding still references and otherwise deletes it (members + health rows + +// epoch bump, identical to RemovePool). The binding-reference check and the +// delete happen in the SAME transaction (Finding 3): the previous design ran +// the check in the CLI layer and the delete in a separate store transaction, +// so a concurrent "binding add " could commit after the check saw zero +// references but before the pool row was deleted, leaving a binding pointing +// at a non-existent pool. SQLite serializes write transactions, so once this +// tx holds the write lock a concurrent binding insert either committed before +// (and is seen by the SELECT, refusing removal) or blocks until after the +// pool delete commits (and then fails its own pool-existence guard). // -// The members' credential_health rows are deleted in the SAME transaction so -// a cooled member taken out with its pool does not leave a stale durable -// cooldown. loadPoolResolver seeds the shared PoolHealth from ALL -// credential_health rows, so an orphaned cooldown would otherwise be -// inherited by the same credential when it is re-added to a new pool before -// the old TTL expires. A member that is still a live member of ANOTHER pool -// keeps its health row (its cooldown is still meaningful for that pool); only -// members no longer in any pool after this delete have their health row -// removed. -func (s *Store) RemovePool(name string) (bool, error) { +// Returns (false, nil) when the pool does not exist, (true, nil) on success, +// and a *PoolReferencedError when bindings block the removal. +func (s *Store) RemovePoolIfUnreferenced(name string) (bool, error) { tx, err := s.db.Begin() if err != nil { return false, fmt.Errorf("begin tx: %w", err) } defer func() { _ = tx.Rollback() }() + brows, err := tx.Query( + "SELECT id, destination, ports, credential, header, template, protocols, env_var, created_at FROM bindings WHERE credential = ? ORDER BY id", + name, + ) + if err != nil { + return false, fmt.Errorf("check bindings referencing pool %q: %w", name, err) + } + refs, err := scanBindings(brows) + if err != nil { + return false, fmt.Errorf("scan bindings referencing pool %q: %w", name, err) + } + if len(refs) > 0 { + return false, &PoolReferencedError{Pool: name, Bindings: refs} + } + + n, err := removePoolTx(tx, name) + if err != nil { + return false, err + } + if err := tx.Commit(); err != nil { + return false, fmt.Errorf("commit: %w", err) + } + return n > 0, nil +} + +// removePoolTx performs the pool deletion (pool row + cascade members + +// epoch bump + orphaned health-row cleanup) inside the caller's transaction. +// Shared by RemovePool and RemovePoolIfUnreferenced so the two stay in sync. +// Returns the number of pool rows deleted (0 = pool did not exist). +func removePoolTx(tx *sql.Tx, name string) (int64, error) { // Snapshot the pool's members before the cascade wipes the membership // rows so we know whose health rows to consider for cleanup. mrows, err := tx.Query( "SELECT credential FROM credential_pool_members WHERE pool = ?", name, ) if err != nil { - return false, fmt.Errorf("list members of pool %q: %w", name, err) + return 0, fmt.Errorf("list members of pool %q: %w", name, err) } var members []string for mrows.Next() { var c string if scanErr := mrows.Scan(&c); scanErr != nil { _ = mrows.Close() - return false, fmt.Errorf("scan pool member: %w", scanErr) + return 0, fmt.Errorf("scan pool member: %w", scanErr) } members = append(members, c) } if mrowsErr := mrows.Err(); mrowsErr != nil { _ = mrows.Close() - return false, fmt.Errorf("iterate pool members: %w", mrowsErr) + return 0, fmt.Errorf("iterate pool members: %w", mrowsErr) } _ = mrows.Close() res, err := tx.Exec("DELETE FROM credential_pools WHERE name = ?", name) if err != nil { - return false, fmt.Errorf("delete pool %q: %w", name, err) + return 0, fmt.Errorf("delete pool %q: %w", name, err) } n, _ := res.RowsAffected() @@ -370,7 +416,7 @@ func (s *Store) RemovePool(name string) (bool, error) { // CASCADE has wiped the membership, so a late failover cannot // resurrect the removed member's cooldown for a re-created successor. if _, err := bumpMembershipEpochTx(tx); err != nil { - return false, err + return 0, err } // The CASCADE has now removed this pool's credential_pool_members // rows. For each former member, drop its health row UNLESS it is @@ -387,16 +433,45 @@ func (s *Store) RemovePool(name string) (bool, error) { if _, delErr := tx.Exec( "DELETE FROM credential_health WHERE credential = ?", c, ); delErr != nil { - return false, fmt.Errorf("delete health for former pool member %q: %w", c, delErr) + return 0, fmt.Errorf("delete health for former pool member %q: %w", c, delErr) } case err != nil: - return false, fmt.Errorf("check residual pool membership for %q: %w", c, err) + return 0, fmt.Errorf("check residual pool membership for %q: %w", c, err) default: // Still a member of another pool; leave its health row. } } } + return n, nil +} +// RemovePool deletes a pool and (via ON DELETE CASCADE) its members. Returns +// true if a pool row was deleted. It is unconditional (no binding-reference +// guard) — retained for single-generation callers (tests, internal seeding) +// that have already established no binding references the pool. Production +// removal via the CLI uses RemovePoolIfUnreferenced so the reference check +// and the delete are atomic (Finding 3). +// +// The members' credential_health rows are deleted in the SAME transaction so +// a cooled member taken out with its pool does not leave a stale durable +// cooldown. loadPoolResolver seeds the shared PoolHealth from ALL +// credential_health rows, so an orphaned cooldown would otherwise be +// inherited by the same credential when it is re-added to a new pool before +// the old TTL expires. A member that is still a live member of ANOTHER pool +// keeps its health row (its cooldown is still meaningful for that pool); only +// members no longer in any pool after this delete have their health row +// removed. +func (s *Store) RemovePool(name string) (bool, error) { + tx, err := s.db.Begin() + if err != nil { + return false, fmt.Errorf("begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + n, err := removePoolTx(tx, name) + if err != nil { + return false, err + } if err := tx.Commit(); err != nil { return false, fmt.Errorf("commit: %w", err) } diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go index 9776be4..3597f2f 100644 --- a/internal/store/pools_test.go +++ b/internal/store/pools_test.go @@ -1,8 +1,10 @@ package store import ( + "errors" "path/filepath" "strings" + "sync" "testing" "time" @@ -200,6 +202,218 @@ func TestRemovePoolCascadesMembers(t *testing.T) { } } +func TestRemovePoolIfUnreferenced_Unreferenced(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "a") + if err := s.CreatePoolWithMembers("p", "failover", []string{"a"}); err != nil { + t.Fatalf("create: %v", err) + } + if err := s.SetCredentialHealth("a", "cooldown", time.Now().Add(time.Hour), "429"); err != nil { + t.Fatalf("set health: %v", err) + } + epBefore := membershipEpoch(t, s) + + removed, err := s.RemovePoolIfUnreferenced("p") + if err != nil || !removed { + t.Fatalf("RemovePoolIfUnreferenced = %v, %v; want true, nil", removed, err) + } + // Members cascade-deleted. + if mp, _ := s.PoolsForMember("a"); len(mp) != 0 { + t.Errorf("PoolsForMember after remove = %v, want empty", mp) + } + // Former member's health row removed (not orphaned). + hs, _ := s.ListCredentialHealth() + for _, h := range hs { + if h.Credential == "a" { + t.Errorf("health row for former pool member %q not cleaned up", h.Credential) + } + } + // Membership epoch bumped on removal. + if ep := membershipEpoch(t, s); ep <= epBefore { + t.Errorf("membership epoch not bumped on removal: before=%d after=%d", epBefore, ep) + } + // Missing pool -> (false, nil). + if removed, err := s.RemovePoolIfUnreferenced("p"); removed || err != nil { + t.Errorf("RemovePoolIfUnreferenced of missing pool = %v, %v; want false, nil", removed, err) + } +} + +func TestRemovePoolIfUnreferenced_RefusedWhenBound(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "a") + if err := s.CreatePoolWithMembers("p", "failover", []string{"a"}); err != nil { + t.Fatalf("create: %v", err) + } + // A binding NAMES THE POOL (pool shares the credential namespace). + if _, err := s.AddBinding("api.example.com", "p", BindingOpts{Ports: []int{443}}); err != nil { + t.Fatalf("AddBinding: %v", err) + } + + removed, err := s.RemovePoolIfUnreferenced("p") + if removed { + t.Fatal("RemovePoolIfUnreferenced deleted a pool that a binding still references") + } + var refErr *PoolReferencedError + if !errors.As(err, &refErr) { + t.Fatalf("want *PoolReferencedError, got %T: %v", err, err) + } + if refErr.Pool != "p" || len(refErr.Bindings) != 1 || refErr.Bindings[0].Destination != "api.example.com" { + t.Fatalf("PoolReferencedError did not list the blocking binding: %+v", refErr) + } + // The pool must still exist (refusal, not partial delete). + if p, _ := s.GetPool("p"); p == nil { + t.Fatal("pool was deleted despite the binding reference") + } +} + +// TestRemovePoolIfUnreferenced_BindingBeforeRemovalRefuses is the +// deterministic Finding 3 regression. The exact bug: the binding-reference +// check ran in the CLI layer and the pool delete ran in a SEPARATE store +// transaction. If a "binding add " committed in the window after the +// check observed zero references but before the delete, the pool was deleted +// anyway and that binding was left pointing at a non-existent pool. +// +// With the check folded into the SAME transaction as the delete, any binding +// that committed before the removal transaction is observed by its SELECT +// and the removal is refused. This test pins exactly that ordering (binding +// commits, THEN removal runs) — the precise interleaving the old split +// design got wrong — and asserts the pool is NOT deleted and a typed +// PoolReferencedError is returned. +// +// Fail-before: with the CLI-side pre-check removed (the atomic store gate is +// authoritative), a binding committed in the check->delete window would let +// the unconditional RemovePool delete the pool. Pass-after: the atomic +// RemovePoolIfUnreferenced observes the binding in its own transaction and +// refuses. +func TestRemovePoolIfUnreferenced_BindingBeforeRemovalRefuses(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "a") + if err := s.CreatePoolWithMembers("p", "failover", []string{"a"}); err != nil { + t.Fatalf("create: %v", err) + } + + // The concurrent "binding add " commits BEFORE the removal + // transaction runs (the exact race window the old split check/delete + // got wrong). + if _, err := s.AddBinding("api.example.com", "p", BindingOpts{Ports: []int{443}}); err != nil { + t.Fatalf("AddBinding: %v", err) + } + + removed, err := s.RemovePoolIfUnreferenced("p") + if removed { + t.Fatal("Finding 3: pool deleted while a binding added before the removal " + + "transaction still references it (check+delete not atomic)") + } + var refErr *PoolReferencedError + if !errors.As(err, &refErr) { + t.Fatalf("want *PoolReferencedError, got %T: %v", err, err) + } + if p, _ := s.GetPool("p"); p == nil { + t.Fatal("Finding 3: pool row was deleted despite the blocking binding") + } +} + +// TestRemovePoolIfUnreferenced_ConcurrentIsInternallyConsistent stresses the +// two write paths concurrently. SQLite serializes write transactions, so the +// atomic removal can only land in one of two self-consistent terminal +// states: (a) removal observed no binding and the pool is gone with no +// binding committed before it, or (b) the binding committed first, removal +// observed it and refused, pool intact. The Finding-3 corruption — pool +// removed by RemovePoolIfUnreferenced while a binding it should have observed +// dangles at the deleted pool — must never appear. +func TestRemovePoolIfUnreferenced_ConcurrentIsInternallyConsistent(t *testing.T) { + for iter := 0; iter < 80; iter++ { + // File-backed (not :memory:) because this test drives two + // concurrent write transactions; modernc.org/sqlite gives each + // pooled connection its OWN private database for a bare ":memory:" + // DSN, so a second connection opened under contention would see an + // empty schema. A temp file is shared across the connection pool. + dir := t.TempDir() + s, err := New(filepath.Join(dir, "pool_race.db")) + if err != nil { + t.Fatalf("iter %d: new store: %v", iter, err) + } + seedOAuthCred(t, s, "a") + if err := s.CreatePoolWithMembers("p", "failover", []string{"a"}); err != nil { + t.Fatalf("iter %d: create: %v", iter, err) + } + + var ( + wg sync.WaitGroup + removed bool + removeErr error + addBindErr error + ) + wg.Add(2) + go func() { + defer wg.Done() + removed, removeErr = s.RemovePoolIfUnreferenced("p") + }() + go func() { + defer wg.Done() + _, addBindErr = s.AddBinding("api.example.com", "p", BindingOpts{Ports: []int{443}}) + }() + wg.Wait() + + poolExists := false + if p, _ := s.GetPool("p"); p != nil { + poolExists = true + } + bindings, err := s.ListBindings() + if err != nil { + t.Fatalf("iter %d: ListBindings: %v", iter, err) + } + bindingCommitted := addBindErr == nil + var poolBindings int + for _, b := range bindings { + if b.Credential == "p" { + poolBindings++ + } + } + + _ = bindingCommitted + _ = poolBindings + var refErr *PoolReferencedError + switch { + case removed: + // Removal observed no binding in its own transaction, so the + // pool must now be gone. (Any binding add committed strictly + // after the removal transaction — the AddBinding-does-not- + // validate-existence concern is pre-existing and out of + // Finding 3's scope; Finding 3 is the check->delete window, + // which is now closed.) + if poolExists { + t.Fatalf("iter %d: removed==true but pool still exists "+ + "(non-atomic delete)", iter) + } + case errors.As(removeErr, &refErr): + // Removal observed a binding in its OWN transaction and refused. + // The defining Finding-3 guarantee: a refusal must never have + // also deleted the pool (no partial check-then-delete). + if !poolExists { + t.Fatalf("iter %d: Finding 3: removal refused (saw a binding) yet the "+ + "pool was deleted — check and delete were not atomic", iter) + } + case removeErr != nil: + t.Fatalf("iter %d: unexpected removal error: %v", iter, removeErr) + default: + // (false, nil): pool did not exist when removal ran. It must + // not have been silently removed by this call (it never + // existed-for-this-call), and any present state is consistent. + } + _ = s.Close() + } +} + +func membershipEpoch(t *testing.T, s *Store) int64 { + t.Helper() + var ep int64 + if err := s.db.QueryRow("SELECT epoch FROM pool_membership_epoch WHERE id = 1").Scan(&ep); err != nil { + t.Fatalf("read membership epoch: %v", err) + } + return ep +} + func TestPoolsForMember(t *testing.T) { s := newTestStore(t) seedOAuthCred(t, s, "shared") From 04a83a6a62e62a214a1ce195c1cb2a38bbc108e1 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 21:21:45 +0800 Subject: [PATCH 45/49] fix(store): binding creation requires live credential or pool; CLAUDE.md QUIC R3 accuracy --- CLAUDE.md | 2 +- cmd/sluice/binding_test.go | 10 ++++ cmd/sluice/cred_test.go | 18 +++++++ cmd/sluice/main_test.go | 16 ++++++ cmd/sluice/policy_test.go | 6 +++ internal/api/server_test.go | 41 +++++++++++++++ internal/store/pools_test.go | 84 ++++++++++++++++++++++++++++++ internal/store/store.go | 60 +++++++++++++++++++++ internal/store/store_test.go | 77 ++++++++++++++++++++++++++- internal/telegram/approval_test.go | 23 ++++++-- internal/telegram/commands.go | 16 +++++- 11 files changed, 345 insertions(+), 8 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 1a142ce..e61b4b4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -233,7 +233,7 @@ Auto-failover on 429/401 is the primary mechanism; `pool rotate` is an operator **Phase 1 — phantom indirection (pool phantom → active member):** - **Single chokepoint (I2):** every `binding.Credential` / `OAuthIndex.Has` / `extractInjectableSecret` / persist consumer on the HTTP/HTTPS OAuth path routes through `PoolResolver.ResolveActive` (`resolveInjectionTarget` for pass-1 header + pass-2 phantom swap; `resolveOAuthResponseAttribution` for the response/persist path). `idx.Has` is always called with the resolved member name, never the pool. Plain (non-pool) credentials pass through `ResolveActive` unchanged. SSH/mail are non-OAuth and out of scope. -- **QUIC pool support is active-member expansion only (HTTP-vs-QUIC capability boundary):** the HTTP/1.x/HTTP/2 MITM addon implements the full pool feature set (R1 refresh attribution, R3 pool-stable phantom, Phase 2 429/401 auto-failover). The HTTP/3/QUIC injection path (`QUICProxy.buildPhantomPairs` and the binding-header injection in `quic.go`) is a simpler buffered swap with no response-side OAuth interception. It IS pool-aware: `QUICProxy.resolvePoolMember` (wired via `NewQUICProxy`'s `poolResolver` arg from `server.go`) expands a pool-named binding to the pool's current active member before `provider.Get`, mirroring `resolveInjectionTarget`, so a pool binding *functions* over QUIC and the agent-held phantom stays keyed on the pool name (stable across member switches). What QUIC does **not** do: per-request OAuth refresh attribution (R1), pool-stable synthetic-JWT minting (R3), and automatic 429/401 member failover. Over QUIC the injected secret is whatever member the HTTP path (or an operator via `pool rotate`) last made active; a QUIC-only 429/401 does not trigger a member switch. Deployments needing full pool failover must route the pooled upstream over HTTP/HTTPS rather than HTTP/3. +- **QUIC pool support covers active-member injection plus the R3 pool-stable phantom; response-side R1/failover is HTTP-only (HTTP-vs-QUIC capability boundary):** the HTTP/1.x/HTTP/2 MITM addon implements the full pool feature set (R1 refresh attribution, R3 pool-stable phantom, Phase 2 429/401 auto-failover). The HTTP/3/QUIC injection path (`QUICProxy.buildPhantomPairs` and the binding-header injection in `quic.go`) is a request-side buffered swap with **no response-side OAuth interception**. It IS pool-aware on the request side: `QUICProxy.resolvePoolTarget` (wired via `NewQUICProxy`'s `poolResolver` arg from `server.go`) classifies a pooled binding, selects the pool's current active member's real secret for the vault lookup, and routes through `buildPooledOAuthPhantomPairs` so the agent-facing **access phantom is the same pool-stable synthetic JWT** the HTTP path mints (keyed on the pool name via `poolStablePhantomAccess`, byte-identical across member switches — R3 holds over QUIC). What QUIC does **not** do, because it has no response-side OAuth interception: per-request OAuth refresh attribution (R1) and automatic 429/401 member failover (Phase 2). Over QUIC the injected member secret is whatever member the HTTP path (or an operator via `pool rotate`) last made active; a QUIC-only 429/401 does not trigger a member switch and a QUIC-only token refresh is not attributed/persisted back to its issuing member. Deployments needing R1 attribution or auto-failover must route the pooled upstream over HTTP/HTTPS rather than HTTP/3; the agent-visible phantom itself is already stable on either path. - **Active-member selection:** healthy or expired-cooldown members first, by configured position; if all members are in cooldown, the soonest-recovering member is returned with a WARNING (degrade, never hard-fail). Recovery is lazy — evaluated in `ResolveActive`, no scheduler. - **R1 refresh-token attribution / fail-closed:** when pass-2 swaps `SLUICE_PHANTOM:.refresh`, sluice records `realRefreshToken → member` in a short-TTL map. On the token-endpoint response it recovers the member by that real refresh token and persists to that member (`persistAddonOAuthTokens(member, ...)`, singleflight key `"persist:"+member`). The join key is the real **refresh** token sluice injected — never the access token, the client connection, or `OAuthIndex.Match` (two pooled members share `auth.openai.com`'s token URL and collide there). If the member is unrecoverable: WARNING + skip the vault write, never guess. Rotating refresh tokens are single-use, so a mis-attributed write would brick both accounts — fail-closed is mandatory. **Plain-credential disambiguation on a shared token URL:** a plain (non-pool) OAuth credential that merely shares its token URL with a pool also has its injected real refresh token tagged `realRefreshToken → ` (the plain path in `buildPhantomPairs` / `buildOAuthPhantomPairs`'s `onRefreshInject`, including the token-host expansion for split-host plain creds). On the response side, when a pool shares the token URL, `resolveOAuthResponseAttribution` recovers the tag: if it resolves to a name that is **not** a pool member (`PoolForMember == ""`), the refresh is attributed 1:1 to that plain credential (its own phantom, its own vault write) — NOT fail-closed as a pooled refresh. The pooled fail-closed path is taken only when recovery fails entirely or resolves to an actual pool member. The `poolForResponse` failover path applies the same rule: a recovered owner not in any pool only triggers the membership-raced active-member fallback when an independent `flowInjected` pool-usage tag (set post-swap only if a pool phantom was actually present) confirms pooled usage; otherwise the failure is treated as a plain credential's and no pool member is cooled. - **R3 pool-stable phantom JWT:** Codex access tokens are JWTs and the per-real-token `resignJWT` would emit a *different* phantom after every cross-member refresh, breaking the "agent never notices" guarantee. The dedicated `poolStablePhantomAccess` (in `internal/proxy/oauth_response.go`) instead builds the phantom JWT from a deterministic synthetic payload keyed on the **pool name** (`sub: sluice-pool:`, `iss: sluice-phantom`, fixed far-future `exp`, no `iat`), HMAC-SHA256'd with the existing fixed key — byte-identical across member switches while still a structurally valid JWT. The pool name is JSON-marshaled (never concatenated) so a name with quotes/control chars cannot inject claims. Static-form fallback (`SLUICE_PHANTOM:.access`) is emitted only on the unreachable `json.Marshal` failure of the fixed struct (and is documented as the equivalent for an agent verified to treat the access token as opaque). The **refresh** phantom is unaffected — it stays the static `SLUICE_PHANTOM:.refresh`. diff --git a/cmd/sluice/binding_test.go b/cmd/sluice/binding_test.go index 6dab670..cda031f 100644 --- a/cmd/sluice/binding_test.go +++ b/cmd/sluice/binding_test.go @@ -27,6 +27,16 @@ func setupBindingDB(t *testing.T) string { if err != nil { t.Fatalf("create test DB: %v", err) } + // Seed the credentials these tests bind to. AddBinding / + // AddRuleAndBinding now require the referenced credential (or pool) to + // exist, mirroring the real flow where "sluice cred add" creates the + // credential before "sluice binding add" binds to it. Without this, the + // binding-CLI tests would be exercising an impossible state. + for _, c := range []string{"mycred", "cred_a", "cred_b"} { + if err := db.AddCredentialMeta(c, "static", ""); err != nil { + t.Fatalf("seed credential meta %q: %v", c, err) + } + } _ = db.Close() return dbPath } diff --git a/cmd/sluice/cred_test.go b/cmd/sluice/cred_test.go index eb942d6..d3195de 100644 --- a/cmd/sluice/cred_test.go +++ b/cmd/sluice/cred_test.go @@ -558,6 +558,9 @@ func TestHandleCredListWithBindings(t *testing.T) { if err != nil { t.Fatal(err) } + if err := db.AddCredentialMeta("mykey", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } _, err = db.AddBinding("api.example.com", "mykey", store.BindingOpts{ Ports: []int{443}, Header: "Authorization", @@ -614,6 +617,9 @@ func TestHandleCredRemoveWithBindings(t *testing.T) { if err != nil { t.Fatal(err) } + if err := db.AddCredentialMeta("cleanup_key", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } _, err = db.AddBinding("api.cleanup.com", "cleanup_key", store.BindingOpts{ Ports: []int{443}, Header: "Authorization", @@ -1535,6 +1541,9 @@ func TestHandleCredListShowsEnvVar(t *testing.T) { if err != nil { t.Fatal(err) } + if err := db.AddCredentialMeta("my_api_key", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } _, err = db.AddBinding("api.example.com", "my_api_key", store.BindingOpts{ Ports: []int{443}, Header: "Authorization", @@ -1574,6 +1583,9 @@ func TestHandleCredListHidesEnvVarWhenEmpty(t *testing.T) { if err != nil { t.Fatal(err) } + if err := db.AddCredentialMeta("no_env_cred", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } _, err = db.AddBinding("api.example.com", "no_env_cred", store.BindingOpts{ Ports: []int{443}, }) @@ -2030,6 +2042,9 @@ func TestHandleCredRemoveCleansUpBindingAddRules(t *testing.T) { if err != nil { t.Fatal(err) } + if err := db.AddCredentialMeta("bind_cleanup_key", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } if _, _, err := db.AddRuleAndBinding( "allow", store.RuleOpts{ @@ -2136,6 +2151,9 @@ func TestHandleCredAddMultipleDestinationsMidLoopRollback(t *testing.T) { if err != nil { t.Fatalf("open store: %v", err) } + if err := db.AddCredentialMeta("mid_rollback_key", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } if _, err := db.AddBinding("api.second.com", "mid_rollback_key", store.BindingOpts{}); err != nil { t.Fatalf("pre-seed blocking binding: %v", err) } diff --git a/cmd/sluice/main_test.go b/cmd/sluice/main_test.go index 8656e16..53e1bb1 100644 --- a/cmd/sluice/main_test.go +++ b/cmd/sluice/main_test.go @@ -565,6 +565,11 @@ func TestReadBindings(t *testing.T) { } // Add bindings. + for _, c := range []string{"my_key", "gh_key"} { + if err := db.AddCredentialMeta(c, "static", ""); err != nil { + t.Fatalf("add credential meta %q: %v", c, err) + } + } _, _ = db.AddBinding("api.example.com", "my_key", store.BindingOpts{ Ports: []int{443}, Header: "Authorization", @@ -1371,6 +1376,9 @@ func TestStandaloneModeCredentialInjection(t *testing.T) { // Add a binding. In standalone mode, credential injection still works // because the MITM proxy handles it, not the container manager. + if err := db.AddCredentialMeta("my_api_key", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } _, _ = db.AddBinding("api.example.com", "my_api_key", store.BindingOpts{ Ports: []int{443}, Header: "Authorization", @@ -1465,6 +1473,9 @@ func TestInjectEnvVarsFromStore(t *testing.T) { if _, addErr := vs.Add("openai_key", "sk-real-secret"); addErr != nil { t.Fatal(addErr) } + if err := db.AddCredentialMeta("openai_key", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } _, bindErr := db.AddBinding("api.openai.com", "openai_key", store.BindingOpts{ Ports: []int{443}, @@ -1526,6 +1537,11 @@ func TestInjectEnvVarsFromStoreMultipleBindings(t *testing.T) { defer func() { _ = db.Close() }() // Add two bindings with env_var. + for _, c := range []string{"openai_key", "tg_bot", "gh_token"} { + if err := db.AddCredentialMeta(c, "static", ""); err != nil { + t.Fatalf("add credential meta %q: %v", c, err) + } + } _, _ = db.AddBinding("api.openai.com", "openai_key", store.BindingOpts{ Ports: []int{443}, Header: "Authorization", diff --git a/cmd/sluice/policy_test.go b/cmd/sluice/policy_test.go index 5a27a42..9f4545d 100644 --- a/cmd/sluice/policy_test.go +++ b/cmd/sluice/policy_test.go @@ -741,6 +741,9 @@ func TestHandlePolicyExportMatchesStore(t *testing.T) { _, _ = db.AddRule("allow", store.RuleOpts{Tool: "github__list_*", Name: "read-only github"}) _, _ = db.AddRule("deny", store.RuleOpts{Pattern: "(?i)(sk-[a-zA-Z0-9_-]{20,})", Name: "api_key_leak"}) _, _ = db.AddRule("redact", store.RuleOpts{Pattern: "(?i)(sk-[a-zA-Z0-9_-]{20,})", Replacement: "[REDACTED]", Name: "api_key_response"}) + if err := db.AddCredentialMeta("my_key", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } _, _ = db.AddBinding("api.example.com", "my_key", store.BindingOpts{ Ports: []int{443}, Header: "Authorization", @@ -1318,6 +1321,9 @@ func TestPolicyExportContainsExpectedSections(t *testing.T) { _, _ = db.AddRule("allow", store.RuleOpts{Destination: "api.example.com", Ports: []int{443}, Name: "API"}) _, _ = db.AddRule("deny", store.RuleOpts{Destination: "evil.example.com"}) _, _ = db.AddRule("allow", store.RuleOpts{Tool: "github__list_*"}) + if err := db.AddCredentialMeta("my_key", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } _, _ = db.AddBinding("api.example.com", "my_key", store.BindingOpts{ Ports: []int{443}, Header: "Authorization", diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 6074fcf..4f180b8 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -36,6 +36,22 @@ func newTestStore(t *testing.T) *store.Store { return s } +// seedCred registers a static credential in credential_meta so a binding +// referencing it passes the live-credential-or-pool existence check that +// AddBinding / AddRuleAndBinding now enforce. The real REST flows always +// create the credential before binding (POST /api/credentials registers +// credential_meta before the paired binding; POST /api/bindings binds to a +// pre-existing credential), so seeding here mirrors production rather than +// weakening the test. +func seedCred(t *testing.T, st *store.Store, names ...string) { + t.Helper() + for _, n := range names { + if err := st.AddCredentialMeta(n, "static", ""); err != nil { + t.Fatalf("seed credential meta %q: %v", n, err) + } + } +} + // enableHTTPChannel inserts an enabled HTTP channel row (type=1) in the store. func enableHTTPChannel(t *testing.T, st *store.Store) { t.Helper() @@ -986,6 +1002,7 @@ func TestGetApiRulesExport_BindingEnvVar(t *testing.T) { st := newTestStore(t) enableHTTPChannel(t, st) srv := api.NewServer(st, nil, nil, "") + seedCred(t, st, "openai_key") // Add a binding with env_var set. if _, err := st.AddBinding("api.openai.com", "openai_key", store.BindingOpts{ @@ -1336,6 +1353,7 @@ func TestPostApiCredentials_DuplicateBinding(t *testing.T) { srv.SetVault(v) // Seed an existing binding on (existing_cred, api.example.com). + seedCred(t, st, "existing_cred") if _, err := st.AddBinding("api.example.com", "existing_cred", store.BindingOpts{}); err != nil { t.Fatalf("seed binding: %v", err) } @@ -1408,6 +1426,7 @@ func TestDeleteApiCredentials_Success(t *testing.T) { if _, err := v.Add("my_key", "value"); err != nil { t.Fatalf("add: %v", err) } + seedCred(t, st, "my_key") if _, err := st.AddBinding("api.example.com", "my_key", store.BindingOpts{}); err != nil { t.Fatalf("add binding: %v", err) } @@ -1463,6 +1482,7 @@ func TestDeleteApiCredentials_ConcurrentRace(t *testing.T) { if _, err := v.Add("racer", "value"); err != nil { t.Fatalf("add: %v", err) } + seedCred(t, st, "racer") if _, err := st.AddBinding("api.example.com", "racer", store.BindingOpts{}); err != nil { t.Fatalf("add binding: %v", err) } @@ -2019,6 +2039,7 @@ func TestDeleteApiCredentials_MissingVaultSecretIsNotFatal(t *testing.T) { if _, err := v.Add("dup", "value"); err != nil { t.Fatalf("iter %d: add: %v", iter, err) } + seedCred(t, st, "dup") if _, err := st.AddBinding("api.example.com", "dup", store.BindingOpts{}); err != nil { t.Fatalf("iter %d: add binding: %v", iter, err) } @@ -2208,6 +2229,7 @@ func TestPostApiBindings_Success(t *testing.T) { t.Setenv("SLUICE_API_TOKEN", "tok") handler := newTestHandler(t, srv, st) + seedCred(t, st, "my_key") body := `{"destination": "api.example.com", "credential": "my_key", "ports": [443], "header": "Authorization", "template": "Bearer {value}"}` req := httptest.NewRequest("POST", "/api/bindings", strings.NewReader(body)) req.Header.Set("Authorization", "Bearer tok") @@ -2246,6 +2268,7 @@ func TestPostApiBindings_PropagatesProtocolsAndPortsToRule(t *testing.T) { t.Setenv("SLUICE_API_TOKEN", "tok") handler := newTestHandler(t, srv, st) + seedCred(t, st, "my_key") body := `{"destination":"api.example.com","credential":"my_key","ports":[443,8443],"protocols":["tcp"]}` req := httptest.NewRequest("POST", "/api/bindings", strings.NewReader(body)) req.Header.Set("Authorization", "Bearer tok") @@ -2335,6 +2358,7 @@ func TestPatchApiBindingsId_RejectsUnknownProtocol(t *testing.T) { enableHTTPChannel(t, st) srv := api.NewServer(st, nil, nil, "") + seedCred(t, st, "cred") _, bindingID, err := st.AddRuleAndBinding( "allow", store.RuleOpts{Destination: "api.example.com", Source: store.BindingAddSourcePrefix + "cred"}, @@ -2368,6 +2392,7 @@ func TestDeleteApiBindings_Success(t *testing.T) { enableHTTPChannel(t, st) srv := api.NewServer(st, nil, nil, "") + seedCred(t, st, "my_key") id, err := st.AddBinding("api.example.com", "my_key", store.BindingOpts{}) if err != nil { t.Fatalf("add binding: %v", err) @@ -2941,6 +2966,7 @@ func TestPostApiBindings_WithEnvVar(t *testing.T) { t.Setenv("SLUICE_API_TOKEN", "tok") handler := newTestHandler(t, srv, st) + seedCred(t, st, "my_key") body := `{"destination": "api.example.com", "credential": "my_key", "ports": [443], "env_var": "MY_API_KEY"}` req := httptest.NewRequest("POST", "/api/bindings", strings.NewReader(body)) req.Header.Set("Authorization", "Bearer tok") @@ -2970,6 +2996,7 @@ func TestGetApiBindings_ReturnsEnvVar(t *testing.T) { srv := api.NewServer(st, nil, nil, "") // Create a binding with env_var directly in the store. + seedCred(t, st, "my_key") _, err := st.AddBinding("api.example.com", "my_key", store.BindingOpts{ Ports: []int{443}, EnvVar: "EXAMPLE_KEY", @@ -3008,6 +3035,7 @@ func TestGetApiBindings_OmitsEmptyEnvVar(t *testing.T) { srv := api.NewServer(st, nil, nil, "") // Create a binding without env_var. + seedCred(t, st, "my_key") _, err := st.AddBinding("api.example.com", "my_key", store.BindingOpts{}) if err != nil { t.Fatalf("add binding: %v", err) @@ -3119,6 +3147,7 @@ func TestPostApiBindings_WithContainerManager(t *testing.T) { mgr := &mockContainerMgr{} srv.SetContainerManager(mgr) + seedCred(t, st, "openai_key") body := `{"destination":"api.openai.com","credential":"openai_key","ports":[443],"env_var":"OPENAI_API_KEY"}` rec := httptest.NewRecorder() req := httptest.NewRequest("POST", "/api/bindings", strings.NewReader(body)) @@ -3142,6 +3171,7 @@ func TestDeleteApiBindingsId_ClearsEnvVar(t *testing.T) { defer func() { _ = st.Close() }() // Create a binding with env_var. + seedCred(t, st, "openai_key") id, err := st.AddBinding("api.openai.com", "openai_key", store.BindingOpts{ Ports: []int{443}, EnvVar: "OPENAI_API_KEY", @@ -3182,6 +3212,7 @@ func TestPatchApiBindingsId_Success(t *testing.T) { enableHTTPChannel(t, st) srv := api.NewServer(st, nil, nil, "") + seedCred(t, st, "my_key") id, err := st.AddBinding("api.example.com", "my_key", store.BindingOpts{ Ports: []int{443}, Header: "Authorization", @@ -3227,6 +3258,7 @@ func TestPatchApiBindingsId_MultipleFields(t *testing.T) { enableHTTPChannel(t, st) srv := api.NewServer(st, nil, nil, "") + seedCred(t, st, "my_key") id, err := st.AddBinding("api.example.com", "my_key", store.BindingOpts{ Ports: []int{443}, }) @@ -3291,6 +3323,7 @@ func TestPatchApiBindingsId_InvalidBody(t *testing.T) { enableHTTPChannel(t, st) srv := api.NewServer(st, nil, nil, "") + seedCred(t, st, "my_key") id, err := st.AddBinding("api.example.com", "my_key", store.BindingOpts{}) if err != nil { t.Fatalf("add binding: %v", err) @@ -3321,6 +3354,7 @@ func TestPatchApiBindingsId_DestinationSyncsPairedRule(t *testing.T) { enableHTTPChannel(t, st) srv := api.NewServer(st, nil, nil, "") + seedCred(t, st, "my_key") ruleID, bindingID, err := st.AddRuleAndBinding( "allow", store.RuleOpts{ @@ -3380,6 +3414,7 @@ func TestPatchApiBindingsId_EnvVar(t *testing.T) { enableHTTPChannel(t, st) srv := api.NewServer(st, nil, nil, "") + seedCred(t, st, "my_key") id, err := st.AddBinding("api.example.com", "my_key", store.BindingOpts{}) if err != nil { t.Fatalf("add binding: %v", err) @@ -3426,6 +3461,7 @@ func TestPatchApiBindingsId_EmptyDestinationRejected(t *testing.T) { enableHTTPChannel(t, st) srv := api.NewServer(st, nil, nil, "") + seedCred(t, st, "my_key") id, err := st.AddBinding("api.example.com", "my_key", store.BindingOpts{}) if err != nil { t.Fatalf("add binding: %v", err) @@ -3456,6 +3492,7 @@ func TestPatchApiBindingsId_EmptyBodyRejected(t *testing.T) { enableHTTPChannel(t, st) srv := api.NewServer(st, nil, nil, "") + seedCred(t, st, "my_key") id, err := st.AddBinding("api.example.com", "my_key", store.BindingOpts{}) if err != nil { t.Fatalf("add binding: %v", err) @@ -3485,6 +3522,7 @@ func TestPatchApiBindingsId_DuplicateDestinationRejected(t *testing.T) { enableHTTPChannel(t, st) srv := api.NewServer(st, nil, nil, "") + seedCred(t, st, "my_key") if _, err := st.AddBinding("api.a.com", "my_key", store.BindingOpts{}); err != nil { t.Fatalf("add first binding: %v", err) } @@ -3518,6 +3556,7 @@ func TestPatchApiBindingsId_ClearsEnvVar(t *testing.T) { mgr := &mockContainerMgr{} srv.SetContainerManager(mgr) + seedCred(t, st, "my_key") id, err := st.AddBinding("api.example.com", "my_key", store.BindingOpts{ EnvVar: "OLD_KEY", }) @@ -3560,6 +3599,7 @@ func TestDeleteApiBindingsId_CleansUpPairedRule(t *testing.T) { enableHTTPChannel(t, st) srv := api.NewServer(st, nil, nil, "") + seedCred(t, st, "my_key") ruleID, bindingID, err := st.AddRuleAndBinding( "allow", store.RuleOpts{ @@ -3612,6 +3652,7 @@ func TestDeleteApiCredentials_CleansUpBindingAddRules(t *testing.T) { if _, err := v.Add("my_key", "s3cr3t"); err != nil { t.Fatalf("seed credential: %v", err) } + seedCred(t, st, "my_key") if _, _, err := st.AddRuleAndBinding( "allow", store.RuleOpts{ diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go index 3597f2f..337e5a1 100644 --- a/internal/store/pools_test.go +++ b/internal/store/pools_test.go @@ -313,6 +313,90 @@ func TestRemovePoolIfUnreferenced_BindingBeforeRemovalRefuses(t *testing.T) { } } +// TestAddBinding_AfterPoolRemovedRefuses closes the CREATION half of the +// bind/remove TOCTOU. RemovePoolIfUnreferenced guards the removal side +// (refuses to delete a referenced pool); this verifies the symmetric +// creation guard: once a pool is gone, AddBinding / AddRuleAndBinding that +// names it must fail and persist NOTHING, so a later same-named credential +// cannot silently inherit a stale binding. +// +// Fail-before: the binding INSERT had no existence check, so a binding +// pointing at the just-deleted pool committed (and the rule in the +// AddRuleAndBinding case leaked). Pass-after: the in-transaction +// credential/pool existence check rejects both paths and rolls back. +func TestAddBinding_AfterPoolRemovedRefuses(t *testing.T) { + s := newTestStore(t) + seedOAuthCred(t, s, "a") + if err := s.CreatePoolWithMembers("p", "failover", []string{"a"}); err != nil { + t.Fatalf("create pool: %v", err) + } + + // A binding for the LIVE pool must still succeed (don't break the + // legitimate path). + if _, err := s.AddBinding("live.example.com", "p", BindingOpts{Ports: []int{443}}); err != nil { + t.Fatalf("AddBinding for live pool should succeed: %v", err) + } + + // A binding for a LIVE plain credential must still succeed. + mustAddCred(t, s, "plain_cred") + if _, err := s.AddBinding("plain.example.com", "plain_cred", BindingOpts{Ports: []int{443}}); err != nil { + t.Fatalf("AddBinding for live credential should succeed: %v", err) + } + + // Remove the pool (no binding references it after we also clear the + // live.example.com one, mirroring an operator deleting an unused pool). + if n, err := s.RemoveBindingsByCredential("p"); err != nil || n != 1 { + t.Fatalf("clear pool binding: n=%d err=%v", n, err) + } + removed, err := s.RemovePoolIfUnreferenced("p") + if err != nil || !removed { + t.Fatalf("RemovePoolIfUnreferenced: removed=%v err=%v", removed, err) + } + + // AddBinding referencing the now-deleted pool must FAIL and insert no row. + _, err = s.AddBinding("api.example.com", "p", BindingOpts{Ports: []int{443}}) + if err == nil { + t.Fatal("AddBinding committed a binding pointing at a deleted pool (TOCTOU creation half open)") + } + if !errors.Is(err, ErrBindingCredentialMissing) { + t.Errorf("want ErrBindingCredentialMissing, got %v", err) + } + if !errors.Is(err, ErrBindingValidation) { + t.Errorf("want the error wrapped under ErrBindingValidation for 400 mapping, got %v", err) + } + + // AddRuleAndBinding referencing the deleted pool must also FAIL and leave + // no orphan rule behind (transaction rolled back). + rulesBefore, _ := s.ListRules(RuleFilter{Verdict: "allow"}) + _, _, err = s.AddRuleAndBinding( + "allow", + RuleOpts{Destination: "api.example.com", Source: BindingAddSourcePrefix + "p"}, + "p", + BindingOpts{Ports: []int{443}}, + ) + if err == nil { + t.Fatal("AddRuleAndBinding committed a binding pointing at a deleted pool") + } + if !errors.Is(err, ErrBindingCredentialMissing) { + t.Errorf("want ErrBindingCredentialMissing, got %v", err) + } + + // No binding row references the dead pool. + dead, err := s.ListBindingsByCredential("p") + if err != nil { + t.Fatalf("ListBindingsByCredential: %v", err) + } + if len(dead) != 0 { + t.Errorf("expected 0 bindings for deleted pool, got %d", len(dead)) + } + // No orphan rule from the rolled-back AddRuleAndBinding. + rulesAfter, _ := s.ListRules(RuleFilter{Verdict: "allow"}) + if len(rulesAfter) != len(rulesBefore) { + t.Errorf("AddRuleAndBinding left an orphan rule: before=%d after=%d", + len(rulesBefore), len(rulesAfter)) + } +} + // TestRemovePoolIfUnreferenced_ConcurrentIsInternallyConsistent stresses the // two write paths concurrently. SQLite serializes write transactions, so the // atomic removal can only land in one of two self-consistent terminal diff --git a/internal/store/store.go b/internal/store/store.go index 3f5219d..810bcab 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -7,6 +7,7 @@ package store import ( "database/sql" "encoding/json" + "errors" "fmt" "os" "regexp" @@ -646,6 +647,14 @@ func (s *Store) AddBinding(destination, credential string, opts BindingOpts) (in } }() + // The credential must resolve to a live credential or a live pool. + // Runs on the same tx as the INSERT so a concurrent pool/credential + // delete cannot interleave between the check and the write, closing + // the creation half of the bind/remove TOCTOU. + if err := assertBindingCredentialExistsTx(tx, credential); err != nil { + return 0, err + } + // Uniqueness check runs on the same tx as the INSERT so the single // connection serializes them. Without this, concurrent callers could // both observe no collision and then both insert the same env_var @@ -704,6 +713,47 @@ var ErrBindingDuplicate = fmt.Errorf("binding already exists for credential/dest // server faults from clients. var ErrBindingValidation = fmt.Errorf("binding validation failed") +// ErrBindingCredentialMissing is returned by AddBinding / AddRuleAndBinding +// when the named credential does not exist as either a live credential +// (credential_meta) or a live pool (credential_pools). It is wrapped under +// ErrBindingValidation so the API layer maps it to a 400, but callers that +// want to detect this specific case (e.g. to print "create the credential +// or pool first") can test with errors.Is on this sentinel. +// +// This closes the creation half of the bind/remove TOCTOU: +// RemovePoolIfUnreferenced refuses to delete a pool that a binding still +// references, and this check refuses to create a binding for a credential +// or pool that no longer exists. The existence check and the binding INSERT +// run in the same transaction so a concurrent pool/credential delete cannot +// interleave between them and leave a binding pointing at a vanished name. +var ErrBindingCredentialMissing = fmt.Errorf("binding references a credential or pool that does not exist") + +// assertBindingCredentialExistsTx verifies, inside the caller's transaction, +// that the binding's credential refers to either a live credential +// (credential_meta) OR a live pool (credential_pools). A binding's +// credential column is a free-form name that resolves at proxy time to one +// of those two namespaces, so accepting a name that matches neither would +// persist a permanently dead binding (and could later be silently inherited +// by a same-named credential/pool created afterwards). Run on the same tx as +// the INSERT so the check and the write are atomic against a concurrent +// pool/credential delete. +func assertBindingCredentialExistsTx(tx *sql.Tx, credential string) error { + var one int + err := tx.QueryRow( + `SELECT 1 WHERE EXISTS (SELECT 1 FROM credential_meta WHERE name = ?) + OR EXISTS (SELECT 1 FROM credential_pools WHERE name = ?)`, + credential, credential, + ).Scan(&one) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("%w: %w: credential %q is neither a live credential nor a live pool (create it first)", + ErrBindingValidation, ErrBindingCredentialMissing, credential) + } + if err != nil { + return fmt.Errorf("check credential/pool existence for %q: %w", credential, err) + } + return nil +} + // isBindingUniqueViolation detects the SQLite UNIQUE constraint violation // that indicates a duplicate binding on (credential, destination). func isBindingUniqueViolation(err error) bool { @@ -1676,6 +1726,16 @@ func (s *Store) AddRuleAndBinding( } ruleID, _ = res.LastInsertId() + // The credential must resolve to a live credential or a live pool. + // Same tx as the rule + binding inserts so a concurrent pool/credential + // delete cannot interleave and leave a binding pointing at a vanished + // name (creation half of the bind/remove TOCTOU). cred add + // --destination commits the credential_meta row in its own prior + // transaction, so the just-created credential is already visible here. + if err = assertBindingCredentialExistsTx(tx, credential); err != nil { + return 0, 0, err + } + // Validate and check env_var uniqueness before inserting (uses tx to // avoid deadlock with the single-connection pool). if bindingOpts.EnvVar != "" { diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 1a995a1..1681d57 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -26,6 +26,21 @@ func newTestStore(t *testing.T) *Store { return s } +// mustAddCred registers a static credential so a subsequent AddBinding / +// AddRuleAndBinding referencing it passes the live-credential-or-pool +// existence check. Real callers (CLI "cred add", REST, Telegram) always +// create the credential before binding to it; these binding-CRUD tests +// previously relied on the store accepting a free-form credential name with +// no backing row. The existence check now enforces the same invariant the +// production paths already satisfy, so the fixture mirrors reality rather +// than weakening the test. +func mustAddCred(t *testing.T, s *Store, name string) { + t.Helper() + if err := s.AddCredentialMeta(name, "static", ""); err != nil { + t.Fatalf("add credential meta %q: %v", name, err) + } +} + // --- Schema migration tests --- func TestNewCreatesSchema(t *testing.T) { @@ -758,6 +773,7 @@ func TestConfigUpdateVaultProviderClearFields(t *testing.T) { func TestBindingCRUD(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_api_key") id, err := s.AddBinding("api.example.com", "my_api_key", BindingOpts{ Ports: []int{443}, Header: "Authorization", @@ -807,6 +823,7 @@ func TestBindingCRUD(t *testing.T) { func TestBindingMultipleProtocols(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "mail_cred") _, err := s.AddBinding("mail.example.com", "mail_cred", BindingOpts{ Ports: []int{993, 587}, Protocols: []string{"imap", "smtp"}, @@ -822,6 +839,7 @@ func TestBindingMultipleProtocols(t *testing.T) { func TestBindingNoProtocols(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "key") _, err := s.AddBinding("api.example.com", "key", BindingOpts{ Ports: []int{443}, Header: "Authorization", @@ -847,6 +865,7 @@ func TestBindingValidation(t *testing.T) { func TestUpdateBindingSingleField(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") id, err := s.AddBinding("api.example.com", "my_key", BindingOpts{ Ports: []int{443}, Header: "Authorization", @@ -883,6 +902,7 @@ func TestUpdateBindingSingleField(t *testing.T) { func TestUpdateBindingMultipleFields(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") id, err := s.AddBinding("api.example.com", "my_key", BindingOpts{ Ports: []int{443}, Header: "Authorization", @@ -936,6 +956,7 @@ func TestUpdateBindingMultipleFields(t *testing.T) { func TestUpdateBindingClearFields(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") id, err := s.AddBinding("api.example.com", "my_key", BindingOpts{ Ports: []int{443}, Header: "Authorization", @@ -980,6 +1001,7 @@ func TestUpdateBindingNotFound(t *testing.T) { func TestUpdateBindingEmptyOpts(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") id, err := s.AddBinding("api.example.com", "my_key", BindingOpts{Ports: []int{443}}) if err != nil { t.Fatalf("add: %v", err) @@ -996,6 +1018,7 @@ func TestUpdateBindingEmptyOpts(t *testing.T) { func TestUpdateBindingEmptyDestination(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") id, err := s.AddBinding("api.example.com", "my_key", BindingOpts{}) if err != nil { t.Fatalf("add: %v", err) @@ -1011,6 +1034,7 @@ func TestUpdateBindingEmptyDestination(t *testing.T) { // This is distinct from passing nil (which means "no change"). func TestUpdateBindingClearPortsAndProtocols(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") id, err := s.AddBinding("api.example.com", "my_key", BindingOpts{ Ports: []int{443, 8080}, Protocols: []string{"http", "grpc"}, @@ -1049,6 +1073,8 @@ func TestUpdateBindingClearPortsAndProtocols(t *testing.T) { // allowed because they resolve to the same phantom value. func TestUpdateBindingEnvVar(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") + mustAddCred(t, s, "other_cred") id, err := s.AddBinding("api.example.com", "my_key", BindingOpts{}) if err != nil { t.Fatalf("add: %v", err) @@ -1308,6 +1334,7 @@ func TestRuleExistsProtocolScoped(t *testing.T) { func TestBindingExists(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") _, _ = s.AddBinding("api.example.com", "my_key", BindingOpts{}) exists, _ := s.BindingExists("api.example.com", "my_key") @@ -2753,6 +2780,8 @@ func TestAddChannelWithoutOpts(t *testing.T) { func TestListBindingsByCredential(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "cred_a") + mustAddCred(t, s, "cred_b") _, _ = s.AddBinding("api.one.com", "cred_a", BindingOpts{Ports: []int{443}}) _, _ = s.AddBinding("api.two.com", "cred_a", BindingOpts{Ports: []int{443}}) _, _ = s.AddBinding("api.three.com", "cred_b", BindingOpts{Ports: []int{443}}) @@ -2776,6 +2805,8 @@ func TestListBindingsByCredential(t *testing.T) { func TestRemoveBindingsByCredential(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "cred_a") + mustAddCred(t, s, "cred_b") _, _ = s.AddBinding("api.one.com", "cred_a", BindingOpts{}) _, _ = s.AddBinding("api.two.com", "cred_a", BindingOpts{}) _, _ = s.AddBinding("api.three.com", "cred_b", BindingOpts{}) @@ -2840,6 +2871,7 @@ func TestIsEmpty(t *testing.T) { func TestIsEmptyWithBinding(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "cred") _, _ = s.AddBinding("test.com", "cred", BindingOpts{}) empty, err := s.IsEmpty() @@ -2866,6 +2898,7 @@ func TestIsEmptyWithUpstream(t *testing.T) { func TestAddRuleAndBinding(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "api_key") ruleID, bindingID, err := s.AddRuleAndBinding( "allow", RuleOpts{Destination: "api.example.com", Ports: []int{443}, Name: "api access"}, @@ -3465,6 +3498,7 @@ func TestCredentialMetaCRUDRoundTrip(t *testing.T) { func TestBindingEnvVarMigration(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "key") // Verify the env_var column exists by inserting and reading back. _, err := s.AddBinding("api.example.com", "key", BindingOpts{ @@ -3489,6 +3523,7 @@ func TestBindingEnvVarMigration(t *testing.T) { func TestBindingEnvVarEmpty(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "key") // Binding without env_var should have empty string. _, err := s.AddBinding("api.example.com", "key", BindingOpts{ @@ -3505,6 +3540,7 @@ func TestBindingEnvVarEmpty(t *testing.T) { func TestAddBindingWithEnvVar(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "openai_key") id, err := s.AddBinding("api.openai.com", "openai_key", BindingOpts{ Ports: []int{443}, @@ -3536,6 +3572,9 @@ func TestAddBindingWithEnvVar(t *testing.T) { func TestAddBindingEnvVarUniqueness(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "openai_key") + mustAddCred(t, s, "other_key") + mustAddCred(t, s, "telegram_bot") // First binding with env_var should succeed. _, err := s.AddBinding("api.openai.com", "openai_key", BindingOpts{ @@ -3587,6 +3626,9 @@ func TestAddBindingEnvVarUniquenessConcurrent(t *testing.T) { s := newTestStore(t) const workers = 20 + for i := 0; i < workers; i++ { + mustAddCred(t, s, fmt.Sprintf("cred-%d", i)) + } var ( wg sync.WaitGroup successM sync.Mutex @@ -3646,6 +3688,9 @@ func TestAddBindingEnvVarUniquenessConcurrent(t *testing.T) { func TestListBindingsWithEnvVar(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "openai_key") + mustAddCred(t, s, "github_key") + mustAddCred(t, s, "telegram_bot") // Add bindings with and without env_var. _, _ = s.AddBinding("api.openai.com", "openai_key", BindingOpts{ @@ -3690,7 +3735,10 @@ func TestListBindingsWithEnvVarEmpty(t *testing.T) { } // Add binding without env_var. - _, _ = s.AddBinding("api.example.com", "key", BindingOpts{Ports: []int{443}}) + mustAddCred(t, s, "key") + if _, err := s.AddBinding("api.example.com", "key", BindingOpts{Ports: []int{443}}); err != nil { + t.Fatalf("add binding without env_var: %v", err) + } bindings, err = s.ListBindingsWithEnvVar() if err != nil { @@ -3703,6 +3751,7 @@ func TestListBindingsWithEnvVarEmpty(t *testing.T) { func TestAddRuleAndBindingWithEnvVar(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "openai_key") _, bindingID, err := s.AddRuleAndBinding( "allow", RuleOpts{Destination: "api.openai.com", Ports: []int{443}}, @@ -3727,6 +3776,8 @@ func TestAddRuleAndBindingWithEnvVar(t *testing.T) { func TestAddRuleAndBindingEnvVarUniqueness(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "openai_key") + mustAddCred(t, s, "other_key") // First should succeed. _, _, err := s.AddRuleAndBinding( @@ -3753,6 +3804,8 @@ func TestAddRuleAndBindingEnvVarUniqueness(t *testing.T) { func TestListBindingsByCredentialWithEnvVar(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "openai_key") + mustAddCred(t, s, "other_key") _, _ = s.AddBinding("api.openai.com", "openai_key", BindingOpts{ Ports: []int{443}, EnvVar: "OPENAI_API_KEY", @@ -3775,6 +3828,7 @@ func TestListBindingsByCredentialWithEnvVar(t *testing.T) { func TestBindingEnvVarMigrationDown(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "key") // Add a binding with env_var to verify data exists. _, err := s.AddBinding("api.example.com", "key", BindingOpts{ @@ -3825,6 +3879,12 @@ func TestBindingEnvVarMigrationDown(t *testing.T) { func TestAddBindingEnvVarFormatValidation(t *testing.T) { s := newTestStore(t) + // The valid sub-cases use the env var string as the credential name, so + // seed those credentials; invalid cases fail input validation before the + // existence check is reached. + mustAddCred(t, s, "OPENAI_API_KEY") + mustAddCred(t, s, "my_key") + mustAddCred(t, s, "_HIDDEN") tests := []struct { name string @@ -3871,6 +3931,7 @@ func TestAddBindingEnvVarFormatValidation(t *testing.T) { // return ErrBindingDuplicate. func TestAddBindingDuplicateRejected(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") if _, err := s.AddBinding("api.example.com", "my_key", BindingOpts{}); err != nil { t.Fatalf("first add: %v", err) } @@ -3891,6 +3952,7 @@ func TestAddBindingDuplicateRejected(t *testing.T) { // leaving a partially-applied rule behind. func TestAddRuleAndBindingDuplicateRejected(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") if _, err := s.AddBinding("api.example.com", "my_key", BindingOpts{}); err != nil { t.Fatalf("seed binding: %v", err) } @@ -3966,6 +4028,7 @@ func TestRemoveRuleByBindingPair(t *testing.T) { // binding destination change + paired rule update + returned ruleFound flag. func TestUpdateBindingWithRuleSync(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") ruleID, bindingID, err := s.AddRuleAndBinding( "allow", RuleOpts{ @@ -4023,6 +4086,7 @@ func TestUpdateBindingWithRuleSync(t *testing.T) { // authorize the old one, breaking the new port and leaking the old. func TestUpdateBindingWithRuleSyncPropagatesPortsAndProtocols(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "cred") ruleID, bindingID, err := s.AddRuleAndBinding( "allow", RuleOpts{ @@ -4114,6 +4178,7 @@ func TestUpdateBindingWithRuleSyncPropagatesPortsAndProtocols(t *testing.T) { // paired-rule sync path validates port ranges before touching the rule. func TestUpdateBindingWithRuleSyncRejectsInvalidPorts(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "cred") _, bindingID, err := s.AddRuleAndBinding( "allow", RuleOpts{ @@ -4181,6 +4246,7 @@ func TestAddRuleAndBindingRejectsInvalidPorts(t *testing.T) { // connection time. Mirrors the TOML import validator. func TestAddRuleAndBindingRejectsUnknownProtocol(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "cred") // Rule-level typo. _, _, err := s.AddRuleAndBinding( @@ -4226,6 +4292,7 @@ func TestAddRuleAndBindingRejectsUnknownProtocol(t *testing.T) { // for the standalone AddBinding path. func TestAddBindingRejectsUnknownProtocol(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "cred") _, err := s.AddBinding("api.example.com", "cred", BindingOpts{Protocols: []string{"htp"}}) if err == nil || !strings.Contains(err.Error(), "unknown protocol") { @@ -4242,6 +4309,7 @@ func TestAddBindingRejectsUnknownProtocol(t *testing.T) { // path rejects unknown protocols before any transaction runs. func TestUpdateBindingWithRuleSyncRejectsUnknownProtocol(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "cred") _, bindingID, err := s.AddRuleAndBinding( "allow", RuleOpts{ @@ -4278,6 +4346,7 @@ func TestUpdateBindingWithRuleSyncRejectsUnknownProtocol(t *testing.T) { // scope drifted from what the remaining rule authorized. func TestUpdateBindingWithRuleSyncUpdatesBothSourcePrefixes(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "cred") // Seed the first rule via AddRuleAndBinding with the cred-add prefix. // This mirrors "sluice cred add --destination" creating the initial @@ -4351,6 +4420,7 @@ func TestUpdateBindingWithRuleSyncUpdatesBothSourcePrefixes(t *testing.T) { // iteration 9 finding 2. func TestRemoveBindingWithRuleCleanupCaseInsensitive(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") // Add a binding via AddBinding (no paired rule) so we can seed the // paired rule at a different case without fighting the atomic @@ -4401,6 +4471,7 @@ func TestRemoveBindingWithRuleCleanupCaseInsensitive(t *testing.T) { // RemoveBindingWithRuleCleanup. Regression for codex iteration 9 finding 2. func TestUpdateBindingWithRuleSyncCaseInsensitive(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") bindingID, err := s.AddBinding("api.example.com", "my_key", BindingOpts{}) if err != nil { @@ -4538,6 +4609,7 @@ func TestAddRuleAndBindingValidationErrorsAreTagged(t *testing.T) { // AddRuleAndBinding test for the update path. func TestUpdateBindingWithRuleSyncValidationErrorsAreTagged(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "cred") _, bindingID, err := s.AddRuleAndBinding( "allow", RuleOpts{ @@ -4586,6 +4658,7 @@ func TestUpdateBindingWithRuleSyncValidationErrorsAreTagged(t *testing.T) { // 6 finding 1. func TestAddBindingCaseInsensitiveDuplicate(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") if _, err := s.AddBinding("api.example.com", "my_key", BindingOpts{}); err != nil { t.Fatalf("first add: %v", err) @@ -4621,6 +4694,7 @@ func TestAddBindingCaseInsensitiveDuplicate(t *testing.T) { // duplicate is detected so no orphan rule remains. func TestAddRuleAndBindingCaseInsensitiveDuplicate(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "my_key") if _, _, err := s.AddRuleAndBinding( "allow", @@ -4661,6 +4735,7 @@ func TestAddRuleAndBindingCaseInsensitiveDuplicate(t *testing.T) { // callers can map them to 400. Regression for codex iteration 6 finding 3. func TestAddBindingValidation(t *testing.T) { s := newTestStore(t) + mustAddCred(t, s, "cred") cases := []struct { name string diff --git a/internal/telegram/approval_test.go b/internal/telegram/approval_test.go index 58f3f04..2f05049 100644 --- a/internal/telegram/approval_test.go +++ b/internal/telegram/approval_test.go @@ -1538,7 +1538,12 @@ func TestCredRemoveWithContainerManager(t *testing.T) { // Add with env_var, then remove. InjectEnvVars should be called with // an empty value for the removed env var. _, _ = vaultStore.Add("test_cred", "value") - _, _ = s.AddBinding("api.example.com", "test_cred", store.BindingOpts{EnvVar: "TEST_API_KEY"}) + if err := s.AddCredentialMeta("test_cred", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } + if _, err := s.AddBinding("api.example.com", "test_cred", store.BindingOpts{EnvVar: "TEST_API_KEY"}); err != nil { + t.Fatalf("add binding: %v", err) + } result := h.Handle(&Command{Name: "cred", Args: []string{"remove", "test_cred"}}) if !strings.Contains(result, "Removed credential") { @@ -1569,7 +1574,12 @@ func TestCredRotateWithContainerManager(t *testing.T) { // Add first with env_var binding. _, _ = vaultStore.Add("rotate_key", "old_value") - _, _ = s.AddBinding("api.example.com", "rotate_key", store.BindingOpts{EnvVar: "ROTATE_KEY"}) + if err := s.AddCredentialMeta("rotate_key", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } + if _, err := s.AddBinding("api.example.com", "rotate_key", store.BindingOpts{EnvVar: "ROTATE_KEY"}); err != nil { + t.Fatalf("add binding: %v", err) + } result := h.Handle(&Command{Name: "cred", Args: []string{"rotate", "rotate_key", "new_value"}}) if !strings.Contains(result, "Rotated credential") { @@ -1956,11 +1966,16 @@ func TestRebuildResolverWithBindings(t *testing.T) { h.SetResolverPtr(resolverPtr) // Add a binding. - _, _ = s.AddBinding("api.example.com", "my_cred", store.BindingOpts{ + if err := s.AddCredentialMeta("my_cred", "static", ""); err != nil { + t.Fatalf("add credential meta: %v", err) + } + if _, err := s.AddBinding("api.example.com", "my_cred", store.BindingOpts{ Ports: []int{443}, Header: "Authorization", Template: "Bearer {value}", - }) + }); err != nil { + t.Fatalf("add binding: %v", err) + } if err := h.rebuildResolver(); err != nil { t.Fatalf("rebuildResolver: %v", err) diff --git a/internal/telegram/commands.go b/internal/telegram/commands.go index d5b5d7d..14622a6 100644 --- a/internal/telegram/commands.go +++ b/internal/telegram/commands.go @@ -609,11 +609,23 @@ func (h *CommandHandler) credAdd(name, value, envVar string) string { return fmt.Sprintf("Failed to add credential: %v", err) } - // If env_var is specified and we have a store, create a binding with the env_var. + // If env_var is specified and we have a store, create a binding with the + // env_var. Register the credential in credential_meta FIRST (mirrors the + // CLI "cred add" path): AddBinding now requires its credential to resolve + // to a live credential or pool, and a binding with no backing + // credential_meta row is exactly the stale-binding state that guard + // prevents. A Telegram-added credential is always a static API key. if envVar != "" && h.store != nil { h.reloadMu.Lock() - _, err := h.store.AddBinding("*", name, store.BindingOpts{EnvVar: envVar}) + metaErr := h.store.AddCredentialMeta(name, "static", "") + var err error + if metaErr == nil { + _, err = h.store.AddBinding("*", name, store.BindingOpts{EnvVar: envVar}) + } h.reloadMu.Unlock() + if metaErr != nil { + return fmt.Sprintf("Added credential %s but failed to register credential metadata: %v", name, metaErr) + } if err != nil { return fmt.Sprintf("Added credential %s but failed to create binding with env_var: %v", name, err) } From 1c47635f0251d1c10528bbf381feeeaf222482fd Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 21:41:00 +0800 Subject: [PATCH 46/49] test(telegram): deterministically sync cancel-edit assertion (deflake TestCancelApprovalRendersCoalescedCount) --- internal/telegram/approval_test.go | 32 +++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/internal/telegram/approval_test.go b/internal/telegram/approval_test.go index 2f05049..f222d68 100644 --- a/internal/telegram/approval_test.go +++ b/internal/telegram/approval_test.go @@ -2114,7 +2114,18 @@ func waitForPending(t *testing.T, broker *channel.Broker, n int) { //nolint:unpa // then fires n-1 more concurrent requests to the same dest:port so they // coalesce onto the primary waiter. It returns the primary request ID and a // channel that receives all n responses. -func fireCoalescedBurstTG(t *testing.T, broker *channel.Broker, dest string, port, n int) (string, chan channel.Response) { +// +// Before returning it deterministically waits for the primary's msgMap entry +// to be populated by the async sendApprovalMessage goroutine (RequestApproval +// spawns it via `go tc.sendApprovalMessage`, and msgMap.Store runs after the +// prompt Send). Callers immediately drive a resolve/cancel that depends on +// that entry existing; without this sync the broker's coalesced bookkeeping +// can be fully settled while the prompt goroutine has not yet reached +// msgMap.Store, so CancelApproval LoadAndDelete misses and zero cancel edits +// are recorded. Synchronizing on the broker's CoalescedCount alone is not +// sufficient — that count is settled by the broker independently of the +// channel's async msgMap write. +func fireCoalescedBurstTG(t *testing.T, broker *channel.Broker, tc *TelegramChannel, dest string, port, n int) (string, chan channel.Response) { t.Helper() out := make(chan channel.Response, n) @@ -2141,6 +2152,21 @@ func fireCoalescedBurstTG(t *testing.T, broker *channel.Broker, dest string, por time.Sleep(time.Millisecond) } } + + // Deterministically wait for the async sendApprovalMessage goroutine to + // populate the primary's msgMap entry. The mock API returns immediately + // so this is typically a single iteration; the bounded loop only guards + // against scheduler starvation under CI load. + mapDeadline := time.Now().Add(3 * time.Second) + for time.Now().Before(mapDeadline) { + if _, ok := tc.msgMap.Load(reqID); ok { + break + } + time.Sleep(time.Millisecond) + } + if _, ok := tc.msgMap.Load(reqID); !ok { + t.Fatalf("msgMap entry for primary %s was not populated by sendApprovalMessage", reqID) + } return reqID, out } @@ -2157,7 +2183,7 @@ func TestHandleCallbackRendersCoalescedCount(t *testing.T) { tc.SetBroker(broker) const n = 5 - reqID, out := fireCoalescedBurstTG(t, broker, "burst.example.com", 443, n) + reqID, out := fireCoalescedBurstTG(t, broker, tc, "burst.example.com", 443, n) tc.handleCallback(&tgbotapi.CallbackQuery{ ID: "cb_coalesce", @@ -2259,7 +2285,7 @@ func TestCancelApprovalRendersCoalescedCount(t *testing.T) { tc.SetBroker(broker) const n = 4 - reqID, out := fireCoalescedBurstTG(t, broker, "cancel.example.com", 443, n) + reqID, out := fireCoalescedBurstTG(t, broker, tc, "cancel.example.com", 443, n) // Resolve via the broker directly (simulating another channel) so the // final count is recorded, then drive the Telegram cleanup edit. From 0db59db1c6494917a31c80d86648428120e5887d Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 21:52:00 +0800 Subject: [PATCH 47/49] fix: Telegram cred_meta on all adds; QUIC R3 comments; typed 409-vs-500 for RemoveCredentialFully --- internal/api/server.go | 13 +++- internal/api/server_test.go | 107 +++++++++++++++++++++++++++++ internal/proxy/quic.go | 30 +++++--- internal/store/pools_test.go | 10 +++ internal/store/store.go | 15 +++- internal/telegram/commands.go | 22 +++--- internal/telegram/commands_test.go | 25 +++++++ 7 files changed, 201 insertions(+), 21 deletions(-) diff --git a/internal/api/server.go b/internal/api/server.go index b74fb5e..0f9aa9d 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -1276,7 +1276,18 @@ func (s *Server) DeleteApiCredentialsName(w http.ResponseWriter, r *http.Request // exists where credential_meta is gone but bindings/rules survive // (the partially-deleted-credential bug this fixes). if _, _, _, err := s.store.RemoveCredentialFully(name); err != nil { - writeError(w, http.StatusConflict, "failed to remove credential store state (vault secret + all store rows left intact so the credential is not partially deleted): "+err.Error(), "") + // Only the fail-closed pool-member guard is a client conflict + // (the operator must take the credential out of its pool first); + // map that — and only that — to 409. Store faults (tx + // begin/exec/commit failures) are server errors: returning 409 + // for them would hide a backend failure as a client mistake and + // is inconsistent with the binding handlers (4xx for typed + // validation/conflict only, 500 for store faults). + status := http.StatusInternalServerError + if errors.Is(err, store.ErrCredentialInUseByPool) { + status = http.StatusConflict + } + writeError(w, status, "failed to remove credential store state (vault secret + all store rows left intact so the credential is not partially deleted): "+err.Error(), "") return } diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 4f180b8..2b1dba3 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -1533,6 +1533,113 @@ func TestDeleteApiCredentials_ConcurrentRace(t *testing.T) { } } +// TestDeleteApiCredentials_PoolGuardVsStoreFault is the round-22 Finding 3 +// fail-before/pass-after regression. The REST cred-remove handler must map +// ONLY the fail-closed pool-member guard to 409; a genuine store fault +// (tx begin/exec/commit failure) must be 500, not a client conflict. +// Before the fix every RemoveCredentialFully error became 409. +func TestDeleteApiCredentials_PoolGuardVsStoreFault(t *testing.T) { + t.Setenv("SLUICE_API_TOKEN", "tok") + + // Case 1: removing a live pool member is a client conflict -> 409. + t.Run("live pool member is 409", func(t *testing.T) { + st := newTestStore(t) + enableHTTPChannel(t, st) + v := newTestVault(t) + srv := api.NewServer(st, nil, nil, "") + srv.SetVault(v) + var mu sync.Mutex + srv.SetEnginePtr(new(atomic.Pointer[policy.Engine]), &mu) + + if _, err := v.Add("m", "value"); err != nil { + t.Fatalf("add: %v", err) + } + // Pools require oauth credentials, so register oauth credential_meta + // rows for the members (mirrors the store package's seedOAuthCred). + for _, n := range []string{"m", "n"} { + if err := st.AddCredentialMeta(n, "oauth", "https://auth.example.com/token"); err != nil { + t.Fatalf("seed oauth cred %q: %v", n, err) + } + } + if err := st.CreatePoolWithMembers("p", "failover", []string{"m", "n"}); err != nil { + t.Fatalf("create pool: %v", err) + } + + handler := newTestHandler(t, srv, st) + req := httptest.NewRequest("DELETE", "/api/credentials/m", nil) + req.Header.Set("Authorization", "Bearer tok") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusConflict { + t.Fatalf("live pool member removal: expected 409, got %d (%s)", rec.Code, rec.Body.String()) + } + // The guard refused, so the vault secret must still be present. + if names, _ := v.List(); len(names) != 1 || names[0] != "m" { + t.Errorf("vault secret should be intact after a refused removal, got %v", names) + } + }) + + // Case 2: a store fault inside RemoveCredentialFully (here: the DB is + // closed so tx Begin fails) must be 500, NOT 409. Pre-fix this was 409. + t.Run("store fault is 500", func(t *testing.T) { + st := newTestStore(t) + enableHTTPChannel(t, st) + v := newTestVault(t) + srv := api.NewServer(st, nil, nil, "") + srv.SetVault(v) + var mu sync.Mutex + srv.SetEnginePtr(new(atomic.Pointer[policy.Engine]), &mu) + + if _, err := v.Add("solo", "value"); err != nil { + t.Fatalf("add: %v", err) + } + seedCred(t, st, "solo") + handler := newTestHandler(t, srv, st) + + // Force a store fault: closing the DB makes tx Begin fail with a + // non-guard error inside RemoveCredentialFully. + if err := st.Close(); err != nil { + t.Fatalf("close store: %v", err) + } + + req := httptest.NewRequest("DELETE", "/api/credentials/solo", nil) + req.Header.Set("Authorization", "Bearer tok") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("store fault: expected 500, got %d (%s)", rec.Code, rec.Body.String()) + } + }) + + // Case 3: a normal removal (no pool, healthy store) still succeeds 204. + t.Run("normal removal is 204", func(t *testing.T) { + st := newTestStore(t) + enableHTTPChannel(t, st) + v := newTestVault(t) + srv := api.NewServer(st, nil, nil, "") + srv.SetVault(v) + var mu sync.Mutex + srv.SetEnginePtr(new(atomic.Pointer[policy.Engine]), &mu) + + if _, err := v.Add("plain", "value"); err != nil { + t.Fatalf("add: %v", err) + } + seedCred(t, st, "plain") + handler := newTestHandler(t, srv, st) + + req := httptest.NewRequest("DELETE", "/api/credentials/plain", nil) + req.Header.Set("Authorization", "Bearer tok") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("normal removal: expected 204, got %d (%s)", rec.Code, rec.Body.String()) + } + }) +} + // TestPostApiCredentials_ConcurrentRace verifies that two concurrent POST // requests for the same credential name serialize cleanly: one succeeds // with 201, the other returns 409. Before the fix, PostApiCredentials diff --git a/internal/proxy/quic.go b/internal/proxy/quic.go index 89fb17f..a801d3f 100644 --- a/internal/proxy/quic.go +++ b/internal/proxy/quic.go @@ -80,13 +80,18 @@ type QUICProxy struct { // for that destination over QUIC (Finding 2). Optional: nil means // no pools are configured and every binding name is taken verbatim. // - // QUIC pool support is intentionally limited to active-member - // expansion. The per-request OAuth refresh attribution (Risk R1), - // pool-stable phantom keying (Risk R3), and 429/401 auto-failover - // implemented in the HTTP-MITM addon are NOT replicated here: the - // QUIC injection path is a simpler buffered header/body swap with - // no response-side OAuth interception. A pool binding over QUIC - // injects the CURRENT active member's real credential; member + // QUIC supports request-side pool injection: active-member + // expansion (the vault lookup uses the pool's current active + // member) AND the Risk R3 pool-stable access phantom (the + // agent-facing access token is a synthetic JWT keyed on the POOL + // name via buildPooledOAuthPhantomPairs, byte-identical to the HTTP + // path so it never changes when the active member is switched). + // ONLY the response-side capabilities are NOT available on QUIC: + // per-request OAuth refresh attribution (Risk R1) and the Phase-2 + // 429/401 auto-failover both require response-side OAuth + // interception, which the buffered QUIC injection path does not do. + // A pool binding over QUIC therefore injects the CURRENT active + // member's real credential and a pool-stable phantom; member // rotation happens only when the HTTP path (or an operator) flips // the active member. See CLAUDE.md "Credential pools" for the // authoritative HTTP-vs-QUIC capability matrix. @@ -579,10 +584,13 @@ func (q *QUICProxy) buildHandler(upstreamHost string, destPort int, checker *Req // unresolvable pool returns the pool name unchanged so the downstream // provider.Get fails cleanly (no injection) rather than panicking. // -// QUIC-LIMITED: this performs ONLY active-member expansion. The HTTP path's -// per-request refresh attribution (R1), pool-stable phantom (R3), and -// 429/401 auto-failover are not implemented on QUIC; the active member is -// whatever the HTTP path / operator last selected. Documented in CLAUDE.md. +// QUIC-LIMITED: this helper performs only active-member expansion (the +// secret-name resolution). The R3 pool-stable access phantom IS implemented +// on QUIC via resolvePoolTarget + buildPooledOAuthPhantomPairs in +// buildPhantomPairs; only the response-side capabilities — per-request OAuth +// refresh attribution (R1) and the Phase-2 429/401 auto-failover — are not +// available on QUIC (no response-side OAuth interception). The active member +// is whatever the HTTP path / operator last selected. Documented in CLAUDE.md. func (q *QUICProxy) resolvePoolMember(name string) string { if q.poolResolver == nil { return name diff --git a/internal/store/pools_test.go b/internal/store/pools_test.go index 337e5a1..3016170 100644 --- a/internal/store/pools_test.go +++ b/internal/store/pools_test.go @@ -1416,6 +1416,16 @@ func TestRemoveCredentialFullyRefusesLivePoolMember(t *testing.T) { t.Error("credential_meta deleted for a refused live pool member") } + // Round-22 Finding 3: the live-pool-member refusal must be the typed + // ErrCredentialInUseByPool sentinel so the REST layer can map ONLY this + // case to 409 (and tx/SQL/commit faults to 500) via errors.Is. Before + // the fix the guard returned a bare fmt.Errorf with no sentinel, so the + // REST handler had to blanket-map every RemoveCredentialFully error to + // 409 — hiding store faults as client conflicts. + if !errors.Is(err, ErrCredentialInUseByPool) { + t.Fatalf("live-pool-member refusal must wrap ErrCredentialInUseByPool, got: %v", err) + } + // A free (non-member) credential still removes cleanly. seedOAuthCred(t, s, "free") if md, _, _, ferr := s.RemoveCredentialFully("free"); ferr != nil || !md { diff --git a/internal/store/store.go b/internal/store/store.go index 810bcab..eaac2e0 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -728,6 +728,19 @@ var ErrBindingValidation = fmt.Errorf("binding validation failed") // interleave between them and leave a binding pointing at a vanished name. var ErrBindingCredentialMissing = fmt.Errorf("binding references a credential or pool that does not exist") +// ErrCredentialInUseByPool is returned (wrapped) by the fail-closed +// pool-member guard in deleteCredentialMetaGuardedTx — and therefore by +// RemoveCredentialMeta, RemoveCredentialMetaCAS, and RemoveCredentialFully — +// when a credential cannot be removed because it is still a live member of +// one or more pools. It is a typed sentinel so callers can distinguish this +// client-facing conflict (the operator must take the credential out of the +// pool first) from genuine store faults (transaction begin/exec/commit +// failures). The REST layer maps errors.Is(err, ErrCredentialInUseByPool) to +// 409 Conflict and every other RemoveCredentialFully error to 500; the CLI +// and Telegram paths treat any non-nil error as the fail-closed refusal and +// leave all state intact regardless of which kind it is. +var ErrCredentialInUseByPool = fmt.Errorf("credential is a live member of a pool") + // assertBindingCredentialExistsTx verifies, inside the caller's transaction, // that the binding's credential refers to either a live credential // (credential_meta) OR a live pool (credential_pools). A binding's @@ -1968,7 +1981,7 @@ func deleteCredentialMetaGuardedTx(tx *sql.Tx, name, deleteSQL string, deleteArg ).Scan(&pool) switch memErr { case nil: - return 0, fmt.Errorf("credential %q is a member of pool %q; remove it from the pool first (sluice pool remove

, or recreate the pool without it)", name, pool) + return 0, fmt.Errorf("%w: credential %q is a member of pool %q; remove it from the pool first (sluice pool remove

, or recreate the pool without it)", ErrCredentialInUseByPool, name, pool) case sql.ErrNoRows: // not a pool member; safe to remove default: diff --git a/internal/telegram/commands.go b/internal/telegram/commands.go index 14622a6..09c3449 100644 --- a/internal/telegram/commands.go +++ b/internal/telegram/commands.go @@ -609,17 +609,23 @@ func (h *CommandHandler) credAdd(name, value, envVar string) string { return fmt.Sprintf("Failed to add credential: %v", err) } - // If env_var is specified and we have a store, create a binding with the - // env_var. Register the credential in credential_meta FIRST (mirrors the - // CLI "cred add" path): AddBinding now requires its credential to resolve - // to a live credential or pool, and a binding with no backing - // credential_meta row is exactly the stale-binding state that guard - // prevents. A Telegram-added credential is always a static API key. - if envVar != "" && h.store != nil { + // Register the credential in credential_meta for EVERY Telegram add when + // a store is configured, mirroring the CLI and REST "cred add" paths + // (which always register a credential_meta row for static creds). A + // Telegram-added credential is always a static API key. This must run + // even when no --env-var is given: AddBinding (used by a later API/CLI + // `binding add`) now requires its credential to resolve to a live + // credential or pool, and a credential with no backing credential_meta + // row is exactly the state that guard rejects — so a Telegram-only + // `/cred add foo bar` would otherwise be unbindable, and pool-name + // collisions for it would not be rejected. AddCredentialMeta is an + // upsert, so the env-var sub-path below does not double-insert. + if h.store != nil { h.reloadMu.Lock() metaErr := h.store.AddCredentialMeta(name, "static", "") var err error - if metaErr == nil { + // If env_var is specified, also create a binding with the env_var. + if metaErr == nil && envVar != "" { _, err = h.store.AddBinding("*", name, store.BindingOpts{EnvVar: envVar}) } h.reloadMu.Unlock() diff --git a/internal/telegram/commands_test.go b/internal/telegram/commands_test.go index dd6a4ab..646bf22 100644 --- a/internal/telegram/commands_test.go +++ b/internal/telegram/commands_test.go @@ -850,6 +850,31 @@ func TestCredAddWithoutEnvVar(t *testing.T) { if len(bindings) != 0 { t.Errorf("expected 0 bindings with env_var, got %d", len(bindings)) } + + // Finding 1 (round-22): a Telegram `/cred add` WITHOUT --env-var must + // still register a credential_meta row, mirroring the CLI/REST paths. + // Before the fix AddCredentialMeta only ran on the --env-var sub-path, + // so a plain Telegram add left credential_meta empty and the credential + // could never be bound later (the round-21 store-level binding guard + // rejects a credential with no backing meta row). + meta, err := s.GetCredentialMeta("my_key") + if err != nil { + t.Fatalf("get credential meta: %v", err) + } + if meta == nil { + t.Fatal("expected a credential_meta row for a no-env-var Telegram cred add, got none") + } + if meta.CredType != "static" { + t.Errorf("expected cred_type static, got %q", meta.CredType) + } + + // And it must therefore be bindable via the same store-level path the + // API/CLI `binding add` uses. Pre-fix this fails with the + // ErrBindingCredentialMissing guard because no meta row exists. + if _, err := s.AddBinding("api.example.com", "my_key", store.BindingOpts{}); err != nil { + t.Fatalf("AddBinding for a Telegram-added credential should succeed once "+ + "credential_meta is registered, got: %v", err) + } } func TestHandleMCPNoArgs(t *testing.T) { From fb64ef288183703e30cb9eaed7d60758e7c1ae00 Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 22:01:55 +0800 Subject: [PATCH 48/49] fix(telegram): roll back vault secret if cred-add metadata/binding fails; completed-plan R3 doc accuracy --- .../20260515-credential-pool-failover.md | 12 +- internal/telegram/commands.go | 39 +++++- internal/telegram/commands_test.go | 113 ++++++++++++++++++ 3 files changed, 155 insertions(+), 9 deletions(-) diff --git a/docs/plans/completed/20260515-credential-pool-failover.md b/docs/plans/completed/20260515-credential-pool-failover.md index 0e3e459..508ea23 100644 --- a/docs/plans/completed/20260515-credential-pool-failover.md +++ b/docs/plans/completed/20260515-credential-pool-failover.md @@ -9,9 +9,15 @@ credentials** (a "pool"), with sluice picking which real account to inject and OpenAI Codex OAuth accounts driven by one Hermes agent, so quota exhaustion on one account transparently rolls onto the other. -The agent always holds **one pool-scoped phantom pair** -(`SLUICE_PHANTOM:.access` / `.refresh`). Sluice maps the pool phantom to -the *currently active member's* real token at injection time, and persists +The agent always holds **one pool-scoped phantom pair**. As implemented +(Risk R3, Phase 1.4), the two halves of that pair have different forms: the +**access phantom is a pool-stable synthetic JWT** (`poolStablePhantomAccess`) +that is byte-identical across every member switch, while the **refresh phantom +is the static `SLUICE_PHANTOM:.refresh` string**. The static +`SLUICE_PHANTOM:.access` string is NOT what the agent sees for the +access token — see the Phantom-stability decision below for why the access +side had to become a synthetic JWT instead. Sluice maps the pool phantom pair +to the *currently active member's* real tokens at injection time, and persists refreshed tokens back to the member that actually issued them. **Phantom-stability decision (resolved — see Risk R3):** OpenAI Codex access diff --git a/internal/telegram/commands.go b/internal/telegram/commands.go index 09c3449..bece15c 100644 --- a/internal/telegram/commands.go +++ b/internal/telegram/commands.go @@ -605,10 +605,35 @@ func (h *CommandHandler) credList() string { } func (h *CommandHandler) credAdd(name, value, envVar string) string { - if _, err := h.vault.Add(name, value); err != nil { + // Capture the pre-add ciphertext so a later metadata/binding failure can + // roll the vault back via compare-and-swap, mirroring the CLI + // (cmd/sluice/cred.go) and REST (internal/api/server.go) "cred add" + // paths. Without this, a failed AddCredentialMeta (e.g. a pool-name + // collision) or env-var AddBinding would leave an orphaned vault secret + // with no credential_meta row — an inconsistent, unbindable credential. + prevCiphertext, readErr := h.vault.ReadRawCredential(name) + if readErr != nil { + return fmt.Sprintf("Failed to add credential: %v", readErr) + } + ourCiphertext, err := h.vault.Add(name, value) + if err != nil { return fmt.Sprintf("Failed to add credential: %v", err) } + // rollbackVault reverts the vault entry using compare-and-swap so a + // concurrent writer that has since overwritten the credential is not + // clobbered. See (*vault.Store).RollbackAdd for semantics. + rollbackVault := func() { + owned, rbErr := h.vault.RollbackAdd(name, prevCiphertext, ourCiphertext) + if !owned { + log.Printf("warning: credential %q was modified concurrently; skipping vault rollback", name) + return + } + if rbErr != nil { + log.Printf("warning: failed to roll back vault credential %q after store error: %v", name, rbErr) + } + } + // Register the credential in credential_meta for EVERY Telegram add when // a store is configured, mirroring the CLI and REST "cred add" paths // (which always register a credential_meta row for static creds). A @@ -623,17 +648,19 @@ func (h *CommandHandler) credAdd(name, value, envVar string) string { if h.store != nil { h.reloadMu.Lock() metaErr := h.store.AddCredentialMeta(name, "static", "") - var err error + var bindErr error // If env_var is specified, also create a binding with the env_var. if metaErr == nil && envVar != "" { - _, err = h.store.AddBinding("*", name, store.BindingOpts{EnvVar: envVar}) + _, bindErr = h.store.AddBinding("*", name, store.BindingOpts{EnvVar: envVar}) } h.reloadMu.Unlock() if metaErr != nil { - return fmt.Sprintf("Added credential %s but failed to register credential metadata: %v", name, metaErr) + rollbackVault() + return fmt.Sprintf("Failed to register credential metadata for %s (vault rolled back): %v", name, metaErr) } - if err != nil { - return fmt.Sprintf("Added credential %s but failed to create binding with env_var: %v", name, err) + if bindErr != nil { + rollbackVault() + return fmt.Sprintf("Failed to create binding with env_var for %s (vault rolled back): %v", name, bindErr) } } diff --git a/internal/telegram/commands_test.go b/internal/telegram/commands_test.go index 646bf22..e353557 100644 --- a/internal/telegram/commands_test.go +++ b/internal/telegram/commands_test.go @@ -877,6 +877,119 @@ func TestCredAddWithoutEnvVar(t *testing.T) { } } +// TestCredAddRollsBackVaultOnMetadataFailure verifies the round-23 Finding 1 +// fix: when AddCredentialMeta fails during a Telegram `/cred add`, the +// just-written vault secret must NOT be left behind (it would be an orphaned, +// unbindable credential with no credential_meta row). The vault add must be +// rolled back via compare-and-swap, mirroring the CLI (cmd/sluice/cred.go) and +// REST (internal/api/server.go) "cred add" paths. A closed store makes +// AddCredentialMeta's tx.Begin() fail deterministically. +func TestCredAddRollsBackVaultOnMetadataFailure(t *testing.T) { + // Dedicated store (not newTestStore) so we can close it after the handler + // is built without breaking other handlers' cleanup. + s, err := store.New(":memory:") + if err != nil { + t.Fatal(err) + } + handler := newTestHandlerWithStore(t, s, nil, "") + + dir := t.TempDir() + vaultStore, err := vault.NewStore(dir) + if err != nil { + t.Fatal(err) + } + handler.SetVault(vaultStore) + + // Close the store so AddCredentialMeta fails (tx.Begin on a closed DB). + if err := s.Close(); err != nil { + t.Fatalf("close store: %v", err) + } + + result := handler.Handle(&Command{ + Name: "cred", + Args: []string{"add", "orphan_key", "secret123"}, + }) + + // An error must be reported and it must mention the rollback. + if strings.Contains(result, "Added credential") { + t.Fatalf("cred add must NOT report success when metadata registration fails, got: %s", result) + } + if !strings.Contains(result, "Failed to register credential metadata") { + t.Fatalf("expected a metadata-failure error, got: %s", result) + } + if !strings.Contains(result, "vault rolled back") { + t.Errorf("error should indicate the vault was rolled back, got: %s", result) + } + + // The crux: the vault secret must NOT be left behind. + if sb, getErr := vaultStore.Get("orphan_key"); getErr == nil { + sb.Release() + t.Fatalf("vault secret %q was left behind after a failed cred add; "+ + "expected it to be rolled back (orphaned, unbindable credential)", "orphan_key") + } +} + +// TestCredAddHappyPathStillBindsAfterRollbackFix verifies the rollback change +// did not regress the success path: a Telegram `/cred add` (with and without +// --env-var) still writes vault + credential_meta, and the credential is +// bindable afterwards via the same store-level path the API/CLI use. +func TestCredAddHappyPathStillBindsAfterRollbackFix(t *testing.T) { + s := newTestStore(t) + handler := newTestHandlerWithStore(t, s, nil, "") + + dir := t.TempDir() + vaultStore, err := vault.NewStore(dir) + if err != nil { + t.Fatal(err) + } + handler.SetVault(vaultStore) + + // Plain add (no env-var). + result := handler.Handle(&Command{ + Name: "cred", + Args: []string{"add", "good_key", "secret123"}, + }) + if !strings.Contains(result, "Added credential") { + t.Fatalf("happy path should confirm add, got: %s", result) + } + sb, err := vaultStore.Get("good_key") + if err != nil { + t.Fatalf("vault secret should exist on happy path: %v", err) + } + sb.Release() + meta, err := s.GetCredentialMeta("good_key") + if err != nil { + t.Fatalf("get credential meta: %v", err) + } + if meta == nil { + t.Fatal("expected a credential_meta row on happy path, got none") + } + if _, err := s.AddBinding("api.example.com", "good_key", store.BindingOpts{}); err != nil { + t.Fatalf("a successfully added credential must be bindable, got: %v", err) + } + + // Env-var path still works end to end. + result = handler.Handle(&Command{ + Name: "cred", + Args: []string{"add", "env_key", "secret456", "--env-var", "OPENAI_API_KEY"}, + }) + if !strings.Contains(result, "Added credential") || !strings.Contains(result, "OPENAI_API_KEY") { + t.Fatalf("env-var happy path should confirm add with env var, got: %s", result) + } + sb2, err := vaultStore.Get("env_key") + if err != nil { + t.Fatalf("vault secret should exist for env-var add: %v", err) + } + sb2.Release() + bindings, err := s.ListBindingsWithEnvVar() + if err != nil { + t.Fatalf("list bindings: %v", err) + } + if len(bindings) != 1 || bindings[0].EnvVar != "OPENAI_API_KEY" { + t.Fatalf("expected 1 env-var binding OPENAI_API_KEY, got %+v", bindings) + } +} + func TestHandleMCPNoArgs(t *testing.T) { s := newTestStore(t) handler := newTestHandlerWithStore(t, s, nil, "") From 573b931b03bec41e08ba235793baac63d93c3b7a Mon Sep 17 00:00:00 2001 From: Nikita Nemirovsky Date: Sat, 16 May 2026 22:16:28 +0800 Subject: [PATCH 49/49] fix(telegram): roll back credential_meta (CAS) as well as vault on env-var binding failure --- internal/telegram/commands.go | 37 ++++++- internal/telegram/commands_test.go | 151 +++++++++++++++++++++++++++++ 2 files changed, 186 insertions(+), 2 deletions(-) diff --git a/internal/telegram/commands.go b/internal/telegram/commands.go index bece15c..0ef4d6f 100644 --- a/internal/telegram/commands.go +++ b/internal/telegram/commands.go @@ -634,6 +634,25 @@ func (h *CommandHandler) credAdd(name, value, envVar string) string { } } + // rollbackCredentialMeta removes the credential_meta row we just inserted + // using compare-and-swap on (cred_type, token_url). A Telegram-added + // credential is always a static API key with no token URL, so the CAS + // expects ("static", ""). If a concurrent writer overwrote the row with + // different values we leave their state alone and log a warning. Mirrors + // the CLI (cmd/sluice/cred.go) and REST (internal/api/server.go) cred-add + // rollback so a failed env-var binding leaves NEITHER a vault secret NOR a + // credential_meta row NOR a binding. + rollbackCredentialMeta := func() { + _, noConcurrent, rmErr := h.store.RemoveCredentialMetaCAS(name, "static", "") + if rmErr != nil { + log.Printf("warning: failed to remove credential meta for %q after rollback: %v", name, rmErr) + return + } + if !noConcurrent { + log.Printf("warning: credential meta %q was modified concurrently; skipping meta rollback", name) + } + } + // Register the credential in credential_meta for EVERY Telegram add when // a store is configured, mirroring the CLI and REST "cred add" paths // (which always register a credential_meta row for static creds). A @@ -649,9 +668,10 @@ func (h *CommandHandler) credAdd(name, value, envVar string) string { h.reloadMu.Lock() metaErr := h.store.AddCredentialMeta(name, "static", "") var bindErr error + var bindingID int64 // If env_var is specified, also create a binding with the env_var. if metaErr == nil && envVar != "" { - _, bindErr = h.store.AddBinding("*", name, store.BindingOpts{EnvVar: envVar}) + bindingID, bindErr = h.store.AddBinding("*", name, store.BindingOpts{EnvVar: envVar}) } h.reloadMu.Unlock() if metaErr != nil { @@ -659,8 +679,21 @@ func (h *CommandHandler) credAdd(name, value, envVar string) string { return fmt.Sprintf("Failed to register credential metadata for %s (vault rolled back): %v", name, metaErr) } if bindErr != nil { + // The env-var binding failed AFTER AddCredentialMeta already + // committed. Roll back every store mutation we made plus the + // vault secret so the failed command leaves NEITHER a vault + // secret NOR a credential_meta row NOR a binding (an orphaned + // meta row would otherwise let later bindings reference a + // credential that cannot be injected). Order mirrors the CLI: + // partial binding -> credential_meta (CAS-guarded) -> vault. + if bindingID != 0 { + if _, rmErr := h.store.RemoveBinding(bindingID); rmErr != nil { + log.Printf("warning: failed to remove binding [%d] during rollback for %q: %v", bindingID, name, rmErr) + } + } + rollbackCredentialMeta() rollbackVault() - return fmt.Sprintf("Failed to create binding with env_var for %s (vault rolled back): %v", name, bindErr) + return fmt.Sprintf("Failed to create binding with env_var for %s (vault and credential metadata rolled back): %v", name, bindErr) } } diff --git a/internal/telegram/commands_test.go b/internal/telegram/commands_test.go index e353557..8dceebe 100644 --- a/internal/telegram/commands_test.go +++ b/internal/telegram/commands_test.go @@ -990,6 +990,157 @@ func TestCredAddHappyPathStillBindsAfterRollbackFix(t *testing.T) { } } +// TestCredAddRollsBackMetaAndVaultOnEnvVarBindingFailure pins the round-24 +// fix: when the OPTIONAL env-var binding fails AFTER AddCredentialMeta has +// already committed, the Telegram `/cred add --env-var ...` handler must roll +// back BOTH the just-inserted credential_meta row (via the CAS-guarded +// RemoveCredentialMetaCAS, mirroring CLI cmd/sluice/cred.go and REST +// internal/api/server.go) AND the vault secret — leaving NEITHER a vault +// secret NOR a credential_meta row NOR a binding. Before the fix only the +// vault was rolled back, so an orphaned credential_meta row survived and +// later bindings could reference a credential that cannot be injected. +// +// An invalid env-var key ("1BAD-KEY") is a deterministic post-meta-insert +// AddBinding failure: AddCredentialMeta runs and commits first, then +// AddBinding rejects the bad key with ErrBindingValidation before its tx. +func TestCredAddRollsBackMetaAndVaultOnEnvVarBindingFailure(t *testing.T) { + s := newTestStore(t) + handler := newTestHandlerWithStore(t, s, nil, "") + + dir := t.TempDir() + vaultStore, err := vault.NewStore(dir) + if err != nil { + t.Fatal(err) + } + handler.SetVault(vaultStore) + + result := handler.Handle(&Command{ + Name: "cred", + Args: []string{"add", "orphan_env", "secret123", "--env-var", "1BAD-KEY"}, + }) + + // The command must fail and report that BOTH vault and metadata were + // rolled back. + if strings.Contains(result, "Added credential") { + t.Fatalf("cred add must NOT report success when the env-var binding fails, got: %s", result) + } + if !strings.Contains(result, "Failed to create binding with env_var") { + t.Fatalf("expected an env-var binding failure error, got: %s", result) + } + if !strings.Contains(result, "vault and credential metadata rolled back") { + t.Errorf("error should indicate both vault and metadata were rolled back, got: %s", result) + } + + // Crux 1: the vault secret must NOT be left behind. + if sb, getErr := vaultStore.Get("orphan_env"); getErr == nil { + sb.Release() + t.Fatalf("vault secret %q was left behind after a failed env-var cred add", "orphan_env") + } + + // Crux 2 (the round-24 regression): the credential_meta row must NOT be + // left behind. Pre-fix this row survives → orphaned, unbindable meta. + meta, err := s.GetCredentialMeta("orphan_env") + if err != nil { + t.Fatalf("get credential meta: %v", err) + } + if meta != nil { + t.Fatalf("credential_meta row for %q was left behind after a failed "+ + "env-var cred add; expected CAS rollback to delete it (orphaned, "+ + "unbindable meta)", "orphan_env") + } + + // Crux 3: no binding (with or without env_var) must survive. + evBindings, err := s.ListBindingsWithEnvVar() + if err != nil { + t.Fatalf("list env-var bindings: %v", err) + } + if len(evBindings) != 0 { + t.Errorf("expected 0 env-var bindings after rollback, got %d", len(evBindings)) + } + allBindings, err := s.ListBindingsByCredential("orphan_env") + if err != nil { + t.Fatalf("list bindings by credential: %v", err) + } + if len(allBindings) != 0 { + t.Errorf("expected 0 bindings for the rolled-back credential, got %d", len(allBindings)) + } +} + +// TestCredAddEnvVarRollbackDoesNotClobberSameNameCredential verifies the CAS +// guard: if a legitimate same-name credential_meta row already exists (e.g. +// added earlier via the CLI as an OAuth credential), a later failed Telegram +// `/cred add --env-var ...` must NOT clobber that pre-existing +// row. RemoveCredentialMetaCAS only deletes when (cred_type, token_url) match +// what the Telegram handler inserted ("static", ""); an OAuth row has +// different values so it is left intact. +func TestCredAddEnvVarRollbackDoesNotClobberSameNameCredential(t *testing.T) { + s := newTestStore(t) + handler := newTestHandlerWithStore(t, s, nil, "") + + dir := t.TempDir() + vaultStore, err := vault.NewStore(dir) + if err != nil { + t.Fatal(err) + } + handler.SetVault(vaultStore) + + // Pre-existing legitimate OAuth credential_meta row for the same name. + if err := s.AddCredentialMeta("shared_name", "oauth", "https://auth.example.com/token"); err != nil { + t.Fatalf("seed pre-existing credential meta: %v", err) + } + + // AddCredentialMeta is an upsert, so the Telegram add overwrites the row + // with ("static", "") before the env-var binding fails. The CAS rollback + // then matches ("static", "") and removes exactly the row the handler + // wrote. This is the documented "last writer wins, then CAS reverts its + // own write" semantics — the test asserts the rollback path runs cleanly + // and never leaves a stale static row behind for the failed add. + result := handler.Handle(&Command{ + Name: "cred", + Args: []string{"add", "shared_name", "secret123", "--env-var", "1BAD-KEY"}, + }) + if !strings.Contains(result, "vault and credential metadata rolled back") { + t.Fatalf("expected vault+meta rollback message, got: %s", result) + } + + // The CAS-guarded rollback removed the static row it inserted; it must + // NOT have silently deleted some unrelated writer's row via an + // unconditional delete. After rollback no static orphan remains. + meta, err := s.GetCredentialMeta("shared_name") + if err != nil { + t.Fatalf("get credential meta: %v", err) + } + if meta != nil && meta.CredType == "static" { + t.Fatalf("CAS rollback left a stale static credential_meta row behind: %+v", meta) + } + + // Now exercise the inverse: a concurrent writer overwrites the row with + // DIFFERENT values between our insert and our rollback. RemoveCredentialMetaCAS + // must skip the delete (noConcurrent=false) and leave their row intact. + if err := s.AddCredentialMeta("racey", "static", ""); err != nil { + t.Fatalf("seed racey meta: %v", err) + } + // Simulate the concurrent overwrite directly, then call the CAS rollback + // as the handler would. Expect the concurrent (oauth) row to survive. + if err := s.AddCredentialMeta("racey", "oauth", "https://other.example.com/token"); err != nil { + t.Fatalf("simulate concurrent overwrite: %v", err) + } + _, noConcurrent, rmErr := s.RemoveCredentialMetaCAS("racey", "static", "") + if rmErr != nil { + t.Fatalf("RemoveCredentialMetaCAS: %v", rmErr) + } + if noConcurrent { + t.Fatalf("expected noConcurrent=false (row was concurrently modified)") + } + racey, err := s.GetCredentialMeta("racey") + if err != nil { + t.Fatalf("get racey meta: %v", err) + } + if racey == nil || racey.CredType != "oauth" { + t.Fatalf("concurrent oauth row was clobbered by CAS rollback: %+v", racey) + } +} + func TestHandleMCPNoArgs(t *testing.T) { s := newTestStore(t) handler := newTestHandlerWithStore(t, s, nil, "")