diff --git a/lib/hooks.ts b/lib/hooks.ts index e31cf110..e08345ec 100644 --- a/lib/hooks.ts +++ b/lib/hooks.ts @@ -46,6 +46,14 @@ const INTERNAL_AGENT_SIGNATURES = [ "Summarize what was done in this conversation", ] +function cloneMessages(messages: WithParts[]): WithParts[] { + return structuredClone(messages) +} + +function commitMessages(target: WithParts[], source: WithParts[]): void { + target.splice(0, target.length, ...source) +} + export function createSystemPromptHandler( state: SessionState, logger: Logger, @@ -105,53 +113,68 @@ export function createChatMessageTransformHandler( hostPermissions: HostPermissionSnapshot, ) { return async (input: {}, output: { messages: WithParts[] }) => { - const receivedMessages = Array.isArray(output.messages) ? output.messages.length : 0 - const messages = filterMessagesInPlace(output.messages) - if (messages.length !== receivedMessages) { - logger.warn("Skipping messages with unexpected shape during chat transform", { - received: receivedMessages, - usable: messages.length, - }) - } + // Fail-open: catch transform failures so DCP bugs do not abort the session. + try { + if (!Array.isArray(output.messages)) { + throw new Error("Chat transform output.messages is not an array") + } - await checkSession(client, state, logger, output.messages, config.manualMode.enabled) + const workingMessages = cloneMessages(output.messages) + const receivedMessages = workingMessages.length + const messages = filterMessagesInPlace(workingMessages) + if (messages.length !== receivedMessages) { + logger.warn("Skipping messages with unexpected shape during chat transform", { + received: receivedMessages, + usable: messages.length, + }) + } - syncCompressPermissionState(state, config, hostPermissions, output.messages) + await checkSession(client, state, logger, workingMessages, config.manualMode.enabled) - if (state.isSubAgent && !config.experimental.allowSubAgents) { - return - } + syncCompressPermissionState(state, config, hostPermissions, workingMessages) - stripHallucinations(output.messages) - cacheSystemPromptTokens(state, output.messages) - assignMessageRefs(state, output.messages) - syncCompressionBlocks(state, logger, output.messages) - syncToolCache(state, config, logger, output.messages) - buildToolIdList(state, output.messages) - prune(state, logger, config, output.messages) - await injectExtendedSubAgentResults( - client, - state, - logger, - output.messages, - config.experimental.allowSubAgents, - ) - const compressionPriorities = buildPriorityMap(config, state, output.messages) - prompts.reload() - injectCompressNudges( - state, - config, - logger, - output.messages, - prompts.getRuntimePrompts(), - compressionPriorities, - ) - injectMessageIds(state, config, output.messages, compressionPriorities) - applyPendingManualTrigger(state, output.messages, logger) - stripStaleMetadata(output.messages) + if (state.isSubAgent && !config.experimental.allowSubAgents) { + commitMessages(output.messages, workingMessages) + return + } + + stripHallucinations(workingMessages) + cacheSystemPromptTokens(state, workingMessages) + assignMessageRefs(state, workingMessages) + syncCompressionBlocks(state, logger, workingMessages) + syncToolCache(state, config, logger, workingMessages) + buildToolIdList(state, workingMessages) + prune(state, logger, config, workingMessages) + await injectExtendedSubAgentResults( + client, + state, + logger, + workingMessages, + config.experimental.allowSubAgents, + ) + const compressionPriorities = buildPriorityMap(config, state, workingMessages) + prompts.reload() + injectCompressNudges( + state, + config, + logger, + workingMessages, + prompts.getRuntimePrompts(), + compressionPriorities, + ) + injectMessageIds(state, config, workingMessages, compressionPriorities) + applyPendingManualTrigger(state, workingMessages, logger) + stripStaleMetadata(workingMessages) - if (state.sessionId) { - await logger.saveContext(state.sessionId, output.messages) + if (state.sessionId) { + await logger.saveContext(state.sessionId, workingMessages) + } + + commitMessages(output.messages, workingMessages) + } catch (err) { + logger.error("DCP chat transform failed; continuing without mutations", { + error: err instanceof Error ? err.message : String(err), + }) } } } diff --git a/lib/message-ids.ts b/lib/message-ids.ts index da003999..2e48f2db 100644 --- a/lib/message-ids.ts +++ b/lib/message-ids.ts @@ -136,7 +136,7 @@ export function assignMessageRefs(state: SessionState, messages: WithParts[]): n } const existingRef = state.messageIds.byRawId.get(rawMessageId) - if (existingRef) { + if (existingRef !== undefined) { if (state.messageIds.byRef.get(existingRef) !== rawMessageId) { state.messageIds.byRef.set(existingRef, rawMessageId) } diff --git a/tests/hooks-permission.test.ts b/tests/hooks-permission.test.ts index 491584c2..a49800ac 100644 --- a/tests/hooks-permission.test.ts +++ b/tests/hooks-permission.test.ts @@ -87,6 +87,10 @@ function buildMessage(id: string, role: "user" | "assistant", text: string): Wit } } +function dcpMessageIdTag(ref = "m0001"): string { + return `<${"dcp-message-id"}>${ref}` +} + test("system prompt handler caches full model context for percentage thresholds", async () => { const state = createSessionState() const handler = createSystemPromptHandler(state, new Logger(false), buildConfig("deny"), { @@ -177,6 +181,123 @@ test("chat message transform drops messages without info instead of crashing", a assert.equal(output.messages.length, 0) }) +test("chat message transform leaves original messages untouched when a late transform fails", async () => { + const state = createSessionState() + const logger = new Logger(false) + let loggedError = "" + logger.error = ((message: string) => { + loggedError = message + return Promise.resolve() + }) as Logger["error"] + const config = buildConfig("allow") + const originalText = `alpha ${dcpMessageIdTag()} omega` + const output = { + messages: [buildMessage("assistant-1", "assistant", originalText)], + } + const originalMessages = output.messages + const handler = createChatMessageTransformHandler( + { session: { get: async () => ({}) } } as any, + state, + logger, + config, + { + reload() { + throw new Error("reload failed") + }, + getRuntimePrompts() { + return {} as any + }, + } as any, + { global: undefined, agents: {} }, + ) + + await handler({}, output) + + assert.equal(output.messages, originalMessages) + assert.equal((output.messages[0]?.parts[0] as any).text, originalText) + assert.equal(loggedError, "DCP chat transform failed; continuing without mutations") +}) + +test("chat message transform leaves original messages untouched when cloning fails", async () => { + const originalStructuredClone = globalThis.structuredClone + const state = createSessionState() + const logger = new Logger(false) + let loggedError = "" + logger.error = ((message: string) => { + loggedError = message + return Promise.resolve() + }) as Logger["error"] + const output = { + messages: [buildMessage("assistant-1", "assistant", "alpha")], + } + const originalMessages = output.messages + + globalThis.structuredClone = (() => { + throw new Error("clone failed") + }) as typeof structuredClone + + try { + const handler = createChatMessageTransformHandler( + { session: { get: async () => ({}) } } as any, + state, + logger, + buildConfig("allow"), + { + reload() {}, + getRuntimePrompts() { + return {} as any + }, + } as any, + { global: undefined, agents: {} }, + ) + + await handler({}, output) + } finally { + globalThis.structuredClone = originalStructuredClone + } + + assert.equal(output.messages, originalMessages) + assert.equal((output.messages[0]?.parts[0] as any).text, "alpha") + assert.equal(loggedError, "DCP chat transform failed; continuing without mutations") +}) + +test("chat message transform commits pre-skip filtering for disabled subagents", async () => { + const state = createSessionState() + state.sessionId = "session-1" + state.isSubAgent = true + const config = buildConfig("allow") + const originalText = `alpha ${dcpMessageIdTag()} omega` + const output = { + messages: [ + { role: "assistant", parts: [] } as any, + buildMessage("assistant-1", "assistant", originalText), + ], + } + const originalMessages = output.messages + const handler = createChatMessageTransformHandler( + { session: { get: async () => ({}) } } as any, + state, + new Logger(false), + config, + { + reload() { + throw new Error("later transforms should not run") + }, + getRuntimePrompts() { + return {} as any + }, + } as any, + { global: undefined, agents: {} }, + ) + + await handler({}, output as any) + + assert.equal(output.messages, originalMessages) + assert.equal(output.messages.length, 1) + assert.equal((output.messages[0]?.parts[0] as any).text, originalText) + assert.equal(state.messageIds.byRawId.size, 0) +}) + test("command execute exits after effective permission resolves to deny", async () => { let sessionMessagesCalls = 0 const output = { parts: [] as any[] } diff --git a/tests/message-ids.test.ts b/tests/message-ids.test.ts index f128b766..cefb6e55 100644 --- a/tests/message-ids.test.ts +++ b/tests/message-ids.test.ts @@ -1,7 +1,7 @@ import assert from "node:assert/strict" import test from "node:test" import { Logger } from "../lib/logger" -import { assignMessageRefs } from "../lib/message-ids" +import { assignMessageRefs, MESSAGE_REF_MAX_INDEX, formatMessageRef } from "../lib/message-ids" import { checkSession, createSessionState, type WithParts } from "../lib/state" function textPart(messageID: string, sessionID: string, id: string, text: string) { @@ -87,3 +87,43 @@ test("checkSession resets message id aliases after native compaction", async () assert.equal(state.messageIds.byRef.get("m0002"), "msg-user-follow-up") assert.equal(state.messageIds.nextRef, 3) }) + +test("assignMessageRefs throws when alias capacity is exhausted", () => { + const sessionID = `ses_message_ids_exhausted_${Date.now()}` + const state = createSessionState() + state.messageIds.nextRef = MESSAGE_REF_MAX_INDEX + 1 + + const messages: WithParts[] = [ + { + info: { + id: "msg-a", + role: "assistant", + sessionID, + agent: "assistant", + time: { created: 1 }, + } as WithParts["info"], + parts: [textPart("msg-a", sessionID, "msg-a-part", "A")], + }, + { + info: { + id: "msg-b", + role: "assistant", + sessionID, + agent: "assistant", + time: { created: 2 }, + } as WithParts["info"], + parts: [textPart("msg-b", sessionID, "msg-b-part", "B")], + }, + ] + + assert.throws( + () => assignMessageRefs(state, messages), + new RegExp( + `Message ID alias capacity exceeded. Cannot allocate more than ${formatMessageRef(MESSAGE_REF_MAX_INDEX)} aliases in this session.`, + ), + ) + assert.equal(state.messageIds.byRawId.size, 0) + assert.equal(state.messageIds.byRef.size, 0) + assert.equal(state.messageIds.byRawId.get("msg-a"), undefined) + assert.equal(state.messageIds.byRawId.get("msg-b"), undefined) +})