diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 580b6e2c6d..f0dacad825 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -499,16 +499,34 @@ async def agent_wrapper(**kwargs: Any) -> str: # Extract the input from kwargs using the specified arg_name input_text = kwargs.get(arg_name, "") - # Forward runtime context kwargs, excluding arg_name and conversation_id. + # Extract conversation_id forwarded from parent agent's tool invocation loop + parent_conversation_id = kwargs.get("conversation_id") + + # Forward runtime context kwargs, excluding arg_name, conversation_id, and options. forwarded_kwargs = {k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options")} + # Pass parent's conversation_id via additional_function_arguments so it reaches + # the sub-agent's tools through **kwargs for correlation purposes. + # We do NOT pass it via chat options because the sub-agent's chat client should + # start its own conversation, not try to continue the parent's. + # Merge into any existing options/additional_function_arguments to avoid replacing them. + existing_options = kwargs.get("options") + run_options: dict[str, Any] | None = dict(existing_options) if isinstance(existing_options, dict) else None + if parent_conversation_id: + if run_options is None: + run_options = {} + existing_additional = run_options.get("additional_function_arguments") or {} + merged_additional = dict(existing_additional) if isinstance(existing_additional, dict) else {} + merged_additional["parent_conversation_id"] = parent_conversation_id + run_options["additional_function_arguments"] = merged_additional + if stream_callback is None: # Use non-streaming mode - return (await self.run(input_text, stream=False, **forwarded_kwargs)).text + return (await self.run(input_text, stream=False, options=run_options, **forwarded_kwargs)).text # Use streaming mode - accumulate updates and create final response response_updates: list[AgentResponseUpdate] = [] - async for update in self.run(input_text, stream=True, **forwarded_kwargs): + async for update in self.run(input_text, stream=True, options=run_options, **forwarded_kwargs): response_updates.append(update) if is_async_callback: await stream_callback(update) # type: ignore[misc] @@ -937,7 +955,11 @@ def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpda def _finalizer(updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]: ctx = ctx_holder["ctx"] - rf = ctx.get("chat_options", {}).get("response_format") if ctx else (options.get("response_format") if options else None) + rf = ( + ctx.get("chat_options", {}).get("response_format") + if ctx + else (options.get("response_format") if options else None) + ) return self._finalize_response_updates(updates, response_format=rf) return ( diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 3ec167d4f7..3b2217082a 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -1406,11 +1406,11 @@ async def _auto_invoke_function( parsed_args: dict[str, Any] = dict(function_call_content.parse_arguments() or {}) # Filter out internal framework kwargs before passing to tools. - # conversation_id is an internal tracking ID that should not be forwarded to tools. + # conversation_id is forwarded so agent-as-tool wrappers can correlate sub-agent conversations. runtime_kwargs: dict[str, Any] = { key: value for key, value in (custom_args or {}).items() - if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"} + if key not in {"_function_middleware_pipeline", "middleware"} } try: if not tool._schema_supplied and tool.input_model is not None: @@ -2083,7 +2083,7 @@ def get_response( max_errors: int = self.function_invocation_configuration["max_consecutive_errors_per_request"] # type: ignore[assignment] additional_function_arguments: dict[str, Any] = {} if options and (additional_opts := options.get("additional_function_arguments")): # type: ignore[attr-defined] - additional_function_arguments = additional_opts # type: ignore + additional_function_arguments = dict(additional_opts) # type: ignore[call-overload] # defensive copy execute_function_calls = partial( _execute_function_calls, custom_args=additional_function_arguments, @@ -2145,6 +2145,7 @@ async def _get_response() -> ChatResponse: if response.conversation_id is not None: _update_conversation_id(kwargs, response.conversation_id, mutable_options) + additional_function_arguments["conversation_id"] = response.conversation_id prepped_messages = [] result = await _process_function_requests( @@ -2279,6 +2280,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: if response.conversation_id is not None: _update_conversation_id(kwargs, response.conversation_id, mutable_options) + additional_function_arguments["conversation_id"] = response.conversation_id prepped_messages = [] result = await _process_function_requests( diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index b6f84dc970..df35c77ee3 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import contextlib -from collections.abc import AsyncIterable, MutableSequence +from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence from typing import Any from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 @@ -23,10 +23,14 @@ Message, SupportsAgentRun, SupportsChatGetResponse, + agent_middleware, tool, ) from agent_framework._agents import _merge_options, _sanitize_agent_name from agent_framework._mcp import MCPTool +from agent_framework._middleware import AgentContext + +from .conftest import MockChatClient def test_agent_session_type(agent_session: AgentSession) -> None: @@ -707,6 +711,51 @@ async def test_chat_agent_as_tool_name_sanitization(client: SupportsChatGetRespo assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}" +async def test_chat_agent_as_tool_propagates_conversation_id(client: SupportsChatGetResponse) -> None: + """Test that as_tool passes parent_conversation_id to sub-agent via additional_function_arguments.""" + mock_client: MockChatClient = client # type: ignore[assignment] + captured_options: dict[str, Any] = {} + + @agent_middleware + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + captured_options.update(context.options or {}) + await call_next() + + mock_client.responses = [ + ChatResponse(messages=[Message(role="assistant", text="Sub-agent response")]), + ] + + sub_agent = Agent(client=mock_client, name="sub_agent", middleware=[capture_middleware]) + t = sub_agent.as_tool(name="delegate", arg_name="task") + + await t.invoke(arguments=t.input_model(task="Test delegation"), conversation_id="conv-parent-123") + + additional_args = captured_options.get("additional_function_arguments", {}) + assert additional_args.get("parent_conversation_id") == "conv-parent-123" + + +async def test_chat_agent_as_tool_no_conversation_id_when_absent(client: SupportsChatGetResponse) -> None: + """Test that as_tool does not inject additional_function_arguments when no conversation_id provided.""" + mock_client: MockChatClient = client # type: ignore[assignment] + captured_options: dict[str, Any] = {} + + @agent_middleware + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + captured_options.update(context.options or {}) + await call_next() + + mock_client.responses = [ + ChatResponse(messages=[Message(role="assistant", text="Sub-agent response")]), + ] + + sub_agent = Agent(client=mock_client, name="sub_agent", middleware=[capture_middleware]) + t = sub_agent.as_tool(name="delegate", arg_name="task") + + await t.invoke(arguments=t.input_model(task="Test delegation"), user_id="user-789") + + assert "additional_function_arguments" not in captured_options + + async def test_chat_agent_as_mcp_server_basic(client: SupportsChatGetResponse) -> None: """Test basic as_mcp_server functionality.""" agent = Agent(client=client, name="TestAgent", description="Test agent for MCP") diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py index da8e907c40..e059d78560 100644 --- a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -339,3 +339,63 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Verify other kwargs were still forwarded assert captured_kwargs.get("api_token") == "secret-xyz-123" assert captured_kwargs.get("user_id") == "user-456" + + async def test_as_tool_propagates_conversation_id_via_options(self, client: MockChatClient) -> None: + """Test that parent_conversation_id from parent is passed to sub-agent's additional_function_arguments.""" + captured_options: dict[str, Any] = {} + + @agent_middleware + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + captured_options.update(context.options or {}) + await call_next() + + # Setup mock response + client.responses = [ + ChatResponse(messages=[Message(role="assistant", text="Sub-agent response")]), + ] + + sub_agent = Agent( + client=client, + name="sub_agent", + middleware=[capture_middleware], + ) + + tool = sub_agent.as_tool(name="delegate", arg_name="task") + + await tool.invoke( + arguments=tool.input_model(task="Test delegation"), + conversation_id="conv-parent-123", + ) + + # Verify parent_conversation_id was passed via additional_function_arguments + additional_args = captured_options.get("additional_function_arguments", {}) + assert additional_args.get("parent_conversation_id") == "conv-parent-123" + + async def test_as_tool_no_conversation_id_when_absent(self, client: MockChatClient) -> None: + """Test that no parent_conversation_id is injected when parent has none.""" + captured_options: dict[str, Any] = {} + + @agent_middleware + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + captured_options.update(context.options or {}) + await call_next() + + client.responses = [ + ChatResponse(messages=[Message(role="assistant", text="Sub-agent response")]), + ] + + sub_agent = Agent( + client=client, + name="sub_agent", + middleware=[capture_middleware], + ) + + tool = sub_agent.as_tool(name="delegate", arg_name="task") + + await tool.invoke( + arguments=tool.input_model(task="Test delegation"), + user_id="user-789", + ) + + # additional_function_arguments should not be in options when no conversation_id provided + assert "additional_function_arguments" not in captured_options diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index 8d74dc181d..9440f964ab 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -1464,4 +1464,66 @@ def test_nested_object_with_const_and_enum(): model(config={"type": "production", "level": "critical"}) -# endregion +async def test_auto_invoke_function_forwards_conversation_id() -> None: + """Test that _auto_invoke_function forwards conversation_id to tools that accept **kwargs.""" + from agent_framework._tools import _auto_invoke_function + + captured_kwargs: dict[str, Any] = {} + + @tool(approval_mode="never_require") + def capturing_tool(query: str, **kwargs: Any) -> str: + """A tool that captures kwargs.""" + captured_kwargs.update(kwargs) + return "ok" + + function_call = Content.from_function_call(name="capturing_tool", arguments='{"query": "test"}', call_id="call_1") + + await _auto_invoke_function( + function_call, + custom_args={"conversation_id": "conv-123", "other_arg": "value"}, + config={ + "enabled": True, + "max_iterations": 1, + "max_consecutive_errors_per_request": 3, + "include_detailed_errors": True, + }, + tool_map={"capturing_tool": capturing_tool}, + ) + + assert captured_kwargs.get("conversation_id") == "conv-123" + assert captured_kwargs.get("other_arg") == "value" + + +async def test_auto_invoke_function_still_filters_internal_kwargs() -> None: + """Test that _auto_invoke_function still filters _function_middleware_pipeline and middleware.""" + from agent_framework._tools import _auto_invoke_function + + captured_kwargs: dict[str, Any] = {} + + @tool(approval_mode="never_require") + def capturing_tool(query: str, **kwargs: Any) -> str: + """A tool that captures kwargs.""" + captured_kwargs.update(kwargs) + return "ok" + + function_call = Content.from_function_call(name="capturing_tool", arguments='{"query": "test"}', call_id="call_1") + + await _auto_invoke_function( + function_call, + custom_args={ + "_function_middleware_pipeline": "should_be_filtered", + "middleware": "should_be_filtered", + "conversation_id": "conv-456", + }, + config={ + "enabled": True, + "max_iterations": 1, + "max_consecutive_errors_per_request": 3, + "include_detailed_errors": True, + }, + tool_map={"capturing_tool": capturing_tool}, + ) + + assert "_function_middleware_pipeline" not in captured_kwargs + assert "middleware" not in captured_kwargs + assert captured_kwargs.get("conversation_id") == "conv-456" diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index db53868ee1..4a850db642 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -286,9 +286,7 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() - @pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"]) -async def test_prepare_agent_run_args_strips_reserved_kwargs( - reserved_kwarg: str, caplog: "LogCaptureFixture" -) -> None: +async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str, caplog: "LogCaptureFixture") -> None: """_prepare_agent_run_args must remove reserved kwargs and log a warning.""" raw = {reserved_kwarg: "should-be-stripped", "custom_key": "keep-me"} diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index 379435e124..ce1465effc 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -499,9 +499,7 @@ async def _done() -> AgentResponse: # Continue with responses only — no new kwargs approval = request_events[0] - await workflow.run( - responses={approval.request_id: approval.data.to_function_approval_response(True)} - ) + await workflow.run(responses={approval.request_id: approval.data.to_function_approval_response(True)}) # Both calls should have received the original kwargs assert len(agent.captured_kwargs) == 2 diff --git a/python/samples/02-agents/tools/agent_as_tool_conversation_id_propagation.py b/python/samples/02-agents/tools/agent_as_tool_conversation_id_propagation.py new file mode 100644 index 0000000000..54b69740f3 --- /dev/null +++ b/python/samples/02-agents/tools/agent_as_tool_conversation_id_propagation.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from collections.abc import Awaitable, Callable +from typing import Any + +from agent_framework import Agent, FunctionInvocationContext, agent_middleware, tool +from agent_framework._middleware import AgentContext +from agent_framework.openai import OpenAIResponsesClient + +""" +Agent-as-Tool with Conversation ID Propagation Example + +Demonstrates how a parent agent's conversation_id is automatically propagated +to sub-agents wrapped as tools via as_tool(). This enables correlating +multi-agent conversations in storage systems. + +The middleware below is ONLY for observability — to print the conversation_id +at each stage so you can verify the propagation. It is NOT required for the +feature to work. The propagation happens automatically at the framework level. + +NOTE: conversation_id propagation requires a chat client that returns +conversation_id in its responses (e.g., OpenAI Responses API). +""" + + +# --- Observability middleware (not required for the feature) --- + + +@agent_middleware +async def log_conversation_id(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + """Prints the conversation_id seen by each agent. Only for demonstration.""" + agent_name = context.agent.name if hasattr(context.agent, "name") else "unknown" + conv_id = (context.options or {}).get("conversation_id") + additional = (context.options or {}).get("additional_function_arguments", {}) + parent_id = additional.get("parent_conversation_id") + print(f" [{agent_name}] conversation_id={conv_id}, parent_conversation_id={parent_id}") + await call_next() + + +async def log_tool_kwargs( + context: FunctionInvocationContext, + call_next: Callable[[], Awaitable[None]], +) -> None: + """Prints the kwargs forwarded to a tool. Only for demonstration.""" + conv_id = context.kwargs.get("conversation_id") + parent_id = context.kwargs.get("parent_conversation_id") + print(f" [tool:{context.function.name}] conversation_id={conv_id}, parent_conversation_id={parent_id}") + await call_next() + + +# --- Application code --- + + +# This tool is NOT required for conversation_id propagation to work. +# It is included only to show the parent_conversation_id arriving via **kwargs. +# NOTE: approval_mode="never_require" is for sample brevity. +@tool(approval_mode="never_require") +def lookup_info(query: str, **kwargs: Any) -> str: + """Look up information for a given query. + + Args: + query: The search query. + + Keyword Args: + kwargs: Runtime context forwarded by the framework, including + parent_conversation_id if the parent agent propagated one. + + Returns: + The lookup result. + """ + parent_id = kwargs.get("parent_conversation_id") + return f"Results for '{query}' (tracked under parent conversation {parent_id})" + + +async def main() -> None: + print("=== Agent-as-Tool: Conversation ID Propagation ===\n") + + client = OpenAIResponsesClient() + + # Create a specialized research agent + researcher = Agent( + client=client, + name="ResearchAgent", + instructions="You are a research assistant. Use the lookup_info tool to find information.", + tools=[lookup_info], + middleware=[log_conversation_id], + function_middleware=[log_tool_kwargs], + ) + + # Wrap the research agent as a tool for the coordinator + research_tool = researcher.as_tool( + name="research", + description="Delegate research tasks to a specialized research agent", + arg_name="task", + arg_description="The research task to perform", + ) + + # Create coordinator with the same observability middleware + coordinator = Agent( + client=client, + name="CoordinatorAgent", + instructions=( + "You are a coordinator. When the user asks a question, delegate to the research tool to find the answer." + ), + tools=[research_tool], + middleware=[log_conversation_id], + function_middleware=[log_tool_kwargs], + ) + + # Run — watch the printed output to see conversation_id flow: + # 1. CoordinatorAgent gets a conversation_id from the API + # 2. The tool invocation forwards it to the research tool's **kwargs + # 3. ResearchAgent receives it as parent_conversation_id in its options + # 4. ResearchAgent's own tools see it via **kwargs + response = await coordinator.run("What are the latest developments in quantum computing?") + + print(f"\nCoordinator: {response.text}") + + +if __name__ == "__main__": + asyncio.run(main())