-
Notifications
You must be signed in to change notification settings - Fork 420
mcp: HTTP Header Standardization for x-mcp-header #915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f00d0f4
d00288d
57659c0
017e0fc
604f2d4
7df5ab6
9de3bec
f429bc5
ad17562
a4e1e23
aeada36
b223143
005d33d
24cc607
9dd6907
8d4e94d
d8c5a76
e754ff7
87093cb
04652ee
d7ef2c1
c218557
e83292b
9bdd1df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -160,6 +160,12 @@ type ClientOptions struct { | |
| KeepAlive time.Duration | ||
| } | ||
|
|
||
| // toolContextKeyType is the context key type for passing tool definitions | ||
| // from CallTool to the transport layer. | ||
| type toolContextKeyType struct{} | ||
|
|
||
| var toolContextKey = toolContextKeyType{} | ||
|
|
||
| // bind implements the binder[*ClientSession] interface, so that Clients can | ||
| // be connected using [connect]. | ||
| func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession { | ||
|
|
@@ -318,6 +324,13 @@ type ClientSession struct { | |
| // Pending URL elicitations waiting for completion notifications. | ||
| pendingElicitationsMu sync.Mutex | ||
| pendingElicitations map[string]chan struct{} | ||
|
|
||
| // toolCacheMu guards toolCache. | ||
| toolCacheMu sync.RWMutex | ||
| // toolCache stores tool definitions keyed by name. | ||
| // It is used to look up x-mcp-header annotations when | ||
| // constructing Mcp-Param-* headers for tools/call requests. | ||
| toolCache map[string]*Tool | ||
| } | ||
|
|
||
| type clientSessionState struct { | ||
|
|
@@ -363,6 +376,21 @@ func (cs *ClientSession) Wait() error { | |
| return cs.conn.Wait() | ||
| } | ||
|
|
||
| func (cs *ClientSession) cacheTools(tools []*Tool) { | ||
| cs.toolCacheMu.Lock() | ||
| defer cs.toolCacheMu.Unlock() | ||
| cs.toolCache = make(map[string]*Tool, len(tools)) | ||
| for _, tool := range tools { | ||
| cs.toolCache[tool.Name] = tool | ||
| } | ||
| } | ||
|
|
||
| func (cs *ClientSession) getCachedTool(name string) *Tool { | ||
| cs.toolCacheMu.RLock() | ||
| defer cs.toolCacheMu.RUnlock() | ||
| return cs.toolCache[name] | ||
| } | ||
|
|
||
| // registerElicitationWaiter registers a waiter for an elicitation complete | ||
| // notification with the given elicitation ID. It returns two functions: an await | ||
| // function that waits for the notification or context cancellation, and a cleanup | ||
|
|
@@ -981,7 +1009,13 @@ func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) | |
|
|
||
| // ListTools lists tools that are currently available on the server. | ||
| func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { | ||
| return handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) | ||
| result, err := handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| result.Tools = filterValidTools(cs.client.opts.Logger, result.Tools) | ||
| cs.cacheTools(result.Tools) | ||
| return result, nil | ||
| } | ||
|
|
||
| // CallTool calls the tool with the given parameters. | ||
|
|
@@ -995,6 +1029,9 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) ( | |
| // Avoid sending nil over the wire. | ||
| params.Arguments = map[string]any{} | ||
| } | ||
| if tool := cs.getCachedTool(params.Name); tool != nil { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please also think how modelcontextprotocol/modelcontextprotocol#2549 will factor in into this solution. |
||
| ctx = context.WithValue(ctx, toolContextKey, tool) | ||
| } | ||
| return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -491,6 +491,14 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque | |
| http.Error(w, "failed connection", http.StatusInternalServerError) | ||
| return | ||
| } | ||
| transport.connection.toolLookup = func(name string) *Tool { | ||
|
guglielmo-san marked this conversation as resolved.
|
||
| server.mu.Lock() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using the mutex from a different file is probably not a good practice. Consider exposing a private method on the server. |
||
| defer server.mu.Unlock() | ||
| if st, ok := server.tools.get(name); ok { | ||
| return st.tool | ||
| } | ||
| return nil | ||
| } | ||
| // Capture the user ID from the token info to enable session hijacking | ||
| // prevention on subsequent requests. | ||
| var userID string | ||
|
|
@@ -669,6 +677,8 @@ type streamableServerConn struct { | |
|
|
||
| logger *slog.Logger | ||
|
|
||
| toolLookup func(name string) *Tool | ||
|
|
||
| incoming chan jsonrpc.Message // messages from the client to the server | ||
|
|
||
| mu sync.Mutex // guards all fields below | ||
|
|
@@ -1186,9 +1196,15 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques | |
| } | ||
| } | ||
|
|
||
| // Validate MCP standard headers (Mcp-Method, Mcp-Name) | ||
| // Validate MCP standard headers (Mcp-Method, Mcp-Name, Mcp-Param-*) | ||
| if !isBatch && len(incoming) == 1 { | ||
| if err := validateMcpHeaders(req.Header, incoming[0]); err != nil { | ||
| var tool *Tool | ||
| if jreq, ok := incoming[0].(*jsonrpc.Request); ok && jreq.Method == "tools/call" && c.toolLookup != nil { | ||
| if name, ok := extractName(jreq.Method, jreq.Params); ok { | ||
| tool = c.toolLookup(name) | ||
| } | ||
| } | ||
| if err := validateMcpHeaders(req.Header, incoming[0], tool); err != nil { | ||
| resp := &jsonrpc.Response{ | ||
| Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()), | ||
| } | ||
|
|
@@ -1813,7 +1829,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e | |
| } | ||
| // Keep this after the setMCPHeaders call to ensure that the | ||
| // protocol version header is set. | ||
| setStandardHeaders(req.Header, msg) | ||
| setStandardHeaders(ctx, req.Header, msg) | ||
| resp, err := c.client.Do(req) | ||
| if err != nil { | ||
| // Any error from client.Do means the request didn't reach the server. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.