Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/packages/core/agent_framework/_workflows/_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,10 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp

return response

# Parameters that are explicitly passed to agent.run() by AgentExecutor
# and must not appear in **run_kwargs to avoid TypeError from duplicate values.
_RESERVED_RUN_PARAMS: frozenset[str] = frozenset({"session", "stream", "messages"})

@staticmethod
def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] | None]:
"""Prepare kwargs and options for agent.run(), avoiding duplicate option passing.
Expand All @@ -423,8 +427,23 @@ def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, A
`options.additional_function_arguments`. If workflow kwargs include an
`options` key, merge it into the final options object and remove it from
kwargs before spreading `**run_kwargs`.

Reserved parameters (session, stream, messages) that are explicitly
managed by AgentExecutor are stripped from run_kwargs to prevent
``TypeError: got multiple values for keyword argument`` collisions.
"""
run_kwargs = dict(raw_run_kwargs)

# Strip reserved params that AgentExecutor passes explicitly to agent.run().
for key in AgentExecutor._RESERVED_RUN_PARAMS:
if key in run_kwargs:
logger.warning(
"Workflow kwarg '%s' is reserved by AgentExecutor and will be ignored. "
"Remove it from workflow.run() kwargs to silence this warning.",
key,
)
run_kwargs.pop(key)

options_from_workflow = run_kwargs.pop("options", None)
workflow_additional_args = run_kwargs.pop("additional_function_arguments", None)

Expand Down
90 changes: 89 additions & 1 deletion python/packages/core/tests/workflow/test_agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) Microsoft. All rights reserved.

import logging
from collections.abc import AsyncIterable, Awaitable
from typing import Any
from typing import TYPE_CHECKING, Any

import pytest

from agent_framework import (
AgentExecutor,
Expand All @@ -18,6 +21,9 @@
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
from agent_framework.orchestrations import SequentialBuilder

if TYPE_CHECKING:
from _pytest.logging import LogCaptureFixture


class _CountingAgent(BaseAgent):
"""Agent that echoes messages with a counter to verify session state persistence."""
Expand Down Expand Up @@ -251,3 +257,85 @@ async def test_agent_executor_save_and_restore_state_directly() -> None:
# Verify session was restored with correct session_id
restored_session = new_executor._session # type: ignore[reportPrivateUsage]
assert restored_session.session_id == session.session_id


async def test_agent_executor_run_with_session_kwarg_does_not_raise() -> None:
"""Passing session= via workflow.run() should not cause a duplicate-keyword TypeError (#4295)."""
agent = _CountingAgent(id="session_kwarg_agent", name="SessionKwargAgent")
executor = AgentExecutor(agent, id="session_kwarg_exec")
workflow = SequentialBuilder(participants=[executor]).build()

# This previously raised: TypeError: run() got multiple values for keyword argument 'session'
result = await workflow.run("hello", session="user-supplied-value")
assert result is not None
assert agent.call_count == 1


async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -> None:
"""Passing stream= via workflow.run() kwargs should not cause a duplicate-keyword TypeError."""
agent = _CountingAgent(id="stream_kwarg_agent", name="StreamKwargAgent")
executor = AgentExecutor(agent, id="stream_kwarg_exec")
workflow = SequentialBuilder(participants=[executor]).build()

# stream=True at workflow level triggers streaming mode (returns async iterable)
events = []
async for event in workflow.run("hello", stream=True):
events.append(event)
assert len(events) > 0
assert agent.call_count == 1


@pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"])
async def test_prepare_agent_run_args_strips_reserved_kwargs(
reserved_kwarg: str, caplog: "LogCaptureFixture"
) -> None:
"""_prepare_agent_run_args must remove reserved kwargs and log a warning."""
raw = {reserved_kwarg: "should-be-stripped", "custom_key": "keep-me"}

with caplog.at_level(logging.WARNING):
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)

assert reserved_kwarg not in run_kwargs
assert "custom_key" in run_kwargs
assert options is not None
assert options["additional_function_arguments"]["custom_key"] == "keep-me"
assert any(reserved_kwarg in record.message for record in caplog.records)


async def test_prepare_agent_run_args_preserves_non_reserved_kwargs() -> None:
"""Non-reserved workflow kwargs should pass through unchanged."""
raw = {"custom_param": "value", "another": 42}
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
assert run_kwargs["custom_param"] == "value"
assert run_kwargs["another"] == 42


async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once(
caplog: "LogCaptureFixture",
) -> None:
"""All reserved kwargs should be stripped when supplied together, each emitting a warning."""
raw = {"session": "x", "stream": True, "messages": [], "custom": 1}

with caplog.at_level(logging.WARNING):
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)

assert "session" not in run_kwargs
assert "stream" not in run_kwargs
assert "messages" not in run_kwargs
assert run_kwargs["custom"] == 1
assert options is not None
assert options["additional_function_arguments"]["custom"] == 1

warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()}
assert warned_keys == {"session", "stream", "messages"}


async def test_agent_executor_run_with_messages_kwarg_does_not_raise() -> None:
"""Passing messages= via workflow.run() kwargs should not cause a duplicate-keyword TypeError."""
agent = _CountingAgent(id="messages_kwarg_agent", name="MessagesKwargAgent")
executor = AgentExecutor(agent, id="messages_kwarg_exec")
workflow = SequentialBuilder(participants=[executor]).build()

result = await workflow.run("hello", messages=["stale"])
assert result is not None
assert agent.call_count == 1