diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 257833bb6a..acec8e48e2 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -415,6 +415,10 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp return response + # Parameters that are explicitly passed to agent.run() by AgentExecutor + # and must not appear in **run_kwargs to avoid TypeError from duplicate values. + _RESERVED_RUN_PARAMS: frozenset[str] = frozenset({"session", "stream", "messages"}) + @staticmethod def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] | None]: """Prepare kwargs and options for agent.run(), avoiding duplicate option passing. @@ -423,8 +427,23 @@ def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, A `options.additional_function_arguments`. If workflow kwargs include an `options` key, merge it into the final options object and remove it from kwargs before spreading `**run_kwargs`. + + Reserved parameters (session, stream, messages) that are explicitly + managed by AgentExecutor are stripped from run_kwargs to prevent + ``TypeError: got multiple values for keyword argument`` collisions. """ run_kwargs = dict(raw_run_kwargs) + + # Strip reserved params that AgentExecutor passes explicitly to agent.run(). + for key in AgentExecutor._RESERVED_RUN_PARAMS: + if key in run_kwargs: + logger.warning( + "Workflow kwarg '%s' is reserved by AgentExecutor and will be ignored. " + "Remove it from workflow.run() kwargs to silence this warning.", + key, + ) + run_kwargs.pop(key) + options_from_workflow = run_kwargs.pop("options", None) workflow_additional_args = run_kwargs.pop("additional_function_arguments", None) diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 7c2e6fc356..db53868ee1 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -1,7 +1,10 @@ # Copyright (c) Microsoft. All rights reserved. +import logging from collections.abc import AsyncIterable, Awaitable -from typing import Any +from typing import TYPE_CHECKING, Any + +import pytest from agent_framework import ( AgentExecutor, @@ -18,6 +21,9 @@ from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage from agent_framework.orchestrations import SequentialBuilder +if TYPE_CHECKING: + from _pytest.logging import LogCaptureFixture + class _CountingAgent(BaseAgent): """Agent that echoes messages with a counter to verify session state persistence.""" @@ -251,3 +257,85 @@ async def test_agent_executor_save_and_restore_state_directly() -> None: # Verify session was restored with correct session_id restored_session = new_executor._session # type: ignore[reportPrivateUsage] assert restored_session.session_id == session.session_id + + +async def test_agent_executor_run_with_session_kwarg_does_not_raise() -> None: + """Passing session= via workflow.run() should not cause a duplicate-keyword TypeError (#4295).""" + agent = _CountingAgent(id="session_kwarg_agent", name="SessionKwargAgent") + executor = AgentExecutor(agent, id="session_kwarg_exec") + workflow = SequentialBuilder(participants=[executor]).build() + + # This previously raised: TypeError: run() got multiple values for keyword argument 'session' + result = await workflow.run("hello", session="user-supplied-value") + assert result is not None + assert agent.call_count == 1 + + +async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -> None: + """Passing stream= via workflow.run() kwargs should not cause a duplicate-keyword TypeError.""" + agent = _CountingAgent(id="stream_kwarg_agent", name="StreamKwargAgent") + executor = AgentExecutor(agent, id="stream_kwarg_exec") + workflow = SequentialBuilder(participants=[executor]).build() + + # stream=True at workflow level triggers streaming mode (returns async iterable) + events = [] + async for event in workflow.run("hello", stream=True): + events.append(event) + assert len(events) > 0 + assert agent.call_count == 1 + + +@pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"]) +async def test_prepare_agent_run_args_strips_reserved_kwargs( + reserved_kwarg: str, caplog: "LogCaptureFixture" +) -> None: + """_prepare_agent_run_args must remove reserved kwargs and log a warning.""" + raw = {reserved_kwarg: "should-be-stripped", "custom_key": "keep-me"} + + with caplog.at_level(logging.WARNING): + run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) + + assert reserved_kwarg not in run_kwargs + assert "custom_key" in run_kwargs + assert options is not None + assert options["additional_function_arguments"]["custom_key"] == "keep-me" + assert any(reserved_kwarg in record.message for record in caplog.records) + + +async def test_prepare_agent_run_args_preserves_non_reserved_kwargs() -> None: + """Non-reserved workflow kwargs should pass through unchanged.""" + raw = {"custom_param": "value", "another": 42} + run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) + assert run_kwargs["custom_param"] == "value" + assert run_kwargs["another"] == 42 + + +async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once( + caplog: "LogCaptureFixture", +) -> None: + """All reserved kwargs should be stripped when supplied together, each emitting a warning.""" + raw = {"session": "x", "stream": True, "messages": [], "custom": 1} + + with caplog.at_level(logging.WARNING): + run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) + + assert "session" not in run_kwargs + assert "stream" not in run_kwargs + assert "messages" not in run_kwargs + assert run_kwargs["custom"] == 1 + assert options is not None + assert options["additional_function_arguments"]["custom"] == 1 + + warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()} + assert warned_keys == {"session", "stream", "messages"} + + +async def test_agent_executor_run_with_messages_kwarg_does_not_raise() -> None: + """Passing messages= via workflow.run() kwargs should not cause a duplicate-keyword TypeError.""" + agent = _CountingAgent(id="messages_kwarg_agent", name="MessagesKwargAgent") + executor = AgentExecutor(agent, id="messages_kwarg_exec") + workflow = SequentialBuilder(participants=[executor]).build() + + result = await workflow.run("hello", messages=["stale"]) + assert result is not None + assert agent.call_count == 1