Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 65 additions & 42 deletions lib/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ const INTERNAL_AGENT_SIGNATURES = [
"Summarize what was done in this conversation",
]

function cloneMessages(messages: WithParts[]): WithParts[] {
return structuredClone(messages)
}

function commitMessages(target: WithParts[], source: WithParts[]): void {
target.splice(0, target.length, ...source)
}

export function createSystemPromptHandler(
state: SessionState,
logger: Logger,
Expand Down Expand Up @@ -105,53 +113,68 @@ export function createChatMessageTransformHandler(
hostPermissions: HostPermissionSnapshot,
) {
return async (input: {}, output: { messages: WithParts[] }) => {
const receivedMessages = Array.isArray(output.messages) ? output.messages.length : 0
const messages = filterMessagesInPlace(output.messages)
if (messages.length !== receivedMessages) {
logger.warn("Skipping messages with unexpected shape during chat transform", {
received: receivedMessages,
usable: messages.length,
})
}
// Fail-open: catch transform failures so DCP bugs do not abort the session.
try {
if (!Array.isArray(output.messages)) {
throw new Error("Chat transform output.messages is not an array")
}

await checkSession(client, state, logger, output.messages, config.manualMode.enabled)
const workingMessages = cloneMessages(output.messages)
const receivedMessages = workingMessages.length
const messages = filterMessagesInPlace(workingMessages)
if (messages.length !== receivedMessages) {
logger.warn("Skipping messages with unexpected shape during chat transform", {
received: receivedMessages,
usable: messages.length,
})
}

syncCompressPermissionState(state, config, hostPermissions, output.messages)
await checkSession(client, state, logger, workingMessages, config.manualMode.enabled)

if (state.isSubAgent && !config.experimental.allowSubAgents) {
return
}
syncCompressPermissionState(state, config, hostPermissions, workingMessages)

stripHallucinations(output.messages)
cacheSystemPromptTokens(state, output.messages)
assignMessageRefs(state, output.messages)
syncCompressionBlocks(state, logger, output.messages)
syncToolCache(state, config, logger, output.messages)
buildToolIdList(state, output.messages)
prune(state, logger, config, output.messages)
await injectExtendedSubAgentResults(
client,
state,
logger,
output.messages,
config.experimental.allowSubAgents,
)
const compressionPriorities = buildPriorityMap(config, state, output.messages)
prompts.reload()
injectCompressNudges(
state,
config,
logger,
output.messages,
prompts.getRuntimePrompts(),
compressionPriorities,
)
injectMessageIds(state, config, output.messages, compressionPriorities)
applyPendingManualTrigger(state, output.messages, logger)
stripStaleMetadata(output.messages)
if (state.isSubAgent && !config.experimental.allowSubAgents) {
commitMessages(output.messages, workingMessages)
return
}

stripHallucinations(workingMessages)
cacheSystemPromptTokens(state, workingMessages)
assignMessageRefs(state, workingMessages)
syncCompressionBlocks(state, logger, workingMessages)
syncToolCache(state, config, logger, workingMessages)
buildToolIdList(state, workingMessages)
prune(state, logger, config, workingMessages)
await injectExtendedSubAgentResults(
client,
state,
logger,
workingMessages,
config.experimental.allowSubAgents,
)
const compressionPriorities = buildPriorityMap(config, state, workingMessages)
prompts.reload()
injectCompressNudges(
state,
config,
logger,
workingMessages,
prompts.getRuntimePrompts(),
compressionPriorities,
)
injectMessageIds(state, config, workingMessages, compressionPriorities)
applyPendingManualTrigger(state, workingMessages, logger)
stripStaleMetadata(workingMessages)

if (state.sessionId) {
await logger.saveContext(state.sessionId, output.messages)
if (state.sessionId) {
await logger.saveContext(state.sessionId, workingMessages)
}

commitMessages(output.messages, workingMessages)
} catch (err) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would still allow these functions to mutate messages, if there is an error thrown we should probably use a deep copy or something like that instead of a partially completed mutation wherever the error occured

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I changed the fail-open path to run transforms against a cloned message array and only commit back to output.messages after the pipeline succeeds, preserving the original array identity with splice. If a transform throws, the original messages are left untouched. Added regression coverage for late transform failure, clone failure, and the disabled-subagent pre-skip path.

logger.error("DCP chat transform failed; continuing without mutations", {
error: err instanceof Error ? err.message : String(err),
})
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/message-ids.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ export function assignMessageRefs(state: SessionState, messages: WithParts[]): n
}

const existingRef = state.messageIds.byRawId.get(rawMessageId)
if (existingRef) {
if (existingRef !== undefined) {
if (state.messageIds.byRef.get(existingRef) !== rawMessageId) {
state.messageIds.byRef.set(existingRef, rawMessageId)
}
Expand Down
121 changes: 121 additions & 0 deletions tests/hooks-permission.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ function buildMessage(id: string, role: "user" | "assistant", text: string): Wit
}
}

function dcpMessageIdTag(ref = "m0001"): string {
return `<${"dcp-message-id"}>${ref}</${"dcp-message-id"}>`
}

test("system prompt handler caches full model context for percentage thresholds", async () => {
const state = createSessionState()
const handler = createSystemPromptHandler(state, new Logger(false), buildConfig("deny"), {
Expand Down Expand Up @@ -177,6 +181,123 @@ test("chat message transform drops messages without info instead of crashing", a
assert.equal(output.messages.length, 0)
})

test("chat message transform leaves original messages untouched when a late transform fails", async () => {
const state = createSessionState()
const logger = new Logger(false)
let loggedError = ""
logger.error = ((message: string) => {
loggedError = message
return Promise.resolve()
}) as Logger["error"]
const config = buildConfig("allow")
const originalText = `alpha ${dcpMessageIdTag()} omega`
const output = {
messages: [buildMessage("assistant-1", "assistant", originalText)],
}
const originalMessages = output.messages
const handler = createChatMessageTransformHandler(
{ session: { get: async () => ({}) } } as any,
state,
logger,
config,
{
reload() {
throw new Error("reload failed")
},
getRuntimePrompts() {
return {} as any
},
} as any,
{ global: undefined, agents: {} },
)

await handler({}, output)

assert.equal(output.messages, originalMessages)
assert.equal((output.messages[0]?.parts[0] as any).text, originalText)
assert.equal(loggedError, "DCP chat transform failed; continuing without mutations")
})

test("chat message transform leaves original messages untouched when cloning fails", async () => {
const originalStructuredClone = globalThis.structuredClone
const state = createSessionState()
const logger = new Logger(false)
let loggedError = ""
logger.error = ((message: string) => {
loggedError = message
return Promise.resolve()
}) as Logger["error"]
const output = {
messages: [buildMessage("assistant-1", "assistant", "alpha")],
}
const originalMessages = output.messages

globalThis.structuredClone = (() => {
throw new Error("clone failed")
}) as typeof structuredClone

try {
const handler = createChatMessageTransformHandler(
{ session: { get: async () => ({}) } } as any,
state,
logger,
buildConfig("allow"),
{
reload() {},
getRuntimePrompts() {
return {} as any
},
} as any,
{ global: undefined, agents: {} },
)

await handler({}, output)
} finally {
globalThis.structuredClone = originalStructuredClone
}

assert.equal(output.messages, originalMessages)
assert.equal((output.messages[0]?.parts[0] as any).text, "alpha")
assert.equal(loggedError, "DCP chat transform failed; continuing without mutations")
})

test("chat message transform commits pre-skip filtering for disabled subagents", async () => {
const state = createSessionState()
state.sessionId = "session-1"
state.isSubAgent = true
const config = buildConfig("allow")
const originalText = `alpha ${dcpMessageIdTag()} omega`
const output = {
messages: [
{ role: "assistant", parts: [] } as any,
buildMessage("assistant-1", "assistant", originalText),
],
}
const originalMessages = output.messages
const handler = createChatMessageTransformHandler(
{ session: { get: async () => ({}) } } as any,
state,
new Logger(false),
config,
{
reload() {
throw new Error("later transforms should not run")
},
getRuntimePrompts() {
return {} as any
},
} as any,
{ global: undefined, agents: {} },
)

await handler({}, output as any)

assert.equal(output.messages, originalMessages)
assert.equal(output.messages.length, 1)
assert.equal((output.messages[0]?.parts[0] as any).text, originalText)
assert.equal(state.messageIds.byRawId.size, 0)
})

test("command execute exits after effective permission resolves to deny", async () => {
let sessionMessagesCalls = 0
const output = { parts: [] as any[] }
Expand Down
42 changes: 41 additions & 1 deletion tests/message-ids.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import assert from "node:assert/strict"
import test from "node:test"
import { Logger } from "../lib/logger"
import { assignMessageRefs } from "../lib/message-ids"
import { assignMessageRefs, MESSAGE_REF_MAX_INDEX, formatMessageRef } from "../lib/message-ids"
import { checkSession, createSessionState, type WithParts } from "../lib/state"

function textPart(messageID: string, sessionID: string, id: string, text: string) {
Expand Down Expand Up @@ -87,3 +87,43 @@ test("checkSession resets message id aliases after native compaction", async ()
assert.equal(state.messageIds.byRef.get("m0002"), "msg-user-follow-up")
assert.equal(state.messageIds.nextRef, 3)
})

test("assignMessageRefs throws when alias capacity is exhausted", () => {
const sessionID = `ses_message_ids_exhausted_${Date.now()}`
const state = createSessionState()
state.messageIds.nextRef = MESSAGE_REF_MAX_INDEX + 1

const messages: WithParts[] = [
{
info: {
id: "msg-a",
role: "assistant",
sessionID,
agent: "assistant",
time: { created: 1 },
} as WithParts["info"],
parts: [textPart("msg-a", sessionID, "msg-a-part", "A")],
},
{
info: {
id: "msg-b",
role: "assistant",
sessionID,
agent: "assistant",
time: { created: 2 },
} as WithParts["info"],
parts: [textPart("msg-b", sessionID, "msg-b-part", "B")],
},
]

assert.throws(
() => assignMessageRefs(state, messages),
new RegExp(
`Message ID alias capacity exceeded. Cannot allocate more than ${formatMessageRef(MESSAGE_REF_MAX_INDEX)} aliases in this session.`,
),
)
assert.equal(state.messageIds.byRawId.size, 0)
assert.equal(state.messageIds.byRef.size, 0)
assert.equal(state.messageIds.byRawId.get("msg-a"), undefined)
assert.equal(state.messageIds.byRawId.get("msg-b"), undefined)
})