diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 257833bb6a..fe8be1886b 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -360,7 +360,7 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp Returns: The complete AgentResponse, or None if waiting for user input. """ - run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {}) + run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})) updates: list[AgentResponseUpdate] = [] streamed_user_input_requests: list[Content] = [] diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index cd7dbb4a68..f545fbe5d8 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -345,9 +345,14 @@ async def _run_workflow_with_tracing( self._runner.context.reset_for_new_run() self._state.clear() - # Store run kwargs in State so executors can access them - # Always store (even empty dict) so retrieval is deterministic - self._state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs or {}) + # Store run kwargs in State so executors can access them. + # Only overwrite when new kwargs are explicitly provided or state was + # just cleared (fresh run). On continuation (reset_context=False) with + # no new kwargs, preserve the kwargs from the original run. + if run_kwargs is not None: + self._state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs) + elif reset_context: + self._state.set(WORKFLOW_RUN_KWARGS_KEY, {}) self._state.commit() # Commit immediately so kwargs are available # Set streaming mode after reset @@ -564,6 +569,10 @@ async def _run_core( initial_executor_fn=initial_executor_fn, reset_context=reset_context, streaming=streaming, + # Empty **kwargs (no caller-provided kwargs) is collapsed to None so that + # continuation calls without explicit kwargs preserve the original run's kwargs. + # A non-empty kwargs dict (even one with empty values like {"key": {}}) + # is passed through and will overwrite stored kwargs. run_kwargs=kwargs if kwargs else None, ): if event.type == "output" and not self._should_yield_output_event(event): diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 0d2c86070c..e9e4196bfd 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -385,7 +385,7 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) try: # Get kwargs from parent workflow's State to propagate to subworkflow - parent_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {} + parent_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}) # Run the sub-workflow and collect all events, passing parent kwargs result = await self.workflow.run(input_data, **parent_kwargs) diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index bf1fd00974..379435e124 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -446,6 +446,215 @@ async def test_kwargs_with_complex_nested_data() -> None: assert received.get("complex_data") == complex_data +async def test_kwargs_preserved_on_response_continuation() -> None: + """Test that run kwargs are preserved when continuing a paused workflow with run(responses=...). + + Regression test for #4293: kwargs were overwritten to {} on continuation calls. + """ + + class _ApprovalCapturingAgent(BaseAgent): + """Agent that pauses for approval on first call and captures kwargs on every call.""" + + captured_kwargs: list[dict[str, Any]] + _asked: bool + + def __init__(self) -> None: + super().__init__(name="approval_agent", description="Test agent") + self.captured_kwargs = [] + self._asked = False + + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + self.captured_kwargs.append(dict(kwargs)) + if not self._asked: + self._asked = True + + async def _pause() -> AgentResponse: + call = Content.from_function_call(call_id="c1", name="do_thing", arguments="{}") + req = Content.from_function_approval_request(id="r1", function_call=call) + return AgentResponse(messages=[Message("assistant", [req])]) + + return _pause() + + async def _done() -> AgentResponse: + return AgentResponse(messages=[Message("assistant", ["done"])]) + + return _done() + + from agent_framework import WorkflowBuilder + + agent = _ApprovalCapturingAgent() + workflow = WorkflowBuilder(start_executor=agent, output_executors=[agent]).build() + + # Initial run with kwargs — workflow should pause for approval + result = await workflow.run("go", custom_data={"token": "abc"}) + request_events = result.get_request_info_events() + assert len(request_events) == 1 + + # Continue with responses only — no new kwargs + approval = request_events[0] + await workflow.run( + responses={approval.request_id: approval.data.to_function_approval_response(True)} + ) + + # Both calls should have received the original kwargs + assert len(agent.captured_kwargs) == 2 + assert agent.captured_kwargs[0].get("custom_data") == {"token": "abc"} + assert agent.captured_kwargs[1].get("custom_data") == {"token": "abc"}, ( + f"kwargs should be preserved on continuation, got: {agent.captured_kwargs[1]}" + ) + + +async def test_kwargs_overridden_on_response_continuation() -> None: + """Test that explicitly provided kwargs override prior kwargs on continuation.""" + + class _ApprovalCapturingAgent(BaseAgent): + captured_kwargs: list[dict[str, Any]] + _asked: bool + + def __init__(self) -> None: + super().__init__(name="approval_agent", description="Test agent") + self.captured_kwargs = [] + self._asked = False + + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + self.captured_kwargs.append(dict(kwargs)) + if not self._asked: + self._asked = True + + async def _pause() -> AgentResponse: + call = Content.from_function_call(call_id="c1", name="do_thing", arguments="{}") + req = Content.from_function_approval_request(id="r1", function_call=call) + return AgentResponse(messages=[Message("assistant", [req])]) + + return _pause() + + async def _done() -> AgentResponse: + return AgentResponse(messages=[Message("assistant", ["done"])]) + + return _done() + + from agent_framework import WorkflowBuilder + + agent = _ApprovalCapturingAgent() + workflow = WorkflowBuilder(start_executor=agent, output_executors=[agent]).build() + + result = await workflow.run("go", custom_data={"token": "abc"}) + request_events = result.get_request_info_events() + approval = request_events[0] + + # Continue with responses AND new kwargs — should override + await workflow.run( + responses={approval.request_id: approval.data.to_function_approval_response(True)}, + custom_data={"token": "xyz"}, + ) + + assert len(agent.captured_kwargs) == 2 + assert agent.captured_kwargs[0].get("custom_data") == {"token": "abc"} + assert agent.captured_kwargs[1].get("custom_data") == {"token": "xyz"} + + +async def test_kwargs_empty_value_passed_on_continuation() -> None: + """Test that explicitly passing a kwarg with an empty value on continuation overrides prior kwargs. + + This exercises the boundary where the caller provides kwargs (e.g., custom_data={}) + that differ from the original run. Because the kwargs dict is non-empty (it has a key), + it passes the `kwargs if kwargs else None` gate and the `is not None` check, so it + overwrites the previously stored kwargs. + """ + + class _ApprovalCapturingAgent(BaseAgent): + captured_kwargs: list[dict[str, Any]] + _asked: bool + + def __init__(self) -> None: + super().__init__(name="approval_agent", description="Test agent") + self.captured_kwargs = [] + self._asked = False + + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + self.captured_kwargs.append(dict(kwargs)) + if not self._asked: + self._asked = True + + async def _pause() -> AgentResponse: + call = Content.from_function_call(call_id="c1", name="do_thing", arguments="{}") + req = Content.from_function_approval_request(id="r1", function_call=call) + return AgentResponse(messages=[Message("assistant", [req])]) + + return _pause() + + async def _done() -> AgentResponse: + return AgentResponse(messages=[Message("assistant", ["done"])]) + + return _done() + + from agent_framework import WorkflowBuilder + + agent = _ApprovalCapturingAgent() + workflow = WorkflowBuilder(start_executor=agent, output_executors=[agent]).build() + + # Initial run with non-empty kwargs + result = await workflow.run("go", custom_data={"token": "abc"}) + request_events = result.get_request_info_events() + assert len(request_events) == 1 + + # Continue with custom_data={} — explicitly clearing the value. + # kwargs={"custom_data": {}} is truthy (has a key), so run_kwargs is set. + approval = request_events[0] + await workflow.run( + responses={approval.request_id: approval.data.to_function_approval_response(True)}, + custom_data={}, + ) + + assert len(agent.captured_kwargs) == 2 + assert agent.captured_kwargs[0].get("custom_data") == {"token": "abc"} + # The continuation explicitly set custom_data={}, overriding the original + assert agent.captured_kwargs[1].get("custom_data") == {} + + +async def test_kwargs_reset_context_stores_empty_dict() -> None: + """Test that reset_context=True with no kwargs stores an empty dict. + + This exercises the `elif reset_context` branch that ensures WORKFLOW_RUN_KWARGS_KEY + is always populated after a fresh run, even when no kwargs are provided. + """ + agent = _KwargsCapturingAgent(name="reset_ctx_test") + + workflow = SequentialBuilder(participants=[agent]).build() + + # Run with no kwargs and reset_context=True (the default for a fresh run) + async for event in workflow.run("test", stream=True): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + assert len(agent.captured_kwargs) >= 1 + # The only kwarg should be the framework-injected 'options' (no user-provided kwargs) + received = agent.captured_kwargs[0] + assert "custom_data" not in received + assert received.get("options") is None + + async def test_kwargs_preserved_across_workflow_reruns() -> None: """Test that kwargs are correctly isolated between workflow runs.""" agent = _KwargsCapturingAgent(name="rerun_test")