Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/packages/core/agent_framework/_workflows/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ async def _run_impl(
output_events.append(event)

result = self._convert_workflow_events_to_agent_response(response_id, output_events)
# Set the response on the session context so that after_run providers
# (e.g. InMemoryHistoryProvider) can persist the output messages.
session_context._response = result
await self._run_after_providers(session=provider_session, context=session_context)
return result

Expand Down Expand Up @@ -322,12 +325,18 @@ async def _run_stream_impl(
# combine the messages

session_messages: list[Message] = session_context.get_messages(include_input=True)
collected_updates: list[AgentResponseUpdate] = []
async for event in self._run_core(
session_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs
):
updates = self._convert_workflow_event_to_agent_response_updates(response_id, event)
for update in updates:
collected_updates.append(update)
yield update
# Build the final response from collected updates and set it on the session
# context so that after_run providers can persist the output messages.
if collected_updates:
session_context._response = AgentResponse.from_updates(collected_updates)
await self._run_after_providers(session=provider_session, context=session_context)

async def _run_core(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) Microsoft. All rights reserved.

"""Tests for WorkflowAgent session history persistence (GitHub issue #4248).

Validates that WorkflowAgent correctly saves both user input and assistant
response messages to session history via the InMemoryHistoryProvider.
"""

import uuid

import pytest
from typing_extensions import Never

from agent_framework import (
AgentResponse,
AgentResponseUpdate,
AgentSession,
Content,
Message,
WorkflowAgent,
WorkflowBuilder,
WorkflowContext,
executor,
)


@executor
async def simple_response_executor(
messages: list[Message], ctx: WorkflowContext[Never, AgentResponse]
) -> None:
"""Executor that emits a simple assistant response."""
input_text = messages[-1].text if messages else "no input"
response = AgentResponse(
messages=[
Message(
role="assistant",
contents=[Content.from_text(text=f"Response to: {input_text}")],
author_name="test-agent",
)
],
)
await ctx.yield_output(response)


@executor
async def streaming_response_executor(
messages: list[Message], ctx: WorkflowContext[Never, AgentResponseUpdate]
) -> None:
"""Executor that emits a streaming assistant response."""
input_text = messages[-1].text if messages else "no input"
update = AgentResponseUpdate(
contents=[Content.from_text(text=f"Streamed response to: {input_text}")],
role="assistant",
author_name="test-agent",
message_id=str(uuid.uuid4()),
)
await ctx.yield_output(update)


class TestWorkflowAgentSessionHistory:
"""Test that WorkflowAgent persists responses to session history.

Reproduces and validates the fix for GitHub issue #4248:
WorkflowAgent was not saving workflow responses to session history
because session_context._response was never set before calling
_run_after_providers.
"""

async def test_non_streaming_saves_response_to_session(self):
"""Non-streaming run should save both user and assistant messages to session history."""
workflow = WorkflowBuilder(start_executor=simple_response_executor).build()
agent = workflow.as_agent("test-agent")
session = agent.create_session()

await agent.run("Hello", session=session)

# The InMemoryHistoryProvider stores messages in session.state["in_memory"]["messages"]
stored_messages = session.state.get("in_memory", {}).get("messages", [])

# Should have both user input and assistant response
assert len(stored_messages) >= 2, (
f"Expected at least 2 messages (user + assistant), got {len(stored_messages)}: "
f"{[(m.role, m.text) for m in stored_messages]}"
)

roles = [m.role for m in stored_messages]
assert "user" in roles, "User message should be in session history"
assert "assistant" in roles, "Assistant message should be in session history"

# Verify the assistant message content
assistant_msgs = [m for m in stored_messages if m.role == "assistant"]
assert any("Response to: Hello" in (m.text or "") for m in assistant_msgs)

async def test_streaming_saves_response_to_session(self):
"""Streaming run should save both user and assistant messages to session history."""
workflow = WorkflowBuilder(start_executor=streaming_response_executor).build()
agent = workflow.as_agent("test-agent")
session = agent.create_session()

# Consume the stream fully
async for _ in agent.run("Hello", stream=True, session=session):
pass

stored_messages = session.state.get("in_memory", {}).get("messages", [])

assert len(stored_messages) >= 2, (
f"Expected at least 2 messages (user + assistant), got {len(stored_messages)}: "
f"{[(m.role, m.text) for m in stored_messages]}"
)

roles = [m.role for m in stored_messages]
assert "user" in roles, "User message should be in session history"
assert "assistant" in roles, "Assistant message should be in session history"

assistant_msgs = [m for m in stored_messages if m.role == "assistant"]
assert any("Streamed response to: Hello" in (m.text or "") for m in assistant_msgs)

async def test_multi_turn_saves_all_messages(self):
"""Multiple turns should accumulate all messages in session history."""
workflow = WorkflowBuilder(start_executor=simple_response_executor).build()
agent = workflow.as_agent("test-agent")
session = agent.create_session()

# Turn 1
await agent.run("First question", session=session)

# Turn 2
await agent.run("Second question", session=session)

stored_messages = session.state.get("in_memory", {}).get("messages", [])

# Should have 4 messages: user1, assistant1, user2, assistant2
assert len(stored_messages) >= 4, (
f"Expected at least 4 messages (2 user + 2 assistant), got {len(stored_messages)}: "
f"{[(m.role, m.text) for m in stored_messages]}"
)

user_msgs = [m for m in stored_messages if m.role == "user"]
assistant_msgs = [m for m in stored_messages if m.role == "assistant"]

assert len(user_msgs) >= 2, f"Expected at least 2 user messages, got {len(user_msgs)}"
assert len(assistant_msgs) >= 2, f"Expected at least 2 assistant messages, got {len(assistant_msgs)}"

# Verify content of second turn references the input
assert any("Second question" in (m.text or "") for m in user_msgs)
assert any("Response to: Second question" in (m.text or "") for m in assistant_msgs)
Loading