From 200453453646e1159e728dd6d6acf07c6c588ea2 Mon Sep 17 00:00:00 2001 From: mukunda katta Date: Thu, 14 May 2026 19:32:15 -0700 Subject: [PATCH] refactor(client): rename message handler callback --- docs/migration.md | 2 +- src/mcp/client/__main__.py | 4 +-- src/mcp/client/client.py | 5 ++-- src/mcp/client/session.py | 10 +++---- src/mcp/client/session_group.py | 4 +-- tests/client/test_logging_callback.py | 4 +-- tests/client/test_notification_response.py | 4 +-- tests/client/test_session.py | 4 +-- tests/client/test_session_group.py | 2 +- .../tasks/client/test_handlers.py | 26 ++++++++--------- .../experimental/tasks/server/test_server.py | 4 +-- tests/server/mcpserver/test_integration.py | 8 +++--- tests/server/test_session.py | 4 +-- tests/shared/test_progress_notifications.py | 2 +- tests/shared/test_session.py | 6 ++-- tests/shared/test_streamable_http.py | 28 +++++++++---------- 16 files changed, 58 insertions(+), 59 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 8b70885e8d..22541d454c 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -690,7 +690,7 @@ async with Client(server) as client: result = await client.call_tool("my_tool", {"x": 1}) ``` -`Client` accepts the same callback parameters the old helper did (`sampling_callback`, `list_roots_callback`, `logging_callback`, `message_handler`, `elicitation_callback`, `client_info`) plus `raise_exceptions` to surface server-side errors. +`Client` accepts the same callback parameters the old helper did (`sampling_callback`, `list_roots_callback`, `logging_callback`, `message_callback`, `elicitation_callback`, `client_info`) plus `raise_exceptions` to surface server-side errors. If you need direct access to the underlying `ClientSession` and memory streams (e.g., for low-level transport testing), `create_client_server_memory_streams` is still available in `mcp.shared.memory`: diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index b9ec344226..4fd6d10296 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -22,7 +22,7 @@ logger = logging.getLogger("client") -async def message_handler( +async def message_callback( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): @@ -40,7 +40,7 @@ async def run_session( async with ClientSession( read_stream, write_stream, - message_handler=message_handler, + message_callback=message_callback, client_info=client_info, ) as session: logger.info("Initializing session") diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 34d6a360fa..a65088a9ac 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -85,8 +85,7 @@ async def main(): logging_callback: LoggingFnT | None = None """Callback for handling logging notifications.""" - # TODO(Marcelo): Why do we have both "callback" and "handler"? - message_handler: MessageHandlerFnT | None = None + message_callback: MessageHandlerFnT | None = None """Callback for handling raw messages.""" client_info: Implementation | None = None @@ -123,7 +122,7 @@ async def __aenter__(self) -> Client: sampling_callback=self.sampling_callback, list_roots_callback=self.list_roots_callback, logging_callback=self.logging_callback, - message_handler=self.message_handler, + message_callback=self.message_callback, client_info=self.client_info, elicitation_callback=self.elicitation_callback, ) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0cea454a77..17faeac48d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -54,7 +54,7 @@ async def __call__( ) -> None: ... # pragma: no branch -async def _default_message_handler( +async def _default_message_callback( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: await anyio.lowlevel.checkpoint() @@ -116,7 +116,7 @@ def __init__( elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, - message_handler: MessageHandlerFnT | None = None, + message_callback: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, *, sampling_capabilities: types.SamplingCapability | None = None, @@ -129,7 +129,7 @@ def __init__( self._elicitation_callback = elicitation_callback or _default_elicitation_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback - self._message_handler = message_handler or _default_message_handler + self._message_callback = message_callback or _default_message_callback self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._initialize_result: types.InitializeResult | None = None self._experimental_features: ExperimentalClientFeatures | None = None @@ -462,8 +462,8 @@ async def _handle_incoming( self, req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: - """Handle incoming messages by forwarding to the message handler.""" - await self._message_handler(req) + """Handle incoming messages by forwarding to the message callback.""" + await self._message_callback(req) async def _received_notification(self, notification: types.ServerNotification) -> None: """Handle notifications from the server.""" diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9610212642..4db901fcf8 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -78,7 +78,7 @@ class ClientSessionParameters: elicitation_callback: ElicitationFnT | None = None list_roots_callback: ListRootsFnT | None = None logging_callback: LoggingFnT | None = None - message_handler: MessageHandlerFnT | None = None + message_callback: MessageHandlerFnT | None = None client_info: types.Implementation | None = None @@ -308,7 +308,7 @@ async def _establish_session( elicitation_callback=session_params.elicitation_callback, list_roots_callback=session_params.list_roots_callback, logging_callback=session_params.logging_callback, - message_handler=session_params.message_handler, + message_callback=session_params.message_callback, client_info=session_params.client_info, ) ) diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 454c1d3382..5bdff6f026 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -54,7 +54,7 @@ async def test_tool_with_log_dict( return True # Create a message handler to catch exceptions - async def message_handler( + async def message_callback( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): # pragma: no cover @@ -63,7 +63,7 @@ async def message_handler( async with Client( server, logging_callback=logging_collector, - message_handler=message_handler, + message_callback=message_callback, ) as client: # First verify our test tool works result = await client.call_tool("test_tool", {}) diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index 69c8afeb84..db2db89417 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -81,7 +81,7 @@ async def test_non_compliant_notification_response() -> None: """ returned_exception = None - async def message_handler( # pragma: no cover + async def message_callback( # pragma: no cover message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: nonlocal returned_exception @@ -90,7 +90,7 @@ async def message_handler( # pragma: no cover async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_non_sdk_server_app())) as client: async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + async with ClientSession(read_stream, write_stream, message_callback=message_callback) as session: await session.initialize() # The test server returns a 204 instead of the expected 202 diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f25c964f03..64ae19482a 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -77,7 +77,7 @@ async def mock_server(): ) # Create a message handler to catch exceptions - async def message_handler( # pragma: no cover + async def message_callback( # pragma: no cover message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): @@ -87,7 +87,7 @@ async def message_handler( # pragma: no cover ClientSession( server_to_client_receive, client_to_server_send, - message_handler=message_handler, + message_callback=message_callback, ) as session, anyio.create_task_group() as tg, client_to_server_send, diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f39..f8e8887a62 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -376,7 +376,7 @@ async def test_client_session_group_establish_session_parameterized( elicitation_callback=None, list_roots_callback=None, logging_callback=None, - message_handler=None, + message_callback=None, client_info=None, ) mock_raw_session_cm.__aenter__.assert_awaited_once() diff --git a/tests/experimental/tasks/client/test_handlers.py b/tests/experimental/tasks/client/test_handlers.py index 137ff80106..b40ac90e1a 100644 --- a/tests/experimental/tasks/client/test_handlers.py +++ b/tests/experimental/tasks/client/test_handlers.py @@ -98,7 +98,7 @@ async def client_streams() -> AsyncIterator[ClientTestStreams]: await client_to_server_receive.aclose() -async def _default_message_handler( +async def _default_message_callback( message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, ) -> None: """Default message handler that ignores messages (tests handle them explicitly).""" @@ -141,7 +141,7 @@ async def run_client() -> None: async with ClientSession( client_streams.client_receive, client_streams.client_send, - message_handler=_default_message_handler, + message_callback=_default_message_callback, experimental_task_handlers=task_handlers, ): client_ready.set() @@ -200,7 +200,7 @@ async def run_client() -> None: async with ClientSession( client_streams.client_receive, client_streams.client_send, - message_handler=_default_message_handler, + message_callback=_default_message_callback, experimental_task_handlers=task_handlers, ): client_ready.set() @@ -258,7 +258,7 @@ async def run_client() -> None: async with ClientSession( client_streams.client_receive, client_streams.client_send, - message_handler=_default_message_handler, + message_callback=_default_message_callback, experimental_task_handlers=task_handlers, ): client_ready.set() @@ -321,7 +321,7 @@ async def run_client() -> None: async with ClientSession( client_streams.client_receive, client_streams.client_send, - message_handler=_default_message_handler, + message_callback=_default_message_callback, experimental_task_handlers=task_handlers, ): client_ready.set() @@ -422,7 +422,7 @@ async def run_client() -> None: async with ClientSession( client_streams.client_receive, client_streams.client_send, - message_handler=_default_message_handler, + message_callback=_default_message_callback, experimental_task_handlers=task_handlers, ): client_ready.set() @@ -562,7 +562,7 @@ async def run_client() -> None: async with ClientSession( client_streams.client_receive, client_streams.client_send, - message_handler=_default_message_handler, + message_callback=_default_message_callback, experimental_task_handlers=task_handlers, ): client_ready.set() @@ -649,7 +649,7 @@ async def run_client() -> None: async with ClientSession( client_streams.client_receive, client_streams.client_send, - message_handler=_default_message_handler, + message_callback=_default_message_callback, ): client_ready.set() await anyio.sleep_forever() @@ -688,7 +688,7 @@ async def run_client() -> None: async with ClientSession( client_streams.client_receive, client_streams.client_send, - message_handler=_default_message_handler, + message_callback=_default_message_callback, ): client_ready.set() await anyio.sleep_forever() @@ -724,7 +724,7 @@ async def run_client() -> None: async with ClientSession( client_streams.client_receive, client_streams.client_send, - message_handler=_default_message_handler, + message_callback=_default_message_callback, ): client_ready.set() await anyio.sleep_forever() @@ -760,7 +760,7 @@ async def run_client() -> None: async with ClientSession( client_streams.client_receive, client_streams.client_send, - message_handler=_default_message_handler, + message_callback=_default_message_callback, ): client_ready.set() await anyio.sleep_forever() @@ -797,7 +797,7 @@ async def run_client() -> None: async with ClientSession( client_streams.client_receive, client_streams.client_send, - message_handler=_default_message_handler, + message_callback=_default_message_callback, ): client_ready.set() await anyio.sleep_forever() @@ -843,7 +843,7 @@ async def run_client() -> None: async with ClientSession( client_streams.client_receive, client_streams.client_send, - message_handler=_default_message_handler, + message_callback=_default_message_callback, ): client_ready.set() await anyio.sleep_forever() diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 6a28b274ea..1842ff0f42 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -323,7 +323,7 @@ async def test_default_task_handlers_via_enable_tasks() -> None: server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - async def message_handler( + async def message_callback( message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, ) -> None: ... # pragma: no branch @@ -351,7 +351,7 @@ async def run_server() -> None: async with ClientSession( server_to_client_receive, client_to_server_send, - message_handler=message_handler, + message_callback=message_callback, ) as client_session: await client_session.initialize() diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index f71c0574cd..c6a54df495 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -184,12 +184,12 @@ async def test_tool_progress() -> None: """Test tool progress reporting.""" collector = NotificationCollector() - async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): + async def message_callback(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): await collector.handle_generic_notification(message) if isinstance(message, Exception): # pragma: no cover raise message - async with Client(tool_progress.mcp, message_handler=message_handler) as client: + async with Client(tool_progress.mcp, message_callback=message_callback) as client: # Test progress callback progress_updates = [] @@ -263,12 +263,12 @@ async def test_notifications() -> None: """Test notifications and logging functionality.""" collector = NotificationCollector() - async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): + async def message_callback(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): await collector.handle_generic_notification(message) if isinstance(message, Exception): # pragma: no cover raise message - async with Client(notifications.mcp, message_handler=message_handler) as client: + async with Client(notifications.mcp, message_callback=message_callback) as client: # Call tool that generates notifications tool_result = await client.call_tool("process_data", {"data": "test_data"}) assert len(tool_result.content) == 1 diff --git a/tests/server/test_session.py b/tests/server/test_session.py index a2786d865d..f119fb226d 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -28,7 +28,7 @@ async def test_server_session_initialize(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) # Create a message handler to catch exceptions - async def message_handler( # pragma: no cover + async def message_callback( # pragma: no cover message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): @@ -63,7 +63,7 @@ async def run_server(): ClientSession( server_to_client_receive, client_to_server_send, - message_handler=message_handler, + message_callback=message_callback, ) as client_session, anyio.create_task_group() as tg, ): diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index aad9e5d439..cba2cd5a00 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -139,7 +139,7 @@ async def handle_client_message( ClientSession( server_to_client_receive, client_to_server_send, - message_handler=handle_client_message, + message_callback=handle_client_message, ) as client_session, anyio.create_task_group() as tg, ): diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index d7c6cc3b5f..6063874a5b 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -305,7 +305,7 @@ async def mock_server(): @pytest.mark.anyio -async def test_null_id_error_surfaced_via_message_handler(): +async def test_null_id_error_surfaced_via_message_callback(): """Test that a JSONRPCError with id=None is surfaced to the message handler. Per JSON-RPC 2.0, error responses use id=null when the request id could not @@ -338,7 +338,7 @@ async def mock_server(): ClientSession( read_stream=client_read, write_stream=client_write, - message_handler=capture_errors, + message_callback=capture_errors, ) as _client_session, ): tg.start_soon(mock_server) @@ -399,7 +399,7 @@ async def make_request(client_session: ClientSession): ClientSession( read_stream=client_read, write_stream=client_write, - message_handler=capture_errors, + message_callback=capture_errors, ) as client_session, ): tg.start_soon(mock_server) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb61..e5b51c304f 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1153,14 +1153,14 @@ async def test_streamable_http_client_get_stream(basic_server: None, basic_serve notifications_received: list[types.ServerNotification] = [] # Define message handler to capture notifications - async def message_handler( # pragma: no branch + async def message_callback( # pragma: no branch message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, types.ServerNotification): # pragma: no branch notifications_received.append(message) async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + async with ClientSession(read_stream, write_stream, message_callback=message_callback) as session: # Initialize the session - this triggers the GET stream setup result = await session.initialize() assert isinstance(result, InitializeResult) @@ -1304,7 +1304,7 @@ async def test_streamable_http_client_resumption(event_server: tuple[SimpleEvent captured_notifications: list[types.ServerNotification] = [] first_notification_received = False - async def message_handler( # pragma: no branch + async def message_callback( # pragma: no branch message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, types.ServerNotification): # pragma: no branch @@ -1329,7 +1329,7 @@ async def on_resumption_token_update(token: str) -> None: write_stream, ): async with ClientSession( # pragma: no branch - read_stream, write_stream, message_handler=message_handler + read_stream, write_stream, message_callback=message_callback ) as session: # Initialize the session result = await session.initialize() @@ -1367,7 +1367,7 @@ async def run_tool(): await anyio.sleep(0.1) # The while loop only exits after first_notification_received=True, - # which is set by message_handler immediately after appending to + # which is set by message_callback immediately after appending to # captured_notifications. The server tool is blocked on its lock, # so nothing else can arrive before we cancel. assert len(captured_notifications) == 1 @@ -1385,7 +1385,7 @@ async def run_tool(): write_stream, ): async with ClientSession( - read_stream, write_stream, message_handler=message_handler + read_stream, write_stream, message_callback=message_callback ) as session: # pragma: no branch result = await session.send_request( types.CallToolRequest(params=types.CallToolRequestParams(name="release_lock", arguments={})), @@ -1982,7 +1982,7 @@ async def test_streamable_http_client_auto_reconnects( _, server_url = event_server captured_notifications: list[str] = [] - async def message_handler( + async def message_callback( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): # pragma: no branch @@ -1992,7 +1992,7 @@ async def message_handler( captured_notifications.append(str(message.params.data)) async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + async with ClientSession(read_stream, write_stream, message_callback=message_callback) as session: await session.initialize() # Call tool that: @@ -2046,7 +2046,7 @@ async def test_streamable_http_sse_polling_full_cycle( _, server_url = event_server all_notifications: list[str] = [] - async def message_handler( + async def message_callback( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): # pragma: no branch @@ -2056,7 +2056,7 @@ async def message_handler( all_notifications.append(str(message.params.data)) async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + async with ClientSession(read_stream, write_stream, message_callback=message_callback) as session: await session.initialize() # Call tool that simulates polling pattern: @@ -2086,7 +2086,7 @@ async def test_streamable_http_events_replayed_after_disconnect( _, server_url = event_server notification_data: list[str] = [] - async def message_handler( + async def message_callback( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): # pragma: no branch @@ -2096,7 +2096,7 @@ async def message_handler( notification_data.append(str(message.params.data)) async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + async with ClientSession(read_stream, write_stream, message_callback=message_callback) as session: await session.initialize() # Tool sends: notification1, close_stream, notification2, notification3, response @@ -2186,7 +2186,7 @@ async def test_standalone_get_stream_reconnection(event_server: tuple[SimpleEven _, server_url = event_server received_notifications: list[str] = [] - async def message_handler( + async def message_callback( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): @@ -2196,7 +2196,7 @@ async def message_handler( received_notifications.append(str(message.params.uri)) async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + async with ClientSession(read_stream, write_stream, message_callback=message_callback) as session: await session.initialize() # Call tool that: