From ce161b6736fbfc35b217d2a34c7d2760b04b2d1f Mon Sep 17 00:00:00 2001 From: Kashish Hora Date: Tue, 31 Mar 2026 16:41:49 +0200 Subject: [PATCH] fix: extract client info per-request in stateless mode to prevent cross-user bleed --- pyproject.toml | 2 +- .../overrides/community/monkey_patch.py | 5 +- .../overrides/community_v3/middleware.py | 24 +++++-- src/mcpcat/modules/overrides/mcp_server.py | 33 +++++++-- .../overrides/official/monkey_patch.py | 8 ++- src/mcpcat/modules/session.py | 71 +++++++++++-------- tests/test_stateless.py | 49 ++++++++++++- 7 files changed, 144 insertions(+), 48 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5368355..aae0fb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mcpcat" -version = "0.1.15b1" +version = "0.1.15b2" description = "Analytics Tool for MCP Servers - provides insights into MCP tool usage patterns" authors = [ { name = "MCPCat", email = "support@mcpcat.io" }, diff --git a/src/mcpcat/modules/overrides/community/monkey_patch.py b/src/mcpcat/modules/overrides/community/monkey_patch.py index 2dc51fb..e1fc760 100644 --- a/src/mcpcat/modules/overrides/community/monkey_patch.py +++ b/src/mcpcat/modules/overrides/community/monkey_patch.py @@ -98,9 +98,10 @@ async def wrapped_call_tool_handler(request: CallToolRequest) -> ServerResult: # Handle session identification try: - get_client_info_from_request_context(lowlevel_server, request_context) + client_name, client_version = get_client_info_from_request_context(lowlevel_server, request_context) identity = identify_session(lowlevel_server, request, request_context) except Exception as e: + client_name, client_version = None, None identity = None write_to_log(f"Non-critical error in session handling: {e}") @@ -124,6 +125,8 @@ async def wrapped_call_tool_handler(request: CallToolRequest) -> ServerResult: identify_actor_given_id=identity.user_id if identity else None, identify_actor_name=identity.user_name if identity else None, identify_data=identity.user_data if identity else None, + client_name=client_name, + client_version=client_version, ) try: diff --git a/src/mcpcat/modules/overrides/community_v3/middleware.py b/src/mcpcat/modules/overrides/community_v3/middleware.py index 03d36db..3d27702 100644 --- a/src/mcpcat/modules/overrides/community_v3/middleware.py +++ b/src/mcpcat/modules/overrides/community_v3/middleware.py @@ -93,20 +93,24 @@ async def on_initialize( session_id = self._get_session_id() params = context.message.params - # Extract client info from initialize params + # Extract client info from initialize params (MCP protocol provides clientInfo here) + client_name, client_version = None, None if params and hasattr(params, "clientInfo") and params.clientInfo: client_info = params.clientInfo if hasattr(client_info, "name") and client_info.name: - self.mcpcat_data.session_info.client_name = client_info.name + client_name = client_info.name if hasattr(client_info, "version") and client_info.version: - self.mcpcat_data.session_info.client_version = client_info.version + client_version = client_info.version # Handle session identification # Note: Use self.server (FastMCP) not self.server._mcp_server because # tracking data is stored with the FastMCP server as the key for v3 request_context = self._get_request_context(context) try: - get_client_info_from_request_context(self.server, request_context) + if not client_name: + client_name, client_version = get_client_info_from_request_context(self.server, request_context) + else: + get_client_info_from_request_context(self.server, request_context) identity = identify_session(self.server, context.message, request_context) except Exception as e: identity = None @@ -120,6 +124,8 @@ async def on_initialize( identify_actor_given_id=identity.user_id if identity else None, identify_actor_name=identity.user_name if identity else None, identify_data=identity.user_data if identity else None, + client_name=client_name, + client_version=client_version, ) try: @@ -157,9 +163,10 @@ async def on_call_tool( # tracking data is stored with the FastMCP server as the key for v3 request_context = self._get_request_context(context) try: - get_client_info_from_request_context(self.server, request_context) + client_name, client_version = get_client_info_from_request_context(self.server, request_context) identity = identify_session(self.server, context.message, request_context) except Exception as e: + client_name, client_version = None, None identity = None write_to_log(f"Non-critical error in session handling: {e}") @@ -188,6 +195,8 @@ async def on_call_tool( identify_actor_given_id=identity.user_id if identity else None, identify_actor_name=identity.user_name if identity else None, identify_data=identity.user_data if identity else None, + client_name=client_name, + client_version=client_version, ) # Create modified context without context parameter if needed @@ -248,9 +257,10 @@ async def on_list_tools( # tracking data is stored with the FastMCP server as the key for v3 request_context = self._get_request_context(context) try: - get_client_info_from_request_context(self.server, request_context) + client_name, client_version = get_client_info_from_request_context(self.server, request_context) identity = identify_session(self.server, context.message, request_context) except Exception as e: + client_name, client_version = None, None identity = None write_to_log(f"Non-critical error in session handling: {e}") @@ -264,6 +274,8 @@ async def on_list_tools( identify_actor_given_id=identity.user_id if identity else None, identify_actor_name=identity.user_name if identity else None, identify_data=identity.user_data if identity else None, + client_name=client_name, + client_version=client_version, ) try: diff --git a/src/mcpcat/modules/overrides/mcp_server.py b/src/mcpcat/modules/overrides/mcp_server.py index 6310497..7506a77 100644 --- a/src/mcpcat/modules/overrides/mcp_server.py +++ b/src/mcpcat/modules/overrides/mcp_server.py @@ -43,6 +43,14 @@ async def wrapped_initialize_handler(request: InitializeRequest) -> ServerResult request_context = safe_request_context(server) identity = identify_session(server, request, request_context) + # Extract clientInfo from InitializeRequest params (MCP protocol provides it here) + client_name, client_version = None, None + if request.params and hasattr(request.params, 'clientInfo') and request.params.clientInfo: + client_name = request.params.clientInfo.name + client_version = getattr(request.params.clientInfo, 'version', None) + if not client_name: + client_name, client_version = get_client_info_from_request_context(server, request_context) + event = UnredactedEvent( session_id=session_id, timestamp=datetime.now(timezone.utc), @@ -51,13 +59,13 @@ async def wrapped_initialize_handler(request: InitializeRequest) -> ServerResult identify_actor_given_id=identity.user_id if identity else None, identify_actor_name=identity.user_name if identity else None, identify_data=identity.user_data if identity else None, + client_name=client_name, + client_version=client_version, ) # Call the original handler result = await original_initialize_handler(request) - # TODO: Grab client and server information from the request - # Record the event event.response = result.model_dump() if result else None event_queue.publish_event(server, event) @@ -67,7 +75,7 @@ async def wrapped_list_tools_handler(request: ListToolsRequest) -> ServerResult: """Intercept list_tools requests to add MCPCat tools and modify existing ones.""" session_id = get_server_session_id(server) request_context = safe_request_context(server) - get_client_info_from_request_context(server, request_context) + client_name, client_version = get_client_info_from_request_context(server, request_context) identity = identify_session(server, request, request_context) event = UnredactedEvent( @@ -80,6 +88,8 @@ async def wrapped_list_tools_handler(request: ListToolsRequest) -> ServerResult: identify_actor_given_id=identity.user_id if identity else None, identify_actor_name=identity.user_name if identity else None, identify_data=identity.user_data if identity else None, + client_name=client_name, + client_version=client_version, ) # Call the original handler to get the tools @@ -149,7 +159,7 @@ async def wrapped_call_tool_handler(request: CallToolRequest) -> ServerResult: arguments = request.params.arguments or {} session_id = get_server_session_id(server) request_context = safe_request_context(server) - get_client_info_from_request_context(server, request_context) + client_name, client_version = get_client_info_from_request_context(server, request_context) identity = identify_session(server, request, request_context) write_to_log( @@ -164,6 +174,8 @@ async def wrapped_call_tool_handler(request: CallToolRequest) -> ServerResult: identify_actor_given_id=identity.user_id if identity else None, identify_actor_name=identity.user_name if identity else None, identify_data=identity.user_data if identity else None, + client_name=client_name, + client_version=client_version, ) # Extract user intent from context (but don't pop yet - we need it for the event) @@ -237,6 +249,13 @@ async def wrapped_initialize_handler(request: InitializeRequest) -> ServerResult identity = None write_to_log(f"Ran into an error in session identification, no identity could be determined: {e}") + client_name, client_version = None, None + if request.params and hasattr(request.params, 'clientInfo') and request.params.clientInfo: + client_name = request.params.clientInfo.name + client_version = getattr(request.params.clientInfo, 'version', None) + if not client_name: + client_name, client_version = get_client_info_from_request_context(server, request_context) + event = UnredactedEvent( session_id=session_id, timestamp=datetime.now(timezone.utc), @@ -245,6 +264,8 @@ async def wrapped_initialize_handler(request: InitializeRequest) -> ServerResult identify_actor_given_id=identity.user_id if identity else None, identify_actor_name=identity.user_name if identity else None, identify_data=identity.user_data if identity else None, + client_name=client_name, + client_version=client_version, ) # Call the original handler @@ -259,7 +280,7 @@ async def wrapped_list_tools_handler(request: ListToolsRequest) -> ServerResult: """Intercept list_tools requests to track the event (tool modifications handled by monkey-patch).""" session_id = get_server_session_id(server) request_context = safe_request_context(server) - get_client_info_from_request_context(server, request_context) + client_name, client_version = get_client_info_from_request_context(server, request_context) identity = identify_session(server, request, request_context) event = UnredactedEvent( @@ -272,6 +293,8 @@ async def wrapped_list_tools_handler(request: ListToolsRequest) -> ServerResult: identify_actor_given_id=identity.user_id if identity else None, identify_actor_name=identity.user_name if identity else None, identify_data=identity.user_data if identity else None, + client_name=client_name, + client_version=client_version, ) # Call the original handler - tool modifications are handled by monkey-patch diff --git a/src/mcpcat/modules/overrides/official/monkey_patch.py b/src/mcpcat/modules/overrides/official/monkey_patch.py index ebe2e66..e535e1e 100644 --- a/src/mcpcat/modules/overrides/official/monkey_patch.py +++ b/src/mcpcat/modules/overrides/official/monkey_patch.py @@ -239,9 +239,9 @@ async def patched_call_tool( # Handle session identification (non-critical) try: request_context = safe_request_context(server._mcp_server) - # Only call if request_context is not None + client_name, client_version = (None, None) if request_context is not None: - get_client_info_from_request_context( + client_name, client_version = get_client_info_from_request_context( server._mcp_server, request_context ) @@ -261,9 +261,9 @@ async def patched_call_tool( identity = identify_session(server._mcp_server, mock_request, request_context) except Exception as e: + client_name, client_version = None, None identity = None write_to_log(f"Non-critical error in session handling: {e}") - # Continue without session identification # Extract user intent (non-critical) user_intent = None @@ -298,6 +298,8 @@ async def patched_call_tool( identify_actor_given_id=identity.user_id if identity else None, identify_actor_name=identity.user_name if identity else None, identify_data=identity.user_data if identity else None, + client_name=client_name, + client_version=client_version, ) except Exception as e: write_to_log(f"Error creating event: {e}") diff --git a/src/mcpcat/modules/session.py b/src/mcpcat/modules/session.py index 803a382..c109434 100644 --- a/src/mcpcat/modules/session.py +++ b/src/mcpcat/modules/session.py @@ -59,80 +59,89 @@ def get_headers_from_request_context( def get_client_info_from_request_context( server: Server, request_context: RequestContext | None -) -> None: +) -> tuple[str | None, str | None]: """Extract client information from request context or HTTP headers. + Returns (client_name, client_version). In stateless mode, extracts per-request + without caching. In stateful mode, caches on shared session_info. + This function is designed to be resilient and never fail - any error is logged but won't affect the server operation. """ # Handle None request_context (e.g., in stateless HTTP mode outside handlers) if request_context is None: write_to_log("Request context is None, skipping client info extraction") - return + return (None, None) try: data = get_server_tracking_data(server) if not data: - return + return (None, None) + + client_name: str | None = None + client_version: str | None = None - # If client name and version are already set, no need to fetch again - if data.session_info.client_name and data.session_info.client_version: - return + # In stateful mode, return cached values if already set + if not data.is_stateless and data.session_info.client_name and data.session_info.client_version: + return (data.session_info.client_name, data.session_info.client_version) try: - # Try to get from session (stateful mode) + # Try to get from MCP session (stateful mode) if hasattr(request_context, "session") and request_context.session: client_info = request_context.session.client_params.clientInfo if client_info: - data.session_info.client_name = client_info.name - data.session_info.client_version = client_info.version - set_server_tracking_data(server, data) - return - except (AttributeError, TypeError) as e: + client_name = client_info.name + client_version = client_info.version + if not data.is_stateless: + data.session_info.client_name = client_name + data.session_info.client_version = client_version + set_server_tracking_data(server, data) + return (client_name, client_version) + except (AttributeError, TypeError): # This is expected in stateless mode, just continue pass except Exception as e: - # Unexpected error, log but continue write_to_log(f"Error extracting client info from session: {e}") # Fallback: Try to extract from HTTP headers (stateless mode) try: headers = get_headers_from_request_context(request_context) if headers: - # Check User-Agent header + # Parse User-Agent header (format: "ClientName/Version ...") user_agent = headers.get("user-agent", "") if user_agent: - # Parse User-Agent for client info - # Format could be: "ClientName/Version (additional info)" match = re.match(r"^([^/]+)/([^\s]+)", user_agent) if match: - data.session_info.client_name = match.group(1) - data.session_info.client_version = match.group(2) + client_name = match.group(1) + client_version = match.group(2) else: - # If no neat match, use the whole string as client_name - data.session_info.client_name = user_agent + # No neat match, use the whole string as client_name + client_name = user_agent - # Also check custom MCP headers if any - # Clients might send: X-MCP-Client-Name, X-MCP-Client-Version + # Custom MCP headers override User-Agent if present if headers.get("x-mcp-client-name"): - data.session_info.client_name = headers.get("x-mcp-client-name") + client_name = headers.get("x-mcp-client-name") if headers.get("x-mcp-client-version"): - data.session_info.client_version = headers.get( - "x-mcp-client-version" - ) + client_version = headers.get("x-mcp-client-version") - if data.session_info.client_name or data.session_info.client_version: + if not data.is_stateless and (client_name or client_version): + data.session_info.client_name = client_name + data.session_info.client_version = client_version set_server_tracking_data(server, data) + + if client_name or client_version: write_to_log( - f"Extracted client info from headers: {data.session_info.client_name} v{data.session_info.client_version}" + f"Extracted client info from headers: {client_name} v{client_version}" ) except Exception as e: write_to_log(f"Error extracting client info from headers: {e}") # Continue without client info + + return (client_name, client_version) except Exception as e: # Catch-all for any unexpected errors - log but never fail write_to_log(f"Unexpected error in get_client_info_from_request_context: {e}") - # Function continues and returns normally + return (None, None) def get_session_info(server: Server, data: MCPCatData | None = None) -> SessionInfo: @@ -148,10 +157,10 @@ def get_session_info(server: Server, data: MCPCatData | None = None) -> SessionI server_name=server.name if hasattr(server, "name") else None, server_version=server.version if hasattr(server, "version") else None, client_name=data.session_info.client_name - if data and data.session_info + if data and data.session_info and not data.is_stateless else None, client_version=data.session_info.client_version - if data and data.session_info + if data and data.session_info and not data.is_stateless else None, identify_actor_given_id=actor_info.user_id if actor_info else None, identify_actor_name=actor_info.user_name if actor_info else None, diff --git a/tests/test_stateless.py b/tests/test_stateless.py index 6287ae9..f89bf9b 100644 --- a/tests/test_stateless.py +++ b/tests/test_stateless.py @@ -9,7 +9,7 @@ set_server_tracking_data, reset_all_tracking_data, ) -from mcpcat.modules.session import get_server_session_id +from mcpcat.modules.session import get_server_session_id, get_client_info_from_request_context from mcpcat.modules.identify import identify_session from mcpcat.types import MCPCatData, MCPCatOptions, SessionInfo, UserIdentity @@ -165,3 +165,50 @@ def test_stateless_identify_exception(self, mock_event_queue): assert result is None assert raising_fn.call_count == 1 + + def _make_request_context(self, user_agent): + """Create a mock request context with a User-Agent header.""" + ctx = MagicMock() + ctx.request.headers = {"user-agent": user_agent} + # No session attribute (stateless HTTP) + ctx.session = None + return ctx + + def test_stateless_client_info_per_request(self): + """In stateless mode, consecutive requests with different clients return different info.""" + self._setup_data(stateless=True) + + ctx1 = self._make_request_context("Cursor/2.6.22") + ctx2 = self._make_request_context("Claude Desktop/1.0") + + result1 = get_client_info_from_request_context(self.server, ctx1) + result2 = get_client_info_from_request_context(self.server, ctx2) + + assert result1 == ("Cursor", "2.6.22") + assert result2 == ("Claude Desktop", "1.0") + + def test_stateless_client_info_returns_values(self): + """In stateless mode, get_client_info_from_request_context returns a tuple.""" + self._setup_data(stateless=True) + + ctx = self._make_request_context("Cursor/2.6.22") + result = get_client_info_from_request_context(self.server, ctx) + + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0] == "Cursor" + assert result[1] == "2.6.22" + + def test_stateful_client_info_cached_across_requests(self): + """In stateful mode, client info is determined by the first request.""" + self._setup_data(stateless=False) + + ctx1 = self._make_request_context("Cursor/2.6.22") + ctx2 = self._make_request_context("Claude Desktop/1.0") + + get_client_info_from_request_context(self.server, ctx1) + get_client_info_from_request_context(self.server, ctx2) + + data = get_server_tracking_data(self.server) + assert data.session_info.client_name == "Cursor" + assert data.session_info.client_version == "2.6.22"