Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,16 +499,34 @@ async def agent_wrapper(**kwargs: Any) -> str:
# Extract the input from kwargs using the specified arg_name
input_text = kwargs.get(arg_name, "")

# Forward runtime context kwargs, excluding arg_name and conversation_id.
# Extract conversation_id forwarded from parent agent's tool invocation loop
parent_conversation_id = kwargs.get("conversation_id")

# Forward runtime context kwargs, excluding arg_name, conversation_id, and options.
forwarded_kwargs = {k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options")}

# Pass parent's conversation_id via additional_function_arguments so it reaches
# the sub-agent's tools through **kwargs for correlation purposes.
# We do NOT pass it via chat options because the sub-agent's chat client should
# start its own conversation, not try to continue the parent's.
# Merge into any existing options/additional_function_arguments to avoid replacing them.
existing_options = kwargs.get("options")
run_options: dict[str, Any] | None = dict(existing_options) if isinstance(existing_options, dict) else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're doing all these extra dict calls only for typing, then we really shouldn't be going this, ignore instead or actually raise a issue if the value isn't right, but this is either does nothing, or it will raise a strange error because it can't create the dict. This goes for all of these later on as well.

if parent_conversation_id:
if run_options is None:
run_options = {}
existing_additional = run_options.get("additional_function_arguments") or {}
merged_additional = dict(existing_additional) if isinstance(existing_additional, dict) else {}
merged_additional["parent_conversation_id"] = parent_conversation_id
run_options["additional_function_arguments"] = merged_additional

if stream_callback is None:
# Use non-streaming mode
return (await self.run(input_text, stream=False, **forwarded_kwargs)).text
return (await self.run(input_text, stream=False, options=run_options, **forwarded_kwargs)).text

# Use streaming mode - accumulate updates and create final response
response_updates: list[AgentResponseUpdate] = []
async for update in self.run(input_text, stream=True, **forwarded_kwargs):
async for update in self.run(input_text, stream=True, options=run_options, **forwarded_kwargs):
response_updates.append(update)
if is_async_callback:
await stream_callback(update) # type: ignore[misc]
Expand Down Expand Up @@ -937,7 +955,11 @@ def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpda

def _finalizer(updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]:
ctx = ctx_holder["ctx"]
rf = ctx.get("chat_options", {}).get("response_format") if ctx else (options.get("response_format") if options else None)
rf = (
ctx.get("chat_options", {}).get("response_format")
if ctx
else (options.get("response_format") if options else None)
)
return self._finalize_response_updates(updates, response_format=rf)

return (
Expand Down
8 changes: 5 additions & 3 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,11 +1406,11 @@ async def _auto_invoke_function(
parsed_args: dict[str, Any] = dict(function_call_content.parse_arguments() or {})

# Filter out internal framework kwargs before passing to tools.
# conversation_id is an internal tracking ID that should not be forwarded to tools.
# conversation_id is forwarded so agent-as-tool wrappers can correlate sub-agent conversations.
runtime_kwargs: dict[str, Any] = {
key: value
for key, value in (custom_args or {}).items()
if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"}
if key not in {"_function_middleware_pipeline", "middleware"}
}
try:
if not tool._schema_supplied and tool.input_model is not None:
Expand Down Expand Up @@ -2083,7 +2083,7 @@ def get_response(
max_errors: int = self.function_invocation_configuration["max_consecutive_errors_per_request"] # type: ignore[assignment]
additional_function_arguments: dict[str, Any] = {}
if options and (additional_opts := options.get("additional_function_arguments")): # type: ignore[attr-defined]
additional_function_arguments = additional_opts # type: ignore
additional_function_arguments = dict(additional_opts) # type: ignore[call-overload] # defensive copy
execute_function_calls = partial(
_execute_function_calls,
custom_args=additional_function_arguments,
Expand Down Expand Up @@ -2145,6 +2145,7 @@ async def _get_response() -> ChatResponse:

if response.conversation_id is not None:
_update_conversation_id(kwargs, response.conversation_id, mutable_options)
additional_function_arguments["conversation_id"] = response.conversation_id
prepped_messages = []

result = await _process_function_requests(
Expand Down Expand Up @@ -2279,6 +2280,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]:

if response.conversation_id is not None:
_update_conversation_id(kwargs, response.conversation_id, mutable_options)
additional_function_arguments["conversation_id"] = response.conversation_id
prepped_messages = []

result = await _process_function_requests(
Expand Down
51 changes: 50 additions & 1 deletion python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import contextlib
from collections.abc import AsyncIterable, MutableSequence
from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
Expand All @@ -23,10 +23,14 @@
Message,
SupportsAgentRun,
SupportsChatGetResponse,
agent_middleware,
tool,
)
from agent_framework._agents import _merge_options, _sanitize_agent_name
from agent_framework._mcp import MCPTool
from agent_framework._middleware import AgentContext

from .conftest import MockChatClient


def test_agent_session_type(agent_session: AgentSession) -> None:
Expand Down Expand Up @@ -707,6 +711,51 @@ async def test_chat_agent_as_tool_name_sanitization(client: SupportsChatGetRespo
assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}"


async def test_chat_agent_as_tool_propagates_conversation_id(client: SupportsChatGetResponse) -> None:
"""Test that as_tool passes parent_conversation_id to sub-agent via additional_function_arguments."""
mock_client: MockChatClient = client # type: ignore[assignment]
captured_options: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_options.update(context.options or {})
await call_next()

mock_client.responses = [
ChatResponse(messages=[Message(role="assistant", text="Sub-agent response")]),
]

sub_agent = Agent(client=mock_client, name="sub_agent", middleware=[capture_middleware])
t = sub_agent.as_tool(name="delegate", arg_name="task")

await t.invoke(arguments=t.input_model(task="Test delegation"), conversation_id="conv-parent-123")

additional_args = captured_options.get("additional_function_arguments", {})
assert additional_args.get("parent_conversation_id") == "conv-parent-123"


async def test_chat_agent_as_tool_no_conversation_id_when_absent(client: SupportsChatGetResponse) -> None:
"""Test that as_tool does not inject additional_function_arguments when no conversation_id provided."""
mock_client: MockChatClient = client # type: ignore[assignment]
captured_options: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_options.update(context.options or {})
await call_next()

mock_client.responses = [
ChatResponse(messages=[Message(role="assistant", text="Sub-agent response")]),
]

sub_agent = Agent(client=mock_client, name="sub_agent", middleware=[capture_middleware])
t = sub_agent.as_tool(name="delegate", arg_name="task")

await t.invoke(arguments=t.input_model(task="Test delegation"), user_id="user-789")

assert "additional_function_arguments" not in captured_options


async def test_chat_agent_as_mcp_server_basic(client: SupportsChatGetResponse) -> None:
"""Test basic as_mcp_server functionality."""
agent = Agent(client=client, name="TestAgent", description="Test agent for MCP")
Expand Down
60 changes: 60 additions & 0 deletions python/packages/core/tests/core/test_as_tool_kwargs_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,63 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai
# Verify other kwargs were still forwarded
assert captured_kwargs.get("api_token") == "secret-xyz-123"
assert captured_kwargs.get("user_id") == "user-456"

async def test_as_tool_propagates_conversation_id_via_options(self, client: MockChatClient) -> None:
"""Test that parent_conversation_id from parent is passed to sub-agent's additional_function_arguments."""
captured_options: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_options.update(context.options or {})
await call_next()

# Setup mock response
client.responses = [
ChatResponse(messages=[Message(role="assistant", text="Sub-agent response")]),
]

sub_agent = Agent(
client=client,
name="sub_agent",
middleware=[capture_middleware],
)

tool = sub_agent.as_tool(name="delegate", arg_name="task")

await tool.invoke(
arguments=tool.input_model(task="Test delegation"),
conversation_id="conv-parent-123",
)

# Verify parent_conversation_id was passed via additional_function_arguments
additional_args = captured_options.get("additional_function_arguments", {})
assert additional_args.get("parent_conversation_id") == "conv-parent-123"

async def test_as_tool_no_conversation_id_when_absent(self, client: MockChatClient) -> None:
"""Test that no parent_conversation_id is injected when parent has none."""
captured_options: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_options.update(context.options or {})
await call_next()

client.responses = [
ChatResponse(messages=[Message(role="assistant", text="Sub-agent response")]),
]

sub_agent = Agent(
client=client,
name="sub_agent",
middleware=[capture_middleware],
)

tool = sub_agent.as_tool(name="delegate", arg_name="task")

await tool.invoke(
arguments=tool.input_model(task="Test delegation"),
user_id="user-789",
)

# additional_function_arguments should not be in options when no conversation_id provided
assert "additional_function_arguments" not in captured_options
64 changes: 63 additions & 1 deletion python/packages/core/tests/core/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,4 +1464,66 @@ def test_nested_object_with_const_and_enum():
model(config={"type": "production", "level": "critical"})


# endregion
async def test_auto_invoke_function_forwards_conversation_id() -> None:
"""Test that _auto_invoke_function forwards conversation_id to tools that accept **kwargs."""
from agent_framework._tools import _auto_invoke_function

captured_kwargs: dict[str, Any] = {}

@tool(approval_mode="never_require")
def capturing_tool(query: str, **kwargs: Any) -> str:
"""A tool that captures kwargs."""
captured_kwargs.update(kwargs)
return "ok"

function_call = Content.from_function_call(name="capturing_tool", arguments='{"query": "test"}', call_id="call_1")

await _auto_invoke_function(
function_call,
custom_args={"conversation_id": "conv-123", "other_arg": "value"},
config={
"enabled": True,
"max_iterations": 1,
"max_consecutive_errors_per_request": 3,
"include_detailed_errors": True,
},
tool_map={"capturing_tool": capturing_tool},
)

assert captured_kwargs.get("conversation_id") == "conv-123"
assert captured_kwargs.get("other_arg") == "value"


async def test_auto_invoke_function_still_filters_internal_kwargs() -> None:
"""Test that _auto_invoke_function still filters _function_middleware_pipeline and middleware."""
from agent_framework._tools import _auto_invoke_function

captured_kwargs: dict[str, Any] = {}

@tool(approval_mode="never_require")
def capturing_tool(query: str, **kwargs: Any) -> str:
"""A tool that captures kwargs."""
captured_kwargs.update(kwargs)
return "ok"

function_call = Content.from_function_call(name="capturing_tool", arguments='{"query": "test"}', call_id="call_1")

await _auto_invoke_function(
function_call,
custom_args={
"_function_middleware_pipeline": "should_be_filtered",
"middleware": "should_be_filtered",
"conversation_id": "conv-456",
},
config={
"enabled": True,
"max_iterations": 1,
"max_consecutive_errors_per_request": 3,
"include_detailed_errors": True,
},
tool_map={"capturing_tool": capturing_tool},
)

assert "_function_middleware_pipeline" not in captured_kwargs
assert "middleware" not in captured_kwargs
assert captured_kwargs.get("conversation_id") == "conv-456"
4 changes: 1 addition & 3 deletions python/packages/core/tests/workflow/test_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,7 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -


@pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"])
async def test_prepare_agent_run_args_strips_reserved_kwargs(
reserved_kwarg: str, caplog: "LogCaptureFixture"
) -> None:
async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str, caplog: "LogCaptureFixture") -> None:
"""_prepare_agent_run_args must remove reserved kwargs and log a warning."""
raw = {reserved_kwarg: "should-be-stripped", "custom_key": "keep-me"}

Expand Down
4 changes: 1 addition & 3 deletions python/packages/core/tests/workflow/test_workflow_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,7 @@ async def _done() -> AgentResponse:

# Continue with responses only — no new kwargs
approval = request_events[0]
await workflow.run(
responses={approval.request_id: approval.data.to_function_approval_response(True)}
)
await workflow.run(responses={approval.request_id: approval.data.to_function_approval_response(True)})

# Both calls should have received the original kwargs
assert len(agent.captured_kwargs) == 2
Expand Down
Loading
Loading