From 9c8e01da5ec381c760a8d6fde85d4d5a5a4e3bd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 15:54:38 +0200 Subject: [PATCH 1/7] Add user stop handling and stop metadata Refactor and extend abort/stop behavior to support targeted user stops. Introduces userStopRequest/plan/result types and handleUserStop/executeUserStopPlan to resolve and execute room-wide, active-turn, or queued-turn stops. Replaces direct abortRoom calls with handleUserStop in command and message handlers. Adds assistantStopMetadata and propagates it into streamingState and UI message metadata (including response status mapping). Tracks room run targets (source/initial events) and binds streaming state to room runs. Implements queue operations to drain or remove pending items by source event and finalizes stopped queue items, preserving ACK reaction removal and session notifications. Adjusts streaming finish logic to treat cancelled vs stop reasons appropriately. Includes unit tests for plan resolution, queued removal, and metadata emission. --- bridges/ai/abort_helpers.go | 207 +++++++++++++++++++++++-- bridges/ai/abort_helpers_test.go | 148 ++++++++++++++++++ bridges/ai/commands_parity.go | 10 +- bridges/ai/handlematrix.go | 11 +- bridges/ai/pending_queue.go | 70 +++++++++ bridges/ai/room_runs.go | 60 +++++++ bridges/ai/streaming_error_handling.go | 12 +- bridges/ai/streaming_init.go | 1 + bridges/ai/streaming_state.go | 2 + bridges/ai/turn_data.go | 3 + bridges/ai/ui_message_metadata.go | 11 ++ 11 files changed, 512 insertions(+), 23 deletions(-) create mode 100644 bridges/ai/abort_helpers_test.go diff --git a/bridges/ai/abort_helpers.go b/bridges/ai/abort_helpers.go index 45607429b..e5890f701 100644 --- a/bridges/ai/abort_helpers.go +++ b/bridges/ai/abort_helpers.go @@ -3,31 +3,204 @@ package ai import ( "context" "fmt" + "strings" "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/id" ) -func formatAbortNotice(stopped int) string { - if stopped <= 0 { - return "Agent was aborted." +type stopPlanKind string + +const ( + stopPlanKindNoMatch stopPlanKind = "no-match" + stopPlanKindRoomWide stopPlanKind = "room-wide" + stopPlanKindActive stopPlanKind = "active-turn" + stopPlanKindQueued stopPlanKind = "queued-turn" +) + +type userStopRequest struct { + Portal *bridgev2.Portal + Meta *PortalMetadata + ReplyTo id.EventID + RequestedByEventID id.EventID + RequestedVia string +} + +type userStopPlan struct { + Kind stopPlanKind + Scope string + TargetKind string + TargetEventID id.EventID +} + +type userStopResult struct { + Plan userStopPlan + ActiveStopped bool + QueuedStopped int + SubagentsStopped int +} + +func stopLabel(count int, singular string) string { + if count == 1 { + return singular + } + return singular + "s" +} + +func formatAbortNotice(result userStopResult) string { + switch result.Plan.Kind { + case stopPlanKindNoMatch: + return "No matching active or queued turn found for that reply." + case stopPlanKindActive: + if result.SubagentsStopped > 0 { + return fmt.Sprintf("Stopped that turn. Stopped %d %s.", result.SubagentsStopped, stopLabel(result.SubagentsStopped, "sub-agent")) + } + return "Stopped that turn." + case stopPlanKindQueued: + if result.QueuedStopped <= 1 { + return "Stopped that queued turn." + } + return fmt.Sprintf("Stopped %d queued %s.", result.QueuedStopped, stopLabel(result.QueuedStopped, "turn")) + case stopPlanKindRoomWide: + parts := make([]string, 0, 3) + if result.ActiveStopped { + parts = append(parts, "stopped the active turn") + } + if result.QueuedStopped > 0 { + parts = append(parts, fmt.Sprintf("removed %d queued %s", result.QueuedStopped, stopLabel(result.QueuedStopped, "turn"))) + } + if result.SubagentsStopped > 0 { + parts = append(parts, fmt.Sprintf("stopped %d %s", result.SubagentsStopped, stopLabel(result.SubagentsStopped, "sub-agent"))) + } + if len(parts) == 0 { + return "No active or queued turns to stop." + } + suffix := "" + if len(parts) > 1 { + suffix = " " + strings.Join(parts[1:], ". ") + "." + } + return strings.ToUpper(parts[0][:1]) + parts[0][1:] + "." + suffix + default: + return "No active or queued turns to stop." + } +} + +func (oc *AIClient) pendingQueueHasSourceEvent(roomID id.RoomID, sourceEventID id.EventID) bool { + if oc == nil || roomID == "" || sourceEventID == "" { + return false + } + queue := oc.getQueueSnapshot(roomID) + if queue == nil { + return false + } + for _, item := range queue.items { + if item.pending.sourceEventID() == sourceEventID { + return true + } + } + return false +} + +func buildStopMetadata(plan userStopPlan, req userStopRequest) *assistantStopMetadata { + return &assistantStopMetadata{ + Reason: "user_stop", + Scope: plan.Scope, + TargetKind: plan.TargetKind, + TargetEventID: plan.TargetEventID.String(), + RequestedByEventID: req.RequestedByEventID.String(), + RequestedVia: strings.TrimSpace(req.RequestedVia), + } +} + +func (oc *AIClient) resolveUserStopPlan(req userStopRequest) userStopPlan { + if req.Portal == nil || req.Portal.MXID == "" { + return userStopPlan{Kind: stopPlanKindNoMatch} + } + if req.ReplyTo == "" { + return userStopPlan{ + Kind: stopPlanKindRoomWide, + Scope: "room", + TargetKind: "all", + } + } + + _, sourceEventID, initialEventID, _ := oc.roomRunTarget(req.Portal.MXID) + if initialEventID != "" && req.ReplyTo == initialEventID { + return userStopPlan{ + Kind: stopPlanKindActive, + Scope: "turn", + TargetKind: "placeholder_event", + TargetEventID: req.ReplyTo, + } + } + if sourceEventID != "" && req.ReplyTo == sourceEventID { + return userStopPlan{ + Kind: stopPlanKindActive, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: req.ReplyTo, + } } - label := "sub-agents" - if stopped == 1 { - label = "sub-agent" + if oc.pendingQueueHasSourceEvent(req.Portal.MXID, req.ReplyTo) { + return userStopPlan{ + Kind: stopPlanKindQueued, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: req.ReplyTo, + } + } + return userStopPlan{ + Kind: stopPlanKindNoMatch, + Scope: "turn", + TargetEventID: req.ReplyTo, } - return fmt.Sprintf("Agent was aborted. Stopped %d %s.", stopped, label) } -func (oc *AIClient) abortRoom(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata) int { - if portal == nil { - return 0 +func (oc *AIClient) finalizeStoppedQueueItems(ctx context.Context, items []pendingQueueItem) int { + for _, item := range items { + if item.pending.Meta != nil && item.pending.Meta.AckReactionRemoveAfter { + oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) + } + oc.sendQueueRejectedStatus(ctx, item.pending.Portal, item.pending.Event, item.pending.StatusEvents, "Stopped.") } - oc.cancelRoomRun(portal.MXID) - oc.clearPendingQueue(portal.MXID) - stopped := oc.stopSubagentRuns(portal.MXID) - if meta != nil { - meta.AbortedLastRun = true - oc.savePortalQuiet(ctx, portal, "abort") + return len(items) +} + +func (oc *AIClient) executeUserStopPlan(ctx context.Context, req userStopRequest, plan userStopPlan) userStopResult { + result := userStopResult{Plan: plan} + if req.Portal == nil || req.Portal.MXID == "" { + return result } - return stopped + roomID := req.Portal.MXID + switch plan.Kind { + case stopPlanKindRoomWide: + if oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req)) { + result.ActiveStopped = oc.cancelRoomRun(roomID) + } + result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.drainPendingQueue(roomID)) + result.SubagentsStopped = oc.stopSubagentRuns(roomID) + case stopPlanKindActive: + if oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req)) { + result.ActiveStopped = oc.cancelRoomRun(roomID) + if result.ActiveStopped { + result.SubagentsStopped = oc.stopSubagentRuns(roomID) + } + } + case stopPlanKindQueued: + result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.removePendingQueueBySourceEvent(roomID, plan.TargetEventID)) + } + + if req.Meta != nil && (result.ActiveStopped || result.QueuedStopped > 0 || result.SubagentsStopped > 0) { + req.Meta.AbortedLastRun = true + oc.savePortalQuiet(ctx, req.Portal, "stop") + } + if req.Meta != nil && result.QueuedStopped > 0 { + oc.notifySessionMutation(ctx, req.Portal, req.Meta, false) + } + return result +} + +func (oc *AIClient) handleUserStop(ctx context.Context, req userStopRequest) userStopResult { + plan := oc.resolveUserStopPlan(req) + return oc.executeUserStopPlan(ctx, req, plan) } diff --git a/bridges/ai/abort_helpers_test.go b/bridges/ai/abort_helpers_test.go new file mode 100644 index 000000000..ce350392b --- /dev/null +++ b/bridges/ai/abort_helpers_test.go @@ -0,0 +1,148 @@ +package ai + +import ( + "context" + "testing" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/id" + + bridgesdk "github.com/beeper/agentremote/sdk" +) + +func TestResolveUserStopPlanRoomWideWithoutReply(t *testing.T) { + oc := &AIClient{} + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}} + req := userStopRequest{Portal: portal, RequestedVia: "command"} + + plan := oc.resolveUserStopPlan(req) + if plan.Kind != stopPlanKindRoomWide { + t.Fatalf("expected room-wide stop, got %#v", plan) + } + if plan.TargetKind != "all" || plan.Scope != "room" { + t.Fatalf("unexpected room-wide stop plan: %#v", plan) + } +} + +func TestResolveUserStopPlanMatchesActiveReplyTargets(t *testing.T) { + roomID := id.RoomID("!room:test") + oc := &AIClient{ + activeRoomRuns: map[id.RoomID]*roomRunState{ + roomID: { + sourceEvent: id.EventID("$user"), + initialEvent: id.EventID("$assistant"), + }, + }, + } + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} + + placeholderPlan := oc.resolveUserStopPlan(userStopRequest{ + Portal: portal, + ReplyTo: id.EventID("$assistant"), + }) + if placeholderPlan.Kind != stopPlanKindActive || placeholderPlan.TargetKind != "placeholder_event" { + t.Fatalf("expected placeholder-targeted active stop, got %#v", placeholderPlan) + } + + sourcePlan := oc.resolveUserStopPlan(userStopRequest{ + Portal: portal, + ReplyTo: id.EventID("$user"), + }) + if sourcePlan.Kind != stopPlanKindActive || sourcePlan.TargetKind != "source_event" { + t.Fatalf("expected source-targeted active stop, got %#v", sourcePlan) + } +} + +func TestResolveUserStopPlanMatchesQueuedReplyTarget(t *testing.T) { + roomID := id.RoomID("!room:test") + oc := &AIClient{ + pendingQueues: map[id.RoomID]*pendingQueue{ + roomID: { + items: []pendingQueueItem{{ + pending: pendingMessage{SourceEventID: id.EventID("$queued")}, + }}, + }, + }, + } + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} + + plan := oc.resolveUserStopPlan(userStopRequest{ + Portal: portal, + ReplyTo: id.EventID("$queued"), + }) + if plan.Kind != stopPlanKindQueued || plan.TargetKind != "source_event" { + t.Fatalf("expected queued stop plan, got %#v", plan) + } +} + +func TestExecuteUserStopPlanRemovesOnlyTargetedQueuedTurn(t *testing.T) { + roomID := id.RoomID("!room:test") + oc := &AIClient{ + pendingQueues: map[id.RoomID]*pendingQueue{ + roomID: { + items: []pendingQueueItem{ + {pending: pendingMessage{SourceEventID: id.EventID("$one")}}, + {pending: pendingMessage{SourceEventID: id.EventID("$two")}}, + }, + }, + }, + } + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} + + result := oc.executeUserStopPlan(context.Background(), userStopRequest{ + Portal: portal, + }, userStopPlan{ + Kind: stopPlanKindQueued, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: id.EventID("$one"), + }) + if result.QueuedStopped != 1 { + t.Fatalf("expected one queued turn to stop, got %#v", result) + } + snapshot := oc.getQueueSnapshot(roomID) + if snapshot == nil || len(snapshot.items) != 1 { + t.Fatalf("expected one queued item to remain, got %#v", snapshot) + } + if got := snapshot.items[0].pending.sourceEventID(); got != id.EventID("$two") { + t.Fatalf("expected remaining queued event $two, got %q", got) + } +} + +func TestBuildStreamUIMessageIncludesStopMetadata(t *testing.T) { + oc := &AIClient{} + conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) + turn := conv.StartTurn(context.Background(), nil, &bridgesdk.SourceRef{EventID: "$user", SenderID: "@user:test"}) + turn.SetID("turn-stop") + state := &streamingState{ + turn: turn, + finishReason: "stop", + stop: &assistantStopMetadata{ + Reason: "user_stop", + Scope: "turn", + TargetKind: "source_event", + TargetEventID: "$user", + RequestedByEventID: "$stop", + RequestedVia: "command", + }, + responseID: "resp_123", + completedAtMs: 1, + } + + ui := oc.buildStreamUIMessage(state, nil, nil) + metadata, ok := ui["metadata"].(map[string]any) + if !ok { + t.Fatalf("expected metadata map, got %T", ui["metadata"]) + } + stop, ok := metadata["stop"].(map[string]any) + if !ok { + t.Fatalf("expected nested stop metadata, got %#v", metadata["stop"]) + } + if stop["reason"] != "user_stop" || stop["requested_via"] != "command" { + t.Fatalf("unexpected stop metadata: %#v", stop) + } + if metadata["response_status"] != "cancelled" { + t.Fatalf("expected cancelled response status for stopped turn, got %#v", metadata["response_status"]) + } +} diff --git a/bridges/ai/commands_parity.go b/bridges/ai/commands_parity.go index bafb22ded..cdbaf9dee 100644 --- a/bridges/ai/commands_parity.go +++ b/bridges/ai/commands_parity.go @@ -65,6 +65,12 @@ func fnStop(ce *commands.Event) { if !ok { return } - stopped := client.abortRoom(ce.Ctx, ce.Portal, meta) - ce.Reply("%s", formatAbortNotice(stopped)) + result := client.handleUserStop(ce.Ctx, userStopRequest{ + Portal: ce.Portal, + Meta: meta, + ReplyTo: ce.ReplyTo, + RequestedByEventID: ce.EventID, + RequestedVia: "command", + }) + ce.Reply("%s", formatAbortNotice(result)) } diff --git a/bridges/ai/handlematrix.go b/bridges/ai/handlematrix.go index 4604d084f..d1864069d 100644 --- a/bridges/ai/handlematrix.go +++ b/bridges/ai/handlematrix.go @@ -135,8 +135,15 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri return &bridgev2.MatrixMessageResponse{Pending: false}, nil } if commandAuthorized && airuntime.IsAbortTriggerText(commandBody) { - stopped := oc.abortRoom(ctx, portal, meta) - oc.sendSystemNotice(ctx, portal, formatAbortNotice(stopped)) + replyCtx := extractInboundReplyContext(msg.Event) + result := oc.handleUserStop(ctx, userStopRequest{ + Portal: portal, + Meta: meta, + ReplyTo: replyCtx.ReplyTo, + RequestedByEventID: msg.Event.ID, + RequestedVia: "text-trigger", + }) + oc.sendSystemNotice(ctx, portal, formatAbortNotice(result)) logCtx.Debug().Msg("Abort trigger handled") return &bridgev2.MatrixMessageResponse{Pending: false}, nil } diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index aee69c527..ac02e3f4c 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -38,6 +38,16 @@ type pendingQueue struct { lastItem *pendingQueueItem } +func (pm pendingMessage) sourceEventID() id.EventID { + if pm.SourceEventID != "" { + return pm.SourceEventID + } + if pm.Event != nil { + return pm.Event.ID + } + return "" +} + type pendingQueueDispatchCandidate struct { items []pendingQueueItem summaryPrompt string @@ -85,6 +95,66 @@ func (oc *AIClient) clearPendingQueue(roomID id.RoomID) { } } +func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { + if oc == nil || roomID == "" { + return nil + } + oc.pendingQueuesMu.Lock() + queue := oc.pendingQueues[roomID] + if queue == nil { + oc.pendingQueuesMu.Unlock() + return nil + } + delete(oc.pendingQueues, roomID) + oc.pendingQueuesMu.Unlock() + + queue.mu.Lock() + items := slices.Clone(queue.items) + queue.items = nil + queue.summaryLines = nil + queue.droppedCount = 0 + queue.lastItem = nil + queue.mu.Unlock() + + oc.stopQueueTyping(roomID) + return items +} + +func (oc *AIClient) removePendingQueueBySourceEvent(roomID id.RoomID, sourceEventID id.EventID) []pendingQueueItem { + if oc == nil || roomID == "" || sourceEventID == "" { + return nil + } + oc.pendingQueuesMu.Lock() + queue := oc.pendingQueues[roomID] + if queue == nil { + oc.pendingQueuesMu.Unlock() + return nil + } + queue.mu.Lock() + removed := make([]pendingQueueItem, 0, 1) + kept := queue.items[:0] + for _, item := range queue.items { + if item.pending.sourceEventID() == sourceEventID { + removed = append(removed, item) + continue + } + kept = append(kept, item) + } + clear(queue.items[len(kept):]) + queue.items = kept + empty := len(queue.items) == 0 && queue.droppedCount == 0 + if empty { + delete(oc.pendingQueues, roomID) + } + queue.mu.Unlock() + oc.pendingQueuesMu.Unlock() + + if empty { + oc.stopQueueTyping(roomID) + } + return removed +} + func (oc *AIClient) enqueuePendingItem(roomID id.RoomID, item pendingQueueItem, settings airuntime.QueueSettings) bool { queue := oc.getPendingQueue(roomID, settings) if queue == nil { diff --git a/bridges/ai/room_runs.go b/bridges/ai/room_runs.go index 640711645..35083e768 100644 --- a/bridges/ai/room_runs.go +++ b/bridges/ai/room_runs.go @@ -13,6 +13,11 @@ type roomRunState struct { cancel context.CancelFunc mu sync.Mutex + state *streamingState + stop *assistantStopMetadata + turnID string + sourceEvent id.EventID + initialEvent id.EventID streaming bool steerQueue []pendingQueueItem statusEvents []*event.Event @@ -97,6 +102,61 @@ func (oc *AIClient) markRoomRunStreaming(roomID id.RoomID, streaming bool) { run.mu.Unlock() } +func (oc *AIClient) bindRoomRunState(roomID id.RoomID, state *streamingState) { + run := oc.getRoomRun(roomID) + if run == nil { + return + } + run.mu.Lock() + run.state = state + if run.stop != nil && state != nil { + state.stop = run.stop + } + if state != nil && state.turn != nil { + run.turnID = state.turn.ID() + run.sourceEvent = state.sourceEventID() + run.initialEvent = state.turn.InitialEventID() + } + run.mu.Unlock() +} + +func (oc *AIClient) roomRunTarget(roomID id.RoomID) (turnID string, sourceEventID, initialEventID id.EventID, state *streamingState) { + run := oc.getRoomRun(roomID) + if run == nil { + return "", "", "", nil + } + run.mu.Lock() + defer run.mu.Unlock() + state = run.state + turnID = run.turnID + sourceEventID = run.sourceEvent + initialEventID = run.initialEvent + if state == nil || state.turn == nil { + return turnID, sourceEventID, initialEventID, state + } + turnID = state.turn.ID() + sourceEventID = state.sourceEventID() + initialEventID = state.turn.InitialEventID() + run.turnID = turnID + run.sourceEvent = sourceEventID + run.initialEvent = initialEventID + return turnID, sourceEventID, initialEventID, state +} + +func (oc *AIClient) markRoomRunStopped(roomID id.RoomID, stop *assistantStopMetadata) bool { + run := oc.getRoomRun(roomID) + if run == nil || stop == nil { + return false + } + run.mu.Lock() + run.stop = stop + if run.state != nil { + run.state.stop = stop + } + run.mu.Unlock() + return true +} + func (oc *AIClient) enqueueSteerQueue(roomID id.RoomID, item pendingQueueItem) bool { run := oc.getRoomRun(roomID) if run == nil { diff --git a/bridges/ai/streaming_error_handling.go b/bridges/ai/streaming_error_handling.go index 14a848fa8..0ea6a9e50 100644 --- a/bridges/ai/streaming_error_handling.go +++ b/bridges/ai/streaming_error_handling.go @@ -40,6 +40,9 @@ func (oc *AIClient) finishStreamingWithFailure( reason string, err error, ) error { + if state != nil && state.stop != nil && reason == "cancelled" { + reason = "stop" + } state.finishReason = reason state.completedAtMs = time.Now().UnixMilli() _ = log @@ -47,12 +50,17 @@ func (oc *AIClient) finishStreamingWithFailure( if writer := state.writer(); writer != nil { writer.MessageMetadata(ctx, oc.buildUIMessageMetadata(state, meta, true)) } - if reason == "cancelled" { + switch reason { + case "cancelled": state.writer().Abort(ctx, "cancelled") if state != nil && state.turn != nil { state.turn.End(msgconv.MapFinishReason(reason)) } - } else { + case "stop": + if state != nil && state.turn != nil { + state.turn.End(msgconv.MapFinishReason(reason)) + } + default: if state != nil && state.turn != nil { state.turn.EndWithError(err.Error()) } diff --git a/bridges/ai/streaming_init.go b/bridges/ai/streaming_init.go index 1833ec075..7f929cfd0 100644 --- a/bridges/ai/streaming_init.go +++ b/bridges/ai/streaming_init.go @@ -128,6 +128,7 @@ func (oc *AIClient) prepareStreamingRun( // Create SDK Turn for writer/emitter/session management. turn := oc.createStreamingTurn(ctx, portal, meta, state, sourceEventID, senderID) state.turn = turn + oc.bindRoomRunState(roomID, state) state.replyTarget = oc.resolveInitialReplyTarget(evt) if state.replyTarget.ThreadRoot != "" { diff --git a/bridges/ai/streaming_state.go b/bridges/ai/streaming_state.go index 18990d95c..f4ba59511 100644 --- a/bridges/ai/streaming_state.go +++ b/bridges/ai/streaming_state.go @@ -69,6 +69,8 @@ type streamingState struct { // Pending MCP approvals to resolve before the turn can continue. pendingMcpApprovals []mcpApprovalRequest pendingMcpApprovalsSeen map[string]bool + + stop *assistantStopMetadata } // sourceEventID returns the triggering user message event ID from the turn's source ref. diff --git a/bridges/ai/turn_data.go b/bridges/ai/turn_data.go index 6448b6da4..bd5e00201 100644 --- a/bridges/ai/turn_data.go +++ b/bridges/ai/turn_data.go @@ -72,6 +72,9 @@ func canonicalResponseStatus(state *streamingState) string { if strings.TrimSpace(state.responseID) == "" { return status } + if state.stop != nil { + return "cancelled" + } switch strings.TrimSpace(state.finishReason) { case "", "stop": diff --git a/bridges/ai/ui_message_metadata.go b/bridges/ai/ui_message_metadata.go index 5c888330a..e96b7fe50 100644 --- a/bridges/ai/ui_message_metadata.go +++ b/bridges/ai/ui_message_metadata.go @@ -13,6 +13,15 @@ type assistantUsageMetadata struct { TotalTokens int64 `json:"total_tokens,omitempty"` } +type assistantStopMetadata struct { + Reason string `json:"reason,omitempty"` + Scope string `json:"scope,omitempty"` + TargetKind string `json:"target_kind,omitempty"` + TargetEventID string `json:"target_event_id,omitempty"` + RequestedByEventID string `json:"requested_by_event_id,omitempty"` + RequestedVia string `json:"requested_via,omitempty"` +} + type assistantTurnMetadata struct { TurnID string `json:"turn_id,omitempty"` AgentID string `json:"agent_id,omitempty"` @@ -28,6 +37,7 @@ type assistantTurnMetadata struct { SourceEventID string `json:"source_event_id,omitempty"` GeneratedFileRefs []GeneratedFileRef `json:"generated_file_refs,omitempty"` Usage *assistantUsageMetadata `json:"usage,omitempty"` + Stop *assistantStopMetadata `json:"stop,omitempty"` } func buildAssistantUsageMetadata(state *streamingState) *assistantUsageMetadata { @@ -70,5 +80,6 @@ func buildAssistantTurnMetadata(state *streamingState, turnID, networkMessageID, SourceEventID: state.sourceEventID().String(), GeneratedFileRefs: agentremote.GeneratedFileRefsFromParts(state.generatedFiles), Usage: buildAssistantUsageMetadata(state), + Stop: state.stop, }) } From a691f3cd3485e65dfa7e620776622b20ec28893d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 16:04:30 +0200 Subject: [PATCH 2/7] Refactor queue checks, abort text, and streaming errors Move pendingQueueHasSourceEvent into pending_queue.go and implement proper locking (pendingQueuesMu and queue.mu) to safely inspect queue items. Simplify drainPendingQueue to delete the queue map entry and return its items directly. Remove the duplicate helper from abort_helpers.go. Improve formatAbortNotice by capitalizing each sentence part and joining them with ". " for clearer messages. Remove redundant run field assignments in roomRunTarget. Adjust finishStreamingWithFailure to fall through from "cancelled" to "stop" so cancelled streams call End like stop cases and remove some redundant nil checks. These changes tidy concurrency handling, clarify abort messaging, and simplify streaming error handling. --- bridges/ai/abort_helpers.go | 22 +++--------------- bridges/ai/pending_queue.go | 32 +++++++++++++++++++------- bridges/ai/room_runs.go | 3 --- bridges/ai/streaming_error_handling.go | 8 +++---- 4 files changed, 30 insertions(+), 35 deletions(-) diff --git a/bridges/ai/abort_helpers.go b/bridges/ai/abort_helpers.go index e5890f701..545209f12 100644 --- a/bridges/ai/abort_helpers.go +++ b/bridges/ai/abort_helpers.go @@ -75,31 +75,15 @@ func formatAbortNotice(result userStopResult) string { if len(parts) == 0 { return "No active or queued turns to stop." } - suffix := "" - if len(parts) > 1 { - suffix = " " + strings.Join(parts[1:], ". ") + "." + for i := range parts { + parts[i] = strings.ToUpper(parts[i][:1]) + parts[i][1:] } - return strings.ToUpper(parts[0][:1]) + parts[0][1:] + "." + suffix + return strings.Join(parts, ". ") + "." default: return "No active or queued turns to stop." } } -func (oc *AIClient) pendingQueueHasSourceEvent(roomID id.RoomID, sourceEventID id.EventID) bool { - if oc == nil || roomID == "" || sourceEventID == "" { - return false - } - queue := oc.getQueueSnapshot(roomID) - if queue == nil { - return false - } - for _, item := range queue.items { - if item.pending.sourceEventID() == sourceEventID { - return true - } - } - return false -} func buildStopMetadata(plan userStopPlan, req userStopRequest) *assistantStopMetadata { return &assistantStopMetadata{ diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index ac02e3f4c..b252ff48d 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -106,20 +106,36 @@ func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { return nil } delete(oc.pendingQueues, roomID) + items := queue.items oc.pendingQueuesMu.Unlock() - queue.mu.Lock() - items := slices.Clone(queue.items) - queue.items = nil - queue.summaryLines = nil - queue.droppedCount = 0 - queue.lastItem = nil - queue.mu.Unlock() - oc.stopQueueTyping(roomID) return items } +func (oc *AIClient) pendingQueueHasSourceEvent(roomID id.RoomID, sourceEventID id.EventID) bool { + if oc == nil || roomID == "" || sourceEventID == "" { + return false + } + oc.pendingQueuesMu.Lock() + queue := oc.pendingQueues[roomID] + if queue == nil { + oc.pendingQueuesMu.Unlock() + return false + } + queue.mu.Lock() + found := false + for _, item := range queue.items { + if item.pending.sourceEventID() == sourceEventID { + found = true + break + } + } + queue.mu.Unlock() + oc.pendingQueuesMu.Unlock() + return found +} + func (oc *AIClient) removePendingQueueBySourceEvent(roomID id.RoomID, sourceEventID id.EventID) []pendingQueueItem { if oc == nil || roomID == "" || sourceEventID == "" { return nil diff --git a/bridges/ai/room_runs.go b/bridges/ai/room_runs.go index 35083e768..a12f5af66 100644 --- a/bridges/ai/room_runs.go +++ b/bridges/ai/room_runs.go @@ -137,9 +137,6 @@ func (oc *AIClient) roomRunTarget(roomID id.RoomID) (turnID string, sourceEventI turnID = state.turn.ID() sourceEventID = state.sourceEventID() initialEventID = state.turn.InitialEventID() - run.turnID = turnID - run.sourceEvent = sourceEventID - run.initialEvent = initialEventID return turnID, sourceEventID, initialEventID, state } diff --git a/bridges/ai/streaming_error_handling.go b/bridges/ai/streaming_error_handling.go index 0ea6a9e50..6f949bdb3 100644 --- a/bridges/ai/streaming_error_handling.go +++ b/bridges/ai/streaming_error_handling.go @@ -53,15 +53,13 @@ func (oc *AIClient) finishStreamingWithFailure( switch reason { case "cancelled": state.writer().Abort(ctx, "cancelled") - if state != nil && state.turn != nil { - state.turn.End(msgconv.MapFinishReason(reason)) - } + fallthrough case "stop": - if state != nil && state.turn != nil { + if state.turn != nil { state.turn.End(msgconv.MapFinishReason(reason)) } default: - if state != nil && state.turn != nil { + if state.turn != nil { state.turn.EndWithError(err.Error()) } } From 5d7e0b2773454a95d55964a61aca455ac7e01aab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 16:16:40 +0200 Subject: [PATCH 3/7] Make stop metadata atomic and thread-safe Change streamingState.stop to an atomic.Pointer[assistantStopMetadata] and update all callsites to use .Load()/.Store() to avoid races. Fix pending queue drain to lock the queue when accessing items to prevent concurrent access. Improve room run logic to prefer current state.turn when present and store stop metadata atomically when marking a run stopped. Use utf8 + unicode for correct Unicode-aware capitalization in abort notices and update tests to store stop metadata via the new atomic API. --- bridges/ai/abort_helpers.go | 6 ++++-- bridges/ai/abort_helpers_test.go | 20 ++++++++++---------- bridges/ai/pending_queue.go | 2 ++ bridges/ai/room_runs.go | 18 ++++++------------ bridges/ai/streaming_error_handling.go | 2 +- bridges/ai/streaming_state.go | 3 ++- bridges/ai/turn_data.go | 2 +- bridges/ai/ui_message_metadata.go | 2 +- 8 files changed, 27 insertions(+), 28 deletions(-) diff --git a/bridges/ai/abort_helpers.go b/bridges/ai/abort_helpers.go index 545209f12..d88261acc 100644 --- a/bridges/ai/abort_helpers.go +++ b/bridges/ai/abort_helpers.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "strings" + "unicode" + "unicode/utf8" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/id" @@ -76,7 +78,8 @@ func formatAbortNotice(result userStopResult) string { return "No active or queued turns to stop." } for i := range parts { - parts[i] = strings.ToUpper(parts[i][:1]) + parts[i][1:] + r, size := utf8.DecodeRuneInString(parts[i]) + parts[i] = string(unicode.ToUpper(r)) + parts[i][size:] } return strings.Join(parts, ". ") + "." default: @@ -84,7 +87,6 @@ func formatAbortNotice(result userStopResult) string { } } - func buildStopMetadata(plan userStopPlan, req userStopRequest) *assistantStopMetadata { return &assistantStopMetadata{ Reason: "user_stop", diff --git a/bridges/ai/abort_helpers_test.go b/bridges/ai/abort_helpers_test.go index ce350392b..ff2994de1 100644 --- a/bridges/ai/abort_helpers_test.go +++ b/bridges/ai/abort_helpers_test.go @@ -116,19 +116,19 @@ func TestBuildStreamUIMessageIncludesStopMetadata(t *testing.T) { turn := conv.StartTurn(context.Background(), nil, &bridgesdk.SourceRef{EventID: "$user", SenderID: "@user:test"}) turn.SetID("turn-stop") state := &streamingState{ - turn: turn, - finishReason: "stop", - stop: &assistantStopMetadata{ - Reason: "user_stop", - Scope: "turn", - TargetKind: "source_event", - TargetEventID: "$user", - RequestedByEventID: "$stop", - RequestedVia: "command", - }, + turn: turn, + finishReason: "stop", responseID: "resp_123", completedAtMs: 1, } + state.stop.Store(&assistantStopMetadata{ + Reason: "user_stop", + Scope: "turn", + TargetKind: "source_event", + TargetEventID: "$user", + RequestedByEventID: "$stop", + RequestedVia: "command", + }) ui := oc.buildStreamUIMessage(state, nil, nil) metadata, ok := ui["metadata"].(map[string]any) diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index b252ff48d..2b4f9f35c 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -106,7 +106,9 @@ func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { return nil } delete(oc.pendingQueues, roomID) + queue.mu.Lock() items := queue.items + queue.mu.Unlock() oc.pendingQueuesMu.Unlock() oc.stopQueueTyping(roomID) diff --git a/bridges/ai/room_runs.go b/bridges/ai/room_runs.go index a12f5af66..274463f0b 100644 --- a/bridges/ai/room_runs.go +++ b/bridges/ai/room_runs.go @@ -110,7 +110,7 @@ func (oc *AIClient) bindRoomRunState(roomID id.RoomID, state *streamingState) { run.mu.Lock() run.state = state if run.stop != nil && state != nil { - state.stop = run.stop + state.stop.Store(run.stop) } if state != nil && state.turn != nil { run.turnID = state.turn.ID() @@ -128,16 +128,10 @@ func (oc *AIClient) roomRunTarget(roomID id.RoomID) (turnID string, sourceEventI run.mu.Lock() defer run.mu.Unlock() state = run.state - turnID = run.turnID - sourceEventID = run.sourceEvent - initialEventID = run.initialEvent - if state == nil || state.turn == nil { - return turnID, sourceEventID, initialEventID, state - } - turnID = state.turn.ID() - sourceEventID = state.sourceEventID() - initialEventID = state.turn.InitialEventID() - return turnID, sourceEventID, initialEventID, state + if state != nil && state.turn != nil { + return state.turn.ID(), state.sourceEventID(), state.turn.InitialEventID(), state + } + return run.turnID, run.sourceEvent, run.initialEvent, state } func (oc *AIClient) markRoomRunStopped(roomID id.RoomID, stop *assistantStopMetadata) bool { @@ -148,7 +142,7 @@ func (oc *AIClient) markRoomRunStopped(roomID id.RoomID, stop *assistantStopMeta run.mu.Lock() run.stop = stop if run.state != nil { - run.state.stop = stop + run.state.stop.Store(stop) } run.mu.Unlock() return true diff --git a/bridges/ai/streaming_error_handling.go b/bridges/ai/streaming_error_handling.go index 6f949bdb3..92bb9769b 100644 --- a/bridges/ai/streaming_error_handling.go +++ b/bridges/ai/streaming_error_handling.go @@ -40,7 +40,7 @@ func (oc *AIClient) finishStreamingWithFailure( reason string, err error, ) error { - if state != nil && state.stop != nil && reason == "cancelled" { + if state != nil && state.stop.Load() != nil && reason == "cancelled" { reason = "stop" } state.finishReason = reason diff --git a/bridges/ai/streaming_state.go b/bridges/ai/streaming_state.go index f4ba59511..cb2050450 100644 --- a/bridges/ai/streaming_state.go +++ b/bridges/ai/streaming_state.go @@ -3,6 +3,7 @@ package ai import ( "context" "strings" + "sync/atomic" "time" "github.com/openai/openai-go/v3/packages/param" @@ -70,7 +71,7 @@ type streamingState struct { pendingMcpApprovals []mcpApprovalRequest pendingMcpApprovalsSeen map[string]bool - stop *assistantStopMetadata + stop atomic.Pointer[assistantStopMetadata] } // sourceEventID returns the triggering user message event ID from the turn's source ref. diff --git a/bridges/ai/turn_data.go b/bridges/ai/turn_data.go index bd5e00201..24c138aae 100644 --- a/bridges/ai/turn_data.go +++ b/bridges/ai/turn_data.go @@ -72,7 +72,7 @@ func canonicalResponseStatus(state *streamingState) string { if strings.TrimSpace(state.responseID) == "" { return status } - if state.stop != nil { + if state.stop.Load() != nil { return "cancelled" } diff --git a/bridges/ai/ui_message_metadata.go b/bridges/ai/ui_message_metadata.go index e96b7fe50..b55abf328 100644 --- a/bridges/ai/ui_message_metadata.go +++ b/bridges/ai/ui_message_metadata.go @@ -80,6 +80,6 @@ func buildAssistantTurnMetadata(state *streamingState, turnID, networkMessageID, SourceEventID: state.sourceEventID().String(), GeneratedFileRefs: agentremote.GeneratedFileRefsFromParts(state.generatedFiles), Usage: buildAssistantUsageMetadata(state), - Stop: state.stop, + Stop: state.stop.Load(), }) } From ed72a84b34bd1a29e0a4606458c28578f6f89861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 16:26:27 +0200 Subject: [PATCH 4/7] Centralize ack removal and simplify queue logic Move AckReactionRemoveAfter checks into removePendingAckReactions and remove duplicate guarded calls throughout the codebase, so callers simply invoke the removal and the function decides whether to act. Simplify pending queue management by replacing clearPendingQueue usage with drainPendingQueue and delete the pendingQueueHasSourceEvent helper. Adjust stop-plan handling to speculatively return queued stops and add a fallback in executeUserStopPlan to convert a queued plan to no-match if nothing was drained. Update tests to reflect the new speculative behavior and the fallback. --- bridges/ai/abort_helpers.go | 18 +++++---------- bridges/ai/abort_helpers_test.go | 39 ++++++++++++++++++++------------ bridges/ai/client.go | 6 ++--- bridges/ai/pending_queue.go | 37 +++--------------------------- bridges/ai/room_runs.go | 4 +--- bridges/ai/subagent_registry.go | 6 ++--- 6 files changed, 39 insertions(+), 71 deletions(-) diff --git a/bridges/ai/abort_helpers.go b/bridges/ai/abort_helpers.go index d88261acc..985cbde4e 100644 --- a/bridges/ai/abort_helpers.go +++ b/bridges/ai/abort_helpers.go @@ -127,26 +127,17 @@ func (oc *AIClient) resolveUserStopPlan(req userStopRequest) userStopPlan { TargetEventID: req.ReplyTo, } } - if oc.pendingQueueHasSourceEvent(req.Portal.MXID, req.ReplyTo) { - return userStopPlan{ - Kind: stopPlanKindQueued, - Scope: "turn", - TargetKind: "source_event", - TargetEventID: req.ReplyTo, - } - } return userStopPlan{ - Kind: stopPlanKindNoMatch, + Kind: stopPlanKindQueued, Scope: "turn", + TargetKind: "source_event", TargetEventID: req.ReplyTo, } } func (oc *AIClient) finalizeStoppedQueueItems(ctx context.Context, items []pendingQueueItem) int { for _, item := range items { - if item.pending.Meta != nil && item.pending.Meta.AckReactionRemoveAfter { - oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) - } + oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) oc.sendQueueRejectedStatus(ctx, item.pending.Portal, item.pending.Event, item.pending.StatusEvents, "Stopped.") } return len(items) @@ -174,6 +165,9 @@ func (oc *AIClient) executeUserStopPlan(ctx context.Context, req userStopRequest } case stopPlanKindQueued: result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.removePendingQueueBySourceEvent(roomID, plan.TargetEventID)) + if result.QueuedStopped == 0 { + result.Plan.Kind = stopPlanKindNoMatch + } } if req.Meta != nil && (result.ActiveStopped || result.QueuedStopped > 0 || result.SubagentsStopped > 0) { diff --git a/bridges/ai/abort_helpers_test.go b/bridges/ai/abort_helpers_test.go index ff2994de1..a1c7c108b 100644 --- a/bridges/ai/abort_helpers_test.go +++ b/bridges/ai/abort_helpers_test.go @@ -54,25 +54,36 @@ func TestResolveUserStopPlanMatchesActiveReplyTargets(t *testing.T) { } } -func TestResolveUserStopPlanMatchesQueuedReplyTarget(t *testing.T) { - roomID := id.RoomID("!room:test") - oc := &AIClient{ - pendingQueues: map[id.RoomID]*pendingQueue{ - roomID: { - items: []pendingQueueItem{{ - pending: pendingMessage{SourceEventID: id.EventID("$queued")}, - }}, - }, - }, - } - portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} +func TestResolveUserStopPlanSpeculativelyReturnsQueued(t *testing.T) { + oc := &AIClient{} + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}} plan := oc.resolveUserStopPlan(userStopRequest{ Portal: portal, - ReplyTo: id.EventID("$queued"), + ReplyTo: id.EventID("$unknown"), }) if plan.Kind != stopPlanKindQueued || plan.TargetKind != "source_event" { - t.Fatalf("expected queued stop plan, got %#v", plan) + t.Fatalf("expected speculative queued stop plan, got %#v", plan) + } +} + +func TestExecuteUserStopPlanFallsBackToNoMatch(t *testing.T) { + oc := &AIClient{} + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}} + + result := oc.executeUserStopPlan(context.Background(), userStopRequest{ + Portal: portal, + }, userStopPlan{ + Kind: stopPlanKindQueued, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: id.EventID("$nonexistent"), + }) + if result.Plan.Kind != stopPlanKindNoMatch { + t.Fatalf("expected no-match fallback, got %#v", result.Plan) + } + if result.QueuedStopped != 0 { + t.Fatalf("expected zero queued stopped, got %d", result.QueuedStopped) } } diff --git a/bridges/ai/client.go b/bridges/ai/client.go index b181d7123..ed4a86660 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -667,7 +667,7 @@ func (oc *AIClient) dispatchOrQueueCore( metaSnapshot := clonePortalMetadata(meta) go func(metaSnapshot *PortalMetadata) { defer func() { - if hasDBMessage && metaSnapshot != nil && metaSnapshot.AckReactionRemoveAfter { + if hasDBMessage { oc.removePendingAckReactions(oc.backgroundContext(ctx), portal, queueItem.pending) } oc.releaseRoom(roomID) @@ -815,9 +815,7 @@ func (oc *AIClient) processPendingQueue(ctx context.Context, roomID id.RoomID) { if err != nil { oc.loggerForContext(ctx).Err(err).Msg("Failed to build prompt for pending queue item") oc.notifyMatrixSendFailure(ctx, item.pending.Portal, item.pending.Event, err) - if item.pending.Meta != nil && item.pending.Meta.AckReactionRemoveAfter { - oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) - } + oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) oc.releaseRoom(roomID) oc.processPendingQueue(oc.backgroundContext(ctx), roomID) return diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index 2b4f9f35c..6b8b684d8 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -86,13 +86,7 @@ func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSe } func (oc *AIClient) clearPendingQueue(roomID id.RoomID) { - oc.pendingQueuesMu.Lock() - _, existed := oc.pendingQueues[roomID] - delete(oc.pendingQueues, roomID) - oc.pendingQueuesMu.Unlock() - if existed { - oc.stopQueueTyping(roomID) - } + oc.drainPendingQueue(roomID) } func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { @@ -115,29 +109,6 @@ func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { return items } -func (oc *AIClient) pendingQueueHasSourceEvent(roomID id.RoomID, sourceEventID id.EventID) bool { - if oc == nil || roomID == "" || sourceEventID == "" { - return false - } - oc.pendingQueuesMu.Lock() - queue := oc.pendingQueues[roomID] - if queue == nil { - oc.pendingQueuesMu.Unlock() - return false - } - queue.mu.Lock() - found := false - for _, item := range queue.items { - if item.pending.sourceEventID() == sourceEventID { - found = true - break - } - } - queue.mu.Unlock() - oc.pendingQueuesMu.Unlock() - return found -} - func (oc *AIClient) removePendingQueueBySourceEvent(roomID id.RoomID, sourceEventID id.EventID) []pendingQueueItem { if oc == nil || roomID == "" || sourceEventID == "" { return nil @@ -509,9 +480,7 @@ func (oc *AIClient) dispatchQueuedPrompt( metaSnapshot := clonePortalMetadata(item.pending.Meta) go func() { defer func() { - if metaSnapshot != nil && metaSnapshot.AckReactionRemoveAfter { - oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) - } + oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) if item.backlogAfter { followup := item followup.backlogAfter = false @@ -527,7 +496,7 @@ func (oc *AIClient) dispatchQueuedPrompt( } func (oc *AIClient) removePendingAckReactions(ctx context.Context, portal *bridgev2.Portal, pending pendingMessage) { - if portal == nil { + if portal == nil || pending.Meta == nil || !pending.Meta.AckReactionRemoveAfter { return } ids := pending.AckEventIDs diff --git a/bridges/ai/room_runs.go b/bridges/ai/room_runs.go index 274463f0b..c83c7e7be 100644 --- a/bridges/ai/room_runs.go +++ b/bridges/ai/room_runs.go @@ -76,9 +76,7 @@ func (oc *AIClient) clearRoomRun(roomID id.RoomID) { } ctx := oc.backgroundContext(context.Background()) for _, pending := range ackPending { - if pending.Meta != nil && pending.Meta.AckReactionRemoveAfter { - oc.removePendingAckReactions(ctx, pending.Portal, pending) - } + oc.removePendingAckReactions(ctx, pending.Portal, pending) } } diff --git a/bridges/ai/subagent_registry.go b/bridges/ai/subagent_registry.go index 2a7bf9633..4772b5a5d 100644 --- a/bridges/ai/subagent_registry.go +++ b/bridges/ai/subagent_registry.go @@ -43,10 +43,8 @@ func (oc *AIClient) stopSubagentRuns(parent id.RoomID) int { continue } canceled := oc.cancelRoomRun(run.ChildRoomID) - queueSnapshot := oc.getQueueSnapshot(run.ChildRoomID) - hasQueued := queueSnapshot != nil && (len(queueSnapshot.items) > 0 || queueSnapshot.droppedCount > 0) - oc.clearPendingQueue(run.ChildRoomID) - if canceled || hasQueued { + drained := oc.drainPendingQueue(run.ChildRoomID) + if canceled || len(drained) > 0 { stopped++ } } From 02c907a13849285ae692a6321e891e5ae029e9d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 16:48:56 +0200 Subject: [PATCH 5/7] Update pending_queue.go --- bridges/ai/pending_queue.go | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index 6b8b684d8..efdaa03b2 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -239,14 +239,22 @@ func (oc *AIClient) getQueueSnapshot(roomID id.RoomID) *pendingQueue { } queue.mu.Lock() defer queue.mu.Unlock() - clone := *queue - clone.items = slices.Clone(queue.items) - clone.summaryLines = slices.Clone(queue.summaryLines) + clone := &pendingQueue{ + items: slices.Clone(queue.items), + draining: queue.draining, + lastEnqueuedAt: queue.lastEnqueuedAt, + mode: queue.mode, + debounceMs: queue.debounceMs, + cap: queue.cap, + dropPolicy: queue.dropPolicy, + droppedCount: queue.droppedCount, + summaryLines: slices.Clone(queue.summaryLines), + } if queue.lastItem != nil { lastItem := *queue.lastItem clone.lastItem = &lastItem } - return &clone + return clone } func (oc *AIClient) roomHasPendingQueueWork(roomID id.RoomID) bool { From a242d262b5d7d0c4b4b8ebc667b02aa5aa92afee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 17:30:09 +0200 Subject: [PATCH 6/7] Propagate context to queue/stop ops and refine stop handling Pass context through pending-queue and subagent stop helpers (clearPendingQueue, stopSubagentRuns, finalizeStoppedQueueItems) and always finalize/drain pending items when clearing queues. Fix executeUserStopPlan logic to mark active stops before cancelling and fall back to a no-match when an active stop is a no-op. Ensure removePendingAckReactions is always invoked in goroutine cleanup. Adjust finishStreamingWithFailure to properly end turns on cancelled streams without falling through, and prefer explicit stop flag in canonicalResponseStatus. Add tests covering the no-op active stop fallback, cancelled finish behavior, and canonicalResponseStatus preference. --- bridges/ai/abort_helpers.go | 13 ++++++---- bridges/ai/abort_helpers_test.go | 28 +++++++++++++++++++++ bridges/ai/client.go | 6 ++--- bridges/ai/commands_parity.go | 2 +- bridges/ai/internal_dispatch.go | 2 +- bridges/ai/pending_queue.go | 4 +-- bridges/ai/streaming_error_handling.go | 4 ++- bridges/ai/streaming_error_handling_test.go | 27 ++++++++++++++++++++ bridges/ai/subagent_registry.go | 6 +++-- bridges/ai/turn_data.go | 6 ++--- bridges/ai/turn_data_test.go | 9 +++++++ 11 files changed, 88 insertions(+), 19 deletions(-) diff --git a/bridges/ai/abort_helpers.go b/bridges/ai/abort_helpers.go index 985cbde4e..54e668225 100644 --- a/bridges/ai/abort_helpers.go +++ b/bridges/ai/abort_helpers.go @@ -155,13 +155,16 @@ func (oc *AIClient) executeUserStopPlan(ctx context.Context, req userStopRequest result.ActiveStopped = oc.cancelRoomRun(roomID) } result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.drainPendingQueue(roomID)) - result.SubagentsStopped = oc.stopSubagentRuns(roomID) + result.SubagentsStopped = oc.stopSubagentRuns(ctx, roomID) case stopPlanKindActive: - if oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req)) { + markedStopped := oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req)) + if markedStopped { result.ActiveStopped = oc.cancelRoomRun(roomID) - if result.ActiveStopped { - result.SubagentsStopped = oc.stopSubagentRuns(roomID) - } + } + if result.ActiveStopped { + result.SubagentsStopped = oc.stopSubagentRuns(ctx, roomID) + } else { + result.Plan.Kind = stopPlanKindNoMatch } case stopPlanKindQueued: result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.removePendingQueueBySourceEvent(roomID, plan.TargetEventID)) diff --git a/bridges/ai/abort_helpers_test.go b/bridges/ai/abort_helpers_test.go index a1c7c108b..ca9597ee3 100644 --- a/bridges/ai/abort_helpers_test.go +++ b/bridges/ai/abort_helpers_test.go @@ -121,6 +121,34 @@ func TestExecuteUserStopPlanRemovesOnlyTargetedQueuedTurn(t *testing.T) { } } +func TestExecuteUserStopPlanActiveNoOpFallsBackToNoMatch(t *testing.T) { + roomID := id.RoomID("!room:test") + oc := &AIClient{ + activeRoomRuns: map[id.RoomID]*roomRunState{ + roomID: { + sourceEvent: id.EventID("$user"), + }, + }, + } + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} + + result := oc.executeUserStopPlan(context.Background(), userStopRequest{ + Portal: portal, + ReplyTo: id.EventID("$user"), + }, userStopPlan{ + Kind: stopPlanKindActive, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: id.EventID("$user"), + }) + if result.Plan.Kind != stopPlanKindNoMatch { + t.Fatalf("expected no-match fallback for no-op active stop, got %#v", result.Plan) + } + if result.ActiveStopped { + t.Fatalf("expected active stop to report false, got %#v", result) + } +} + func TestBuildStreamUIMessageIncludesStopMetadata(t *testing.T) { oc := &AIClient{} conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) diff --git a/bridges/ai/client.go b/bridges/ai/client.go index ed4a86660..01916cfce 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -641,7 +641,7 @@ func (oc *AIClient) dispatchOrQueueCore( queueDecision := airuntime.DecideQueueAction(queueSettings.Mode, roomBusy, false) if queueDecision.Action == airuntime.QueueActionInterruptAndRun { oc.cancelRoomRun(roomID) - oc.clearPendingQueue(roomID) + oc.clearPendingQueue(ctx, roomID) roomBusy = false } if !roomBusy && oc.acquireRoom(roomID) { @@ -667,9 +667,7 @@ func (oc *AIClient) dispatchOrQueueCore( metaSnapshot := clonePortalMetadata(meta) go func(metaSnapshot *PortalMetadata) { defer func() { - if hasDBMessage { - oc.removePendingAckReactions(oc.backgroundContext(ctx), portal, queueItem.pending) - } + oc.removePendingAckReactions(oc.backgroundContext(ctx), portal, queueItem.pending) oc.releaseRoom(roomID) oc.processPendingQueue(oc.backgroundContext(ctx), roomID) }() diff --git a/bridges/ai/commands_parity.go b/bridges/ai/commands_parity.go index cdbaf9dee..e3ee4b634 100644 --- a/bridges/ai/commands_parity.go +++ b/bridges/ai/commands_parity.go @@ -45,7 +45,7 @@ func fnReset(ce *commands.Event) { meta.SessionResetAt = time.Now().UnixMilli() client.savePortalQuiet(ce.Ctx, ce.Portal, "session reset") - client.clearPendingQueue(ce.Portal.MXID) + client.clearPendingQueue(ce.Ctx, ce.Portal.MXID) client.cancelRoomRun(ce.Portal.MXID) ce.Reply("%s", formatSystemAck("Session reset.")) diff --git a/bridges/ai/internal_dispatch.go b/bridges/ai/internal_dispatch.go index d4cfd6570..aed0669eb 100644 --- a/bridges/ai/internal_dispatch.go +++ b/bridges/ai/internal_dispatch.go @@ -107,7 +107,7 @@ func (oc *AIClient) dispatchInternalMessage( queueDecision := airuntime.DecideQueueAction(queueSettings.Mode, oc.roomHasActiveRun(portal.MXID), false) if queueDecision.Action == airuntime.QueueActionInterruptAndRun { oc.cancelRoomRun(portal.MXID) - oc.clearPendingQueue(portal.MXID) + oc.clearPendingQueue(ctx, portal.MXID) } if shouldSteer && pending.Type == pendingTypeText { queueItem.prompt = pending.MessageBody diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index efdaa03b2..35650828a 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -85,8 +85,8 @@ func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSe return queue } -func (oc *AIClient) clearPendingQueue(roomID id.RoomID) { - oc.drainPendingQueue(roomID) +func (oc *AIClient) clearPendingQueue(ctx context.Context, roomID id.RoomID) { + oc.finalizeStoppedQueueItems(ctx, oc.drainPendingQueue(roomID)) } func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { diff --git a/bridges/ai/streaming_error_handling.go b/bridges/ai/streaming_error_handling.go index 92bb9769b..63932df0e 100644 --- a/bridges/ai/streaming_error_handling.go +++ b/bridges/ai/streaming_error_handling.go @@ -53,7 +53,9 @@ func (oc *AIClient) finishStreamingWithFailure( switch reason { case "cancelled": state.writer().Abort(ctx, "cancelled") - fallthrough + if state.turn != nil { + state.turn.End("cancelled") + } case "stop": if state.turn != nil { state.turn.End(msgconv.MapFinishReason(reason)) diff --git a/bridges/ai/streaming_error_handling_test.go b/bridges/ai/streaming_error_handling_test.go index 6ae8fbfa4..770e6aae5 100644 --- a/bridges/ai/streaming_error_handling_test.go +++ b/bridges/ai/streaming_error_handling_test.go @@ -5,10 +5,12 @@ import ( "errors" "testing" + "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" + "github.com/beeper/agentremote/pkg/shared/streamui" bridgesdk "github.com/beeper/agentremote/sdk" ) @@ -84,3 +86,28 @@ func TestStreamFailureErrorUsesAnyMessageTarget(t *testing.T) { } }) } + +func TestFinishStreamingWithFailureCancelledEndsTurnAsCancelled(t *testing.T) { + state := newTestStreamingStateWithTurn() + state.turn.SetSuppressSend(true) + state.writer().TextDelta(context.Background(), "hello") + + err := (&AIClient{}).finishStreamingWithFailure( + context.Background(), + zerolog.Nop(), + nil, + state, + nil, + "cancelled", + context.Canceled, + ) + if err == nil { + t.Fatal("expected wrapped cancellation error") + } + + message := streamui.SnapshotUIMessage(state.turn.UIState()) + metadata, _ := message["metadata"].(map[string]any) + if metadata["finish_reason"] != "cancelled" { + t.Fatalf("expected cancelled finish_reason, got %#v", metadata["finish_reason"]) + } +} diff --git a/bridges/ai/subagent_registry.go b/bridges/ai/subagent_registry.go index 4772b5a5d..a6ecf9bfe 100644 --- a/bridges/ai/subagent_registry.go +++ b/bridges/ai/subagent_registry.go @@ -1,6 +1,7 @@ package ai import ( + "context" "time" "maunium.net/go/mautrix/id" @@ -32,7 +33,7 @@ func (oc *AIClient) listSubagentRunsForParent(parent id.RoomID) []*subagentRun { return runs } -func (oc *AIClient) stopSubagentRuns(parent id.RoomID) int { +func (oc *AIClient) stopSubagentRuns(ctx context.Context, parent id.RoomID) int { if oc == nil || parent == "" { return 0 } @@ -44,7 +45,8 @@ func (oc *AIClient) stopSubagentRuns(parent id.RoomID) int { } canceled := oc.cancelRoomRun(run.ChildRoomID) drained := oc.drainPendingQueue(run.ChildRoomID) - if canceled || len(drained) > 0 { + finalized := oc.finalizeStoppedQueueItems(ctx, drained) + if canceled || finalized > 0 { stopped++ } } diff --git a/bridges/ai/turn_data.go b/bridges/ai/turn_data.go index 24c138aae..2dbb9759c 100644 --- a/bridges/ai/turn_data.go +++ b/bridges/ai/turn_data.go @@ -59,6 +59,9 @@ func canonicalResponseStatus(state *streamingState) string { if state == nil { return "" } + if state.stop.Load() != nil { + return "cancelled" + } status := strings.TrimSpace(state.responseStatus) if state.completedAtMs == 0 { return status @@ -72,9 +75,6 @@ func canonicalResponseStatus(state *streamingState) string { if strings.TrimSpace(state.responseID) == "" { return status } - if state.stop.Load() != nil { - return "cancelled" - } switch strings.TrimSpace(state.finishReason) { case "", "stop": diff --git a/bridges/ai/turn_data_test.go b/bridges/ai/turn_data_test.go index fbd2b7c27..36919e63d 100644 --- a/bridges/ai/turn_data_test.go +++ b/bridges/ai/turn_data_test.go @@ -100,3 +100,12 @@ func TestBuildTurnDataMetadataUsesResponderSnapshot(t *testing.T) { t.Fatalf("did not expect flat prompt_tokens field, got %#v", meta["prompt_tokens"]) } } + +func TestCanonicalResponseStatusPrefersExplicitStopWithoutResponseID(t *testing.T) { + state := testStreamingState("turn-cancelled") + state.stop.Store(&assistantStopMetadata{Reason: "user_stop"}) + + if got := canonicalResponseStatus(state); got != "cancelled" { + t.Fatalf("expected cancelled status from explicit stop, got %q", got) + } +} From bf332afb695b1804f2a43142864941548cd416b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 18:11:45 +0200 Subject: [PATCH 7/7] Fix pending queue locking and lastItem updates Rework pending queue locking and item housekeeping to avoid races and stale pointers. getPendingQueue now locks the queue.mu before releasing pendingQueuesMu and consistently applies settings for both new and existing queues. drainPendingQueue clears queue.items and lastItem when removing a queue. removePendingQueueBySourceEvent reassigns lastItem to the new tail if the removed item was the last. enqueuePendingItem removed a now-redundant explicit lock (the returned queue is already locked). Added a unit test to verify lastItem is cleared/reassigned, and tightened an error assertion in the streaming test to use errors.Is for wrapped cancellations. --- bridges/ai/pending_queue.go | 38 ++++++++++++--------- bridges/ai/queue_status_test.go | 30 ++++++++++++++++ bridges/ai/streaming_error_handling_test.go | 4 +-- 3 files changed, 54 insertions(+), 18 deletions(-) diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index 35650828a..9604c8a97 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -57,7 +57,6 @@ type pendingQueueDispatchCandidate struct { func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSettings) *pendingQueue { oc.pendingQueuesMu.Lock() - defer oc.pendingQueuesMu.Unlock() queue := oc.pendingQueues[roomID] if queue == nil { queue = &pendingQueue{ @@ -68,20 +67,19 @@ func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSe dropPolicy: settings.DropPolicy, } oc.pendingQueues[roomID] = queue - } else { - queue.mu.Lock() - queue.mode = settings.Mode - if settings.DebounceMs >= 0 { - queue.debounceMs = settings.DebounceMs - } - if settings.Cap > 0 { - queue.cap = settings.Cap - } - if settings.DropPolicy != "" { - queue.dropPolicy = settings.DropPolicy - } - queue.mu.Unlock() } + queue.mu.Lock() + queue.mode = settings.Mode + if settings.DebounceMs >= 0 { + queue.debounceMs = settings.DebounceMs + } + if settings.Cap > 0 { + queue.cap = settings.Cap + } + if settings.DropPolicy != "" { + queue.dropPolicy = settings.DropPolicy + } + oc.pendingQueuesMu.Unlock() return queue } @@ -99,9 +97,11 @@ func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { oc.pendingQueuesMu.Unlock() return nil } - delete(oc.pendingQueues, roomID) queue.mu.Lock() + delete(oc.pendingQueues, roomID) items := queue.items + queue.items = nil + queue.lastItem = nil queue.mu.Unlock() oc.pendingQueuesMu.Unlock() @@ -131,6 +131,13 @@ func (oc *AIClient) removePendingQueueBySourceEvent(roomID id.RoomID, sourceEven } clear(queue.items[len(kept):]) queue.items = kept + if queue.lastItem != nil && queue.lastItem.pending.sourceEventID() == sourceEventID { + queue.lastItem = nil + if len(kept) > 0 { + lastItem := kept[len(kept)-1] + queue.lastItem = &lastItem + } + } empty := len(queue.items) == 0 && queue.droppedCount == 0 if empty { delete(oc.pendingQueues, roomID) @@ -149,7 +156,6 @@ func (oc *AIClient) enqueuePendingItem(roomID id.RoomID, item pendingQueueItem, if queue == nil { return false } - queue.mu.Lock() defer queue.mu.Unlock() for _, existing := range queue.items { diff --git a/bridges/ai/queue_status_test.go b/bridges/ai/queue_status_test.go index b30b60925..2785af968 100644 --- a/bridges/ai/queue_status_test.go +++ b/bridges/ai/queue_status_test.go @@ -190,3 +190,33 @@ func TestDispatchOrQueueQueuesBehindExistingPendingWork(t *testing.T) { t.Fatalf("expected room to remain unacquired while backlog exists") } } + +func TestRemovePendingQueueBySourceEventClearsRemovedLastItem(t *testing.T) { + roomID := id.RoomID("!room:example.com") + first := pendingQueueItem{pending: pendingMessage{SourceEventID: id.EventID("$one")}} + last := pendingQueueItem{pending: pendingMessage{SourceEventID: id.EventID("$two")}} + oc := &AIClient{ + pendingQueues: map[id.RoomID]*pendingQueue{ + roomID: { + items: []pendingQueueItem{first, last}, + lastItem: &last, + }, + }, + } + + removed := oc.removePendingQueueBySourceEvent(roomID, id.EventID("$two")) + if len(removed) != 1 { + t.Fatalf("expected one removed item, got %d", len(removed)) + } + + snapshot := oc.getQueueSnapshot(roomID) + if snapshot == nil { + t.Fatal("expected queue snapshot to remain") + } + if snapshot.lastItem == nil { + t.Fatal("expected lastItem to be reassigned to the new tail") + } + if got := snapshot.lastItem.pending.sourceEventID(); got != id.EventID("$one") { + t.Fatalf("expected lastItem to point at remaining item, got %q", got) + } +} diff --git a/bridges/ai/streaming_error_handling_test.go b/bridges/ai/streaming_error_handling_test.go index 770e6aae5..36435cb0f 100644 --- a/bridges/ai/streaming_error_handling_test.go +++ b/bridges/ai/streaming_error_handling_test.go @@ -101,8 +101,8 @@ func TestFinishStreamingWithFailureCancelledEndsTurnAsCancelled(t *testing.T) { "cancelled", context.Canceled, ) - if err == nil { - t.Fatal("expected wrapped cancellation error") + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected wrapped cancellation error, got %#v", err) } message := streamui.SnapshotUIMessage(state.turn.UIState())