Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp
Returns:
The complete AgentResponse, or None if waiting for user input.
"""
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {})
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}))

updates: list[AgentResponseUpdate] = []
streamed_user_input_requests: list[Content] = []
Expand Down
15 changes: 12 additions & 3 deletions python/packages/core/agent_framework/_workflows/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,14 @@ async def _run_workflow_with_tracing(
self._runner.context.reset_for_new_run()
self._state.clear()

# Store run kwargs in State so executors can access them
# Always store (even empty dict) so retrieval is deterministic
self._state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs or {})
# Store run kwargs in State so executors can access them.
# Only overwrite when new kwargs are explicitly provided or state was
# just cleared (fresh run). On continuation (reset_context=False) with
# no new kwargs, preserve the kwargs from the original run.
if run_kwargs is not None:
self._state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs)
elif reset_context:
self._state.set(WORKFLOW_RUN_KWARGS_KEY, {})
self._state.commit() # Commit immediately so kwargs are available

# Set streaming mode after reset
Expand Down Expand Up @@ -564,6 +569,10 @@ async def _run_core(
initial_executor_fn=initial_executor_fn,
reset_context=reset_context,
streaming=streaming,
# Empty **kwargs (no caller-provided kwargs) is collapsed to None so that
# continuation calls without explicit kwargs preserve the original run's kwargs.
# A non-empty kwargs dict (even one with empty values like {"key": {}})
# is passed through and will overwrite stored kwargs.
run_kwargs=kwargs if kwargs else None,
):
if event.type == "output" and not self._should_yield_output_event(event):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any])

try:
# Get kwargs from parent workflow's State to propagate to subworkflow
parent_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {}
parent_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})

# Run the sub-workflow and collect all events, passing parent kwargs
result = await self.workflow.run(input_data, **parent_kwargs)
Expand Down
209 changes: 209 additions & 0 deletions python/packages/core/tests/workflow/test_workflow_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,215 @@ async def test_kwargs_with_complex_nested_data() -> None:
assert received.get("complex_data") == complex_data


async def test_kwargs_preserved_on_response_continuation() -> None:
"""Test that run kwargs are preserved when continuing a paused workflow with run(responses=...).

Regression test for #4293: kwargs were overwritten to {} on continuation calls.
"""

class _ApprovalCapturingAgent(BaseAgent):
"""Agent that pauses for approval on first call and captures kwargs on every call."""

captured_kwargs: list[dict[str, Any]]
_asked: bool

def __init__(self) -> None:
super().__init__(name="approval_agent", description="Test agent")
self.captured_kwargs = []
self._asked = False

def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
self.captured_kwargs.append(dict(kwargs))
if not self._asked:
self._asked = True

async def _pause() -> AgentResponse:
call = Content.from_function_call(call_id="c1", name="do_thing", arguments="{}")
req = Content.from_function_approval_request(id="r1", function_call=call)
return AgentResponse(messages=[Message("assistant", [req])])

return _pause()

async def _done() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", ["done"])])

return _done()

from agent_framework import WorkflowBuilder

agent = _ApprovalCapturingAgent()
workflow = WorkflowBuilder(start_executor=agent, output_executors=[agent]).build()

# Initial run with kwargs — workflow should pause for approval
result = await workflow.run("go", custom_data={"token": "abc"})
request_events = result.get_request_info_events()
assert len(request_events) == 1

# Continue with responses only — no new kwargs
approval = request_events[0]
await workflow.run(
responses={approval.request_id: approval.data.to_function_approval_response(True)}
)

# Both calls should have received the original kwargs
assert len(agent.captured_kwargs) == 2
assert agent.captured_kwargs[0].get("custom_data") == {"token": "abc"}
assert agent.captured_kwargs[1].get("custom_data") == {"token": "abc"}, (
f"kwargs should be preserved on continuation, got: {agent.captured_kwargs[1]}"
)


async def test_kwargs_overridden_on_response_continuation() -> None:
"""Test that explicitly provided kwargs override prior kwargs on continuation."""

class _ApprovalCapturingAgent(BaseAgent):
captured_kwargs: list[dict[str, Any]]
_asked: bool

def __init__(self) -> None:
super().__init__(name="approval_agent", description="Test agent")
self.captured_kwargs = []
self._asked = False

def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
self.captured_kwargs.append(dict(kwargs))
if not self._asked:
self._asked = True

async def _pause() -> AgentResponse:
call = Content.from_function_call(call_id="c1", name="do_thing", arguments="{}")
req = Content.from_function_approval_request(id="r1", function_call=call)
return AgentResponse(messages=[Message("assistant", [req])])

return _pause()

async def _done() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", ["done"])])

return _done()

from agent_framework import WorkflowBuilder

agent = _ApprovalCapturingAgent()
workflow = WorkflowBuilder(start_executor=agent, output_executors=[agent]).build()

result = await workflow.run("go", custom_data={"token": "abc"})
request_events = result.get_request_info_events()
approval = request_events[0]

# Continue with responses AND new kwargs — should override
await workflow.run(
responses={approval.request_id: approval.data.to_function_approval_response(True)},
custom_data={"token": "xyz"},
)

assert len(agent.captured_kwargs) == 2
assert agent.captured_kwargs[0].get("custom_data") == {"token": "abc"}
assert agent.captured_kwargs[1].get("custom_data") == {"token": "xyz"}


async def test_kwargs_empty_value_passed_on_continuation() -> None:
"""Test that explicitly passing a kwarg with an empty value on continuation overrides prior kwargs.

This exercises the boundary where the caller provides kwargs (e.g., custom_data={})
that differ from the original run. Because the kwargs dict is non-empty (it has a key),
it passes the `kwargs if kwargs else None` gate and the `is not None` check, so it
overwrites the previously stored kwargs.
"""

class _ApprovalCapturingAgent(BaseAgent):
captured_kwargs: list[dict[str, Any]]
_asked: bool

def __init__(self) -> None:
super().__init__(name="approval_agent", description="Test agent")
self.captured_kwargs = []
self._asked = False

def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
self.captured_kwargs.append(dict(kwargs))
if not self._asked:
self._asked = True

async def _pause() -> AgentResponse:
call = Content.from_function_call(call_id="c1", name="do_thing", arguments="{}")
req = Content.from_function_approval_request(id="r1", function_call=call)
return AgentResponse(messages=[Message("assistant", [req])])

return _pause()

async def _done() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", ["done"])])

return _done()

from agent_framework import WorkflowBuilder

agent = _ApprovalCapturingAgent()
workflow = WorkflowBuilder(start_executor=agent, output_executors=[agent]).build()

# Initial run with non-empty kwargs
result = await workflow.run("go", custom_data={"token": "abc"})
request_events = result.get_request_info_events()
assert len(request_events) == 1

# Continue with custom_data={} — explicitly clearing the value.
# kwargs={"custom_data": {}} is truthy (has a key), so run_kwargs is set.
approval = request_events[0]
await workflow.run(
responses={approval.request_id: approval.data.to_function_approval_response(True)},
custom_data={},
)

assert len(agent.captured_kwargs) == 2
assert agent.captured_kwargs[0].get("custom_data") == {"token": "abc"}
# The continuation explicitly set custom_data={}, overriding the original
assert agent.captured_kwargs[1].get("custom_data") == {}


async def test_kwargs_reset_context_stores_empty_dict() -> None:
"""Test that reset_context=True with no kwargs stores an empty dict.

This exercises the `elif reset_context` branch that ensures WORKFLOW_RUN_KWARGS_KEY
is always populated after a fresh run, even when no kwargs are provided.
"""
agent = _KwargsCapturingAgent(name="reset_ctx_test")

workflow = SequentialBuilder(participants=[agent]).build()

# Run with no kwargs and reset_context=True (the default for a fresh run)
async for event in workflow.run("test", stream=True):
if event.type == "status" and event.state == WorkflowRunState.IDLE:
break

assert len(agent.captured_kwargs) >= 1
# The only kwarg should be the framework-injected 'options' (no user-provided kwargs)
received = agent.captured_kwargs[0]
assert "custom_data" not in received
assert received.get("options") is None


async def test_kwargs_preserved_across_workflow_reruns() -> None:
"""Test that kwargs are correctly isolated between workflow runs."""
agent = _KwargsCapturingAgent(name="rerun_test")
Expand Down