diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 267e176ee8..667478408c 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -4,6 +4,7 @@ import asyncio import base64 +import contextvars import json import logging import re @@ -59,6 +60,7 @@ class MCPSpecificApproval(TypedDict, total=False): _MCP_REMOTE_NAME_KEY = "_mcp_remote_name" _MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name" +_mcp_call_headers: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("_mcp_call_headers") # region: Helpers @@ -1385,6 +1387,7 @@ def __init__( client: SupportsChatGetResponse | None = None, additional_properties: dict[str, Any] | None = None, http_client: AsyncClient | None = None, + header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None, **kwargs: Any, ) -> None: """Initialize the MCP streamable HTTP tool. @@ -1432,6 +1435,11 @@ def __init__( ``streamable_http_client`` API will create and manage a default client. To configure headers, timeouts, or other HTTP client settings, create and pass your own ``asyncClient`` instance. + header_provider: Optional callable that receives the runtime keyword arguments + (from ``FunctionInvocationContext.kwargs``) and returns a ``dict[str, str]`` + of HTTP headers to inject into every outbound request to the MCP server. + Use this to forward per-request context (e.g. authentication tokens set in + agent middleware) without creating a separate ``httpx.AsyncClient``. kwargs: Additional keyword arguments (accepted for backward compatibility but not used). """ super().__init__( @@ -1452,6 +1460,7 @@ def __init__( self.url = url self.terminate_on_close = terminate_on_close self._httpx_client: AsyncClient | None = http_client + self._header_provider = header_provider def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: """Get an MCP streamable HTTP client. @@ -1459,18 +1468,60 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: Returns: An async context manager for the streamable HTTP client transport. """ - try: - from mcp.client.streamable_http import streamable_http_client - except ModuleNotFoundError as ex: - raise ModuleNotFoundError("`mcp` is required to use `MCPStreamableHTTPTool`. Please install `mcp`.") from ex + from httpx import AsyncClient + from mcp.client.streamable_http import streamable_http_client + + http_client = self._httpx_client + if self._header_provider is not None: + if http_client is None: + http_client = AsyncClient( + follow_redirects=True, + timeout=httpx.Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT), + ) + self._httpx_client = http_client + + if not hasattr(self, "_inject_headers_hook"): + + async def _inject_headers(request: httpx.Request) -> None: # noqa: RUF029 + headers = _mcp_call_headers.get({}) + for key, value in headers.items(): + request.headers[key] = value + + self._inject_headers_hook = _inject_headers # type: ignore[attr-defined] + http_client.event_hooks["request"].append(self._inject_headers_hook) # type: ignore[attr-defined] - # Pass the http_client (which may be None) to streamable_http_client return streamable_http_client( url=self.url, - http_client=self._httpx_client, + http_client=http_client, terminate_on_close=self.terminate_on_close if self.terminate_on_close is not None else True, ) + async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]: + """Call a tool, injecting headers from the header_provider if configured. + + When a ``header_provider`` was supplied at construction time, the runtime + *kwargs* (originating from ``FunctionInvocationContext.kwargs``) are passed + to the provider. The returned headers are attached to every HTTP request + made during this tool call via a ``contextvars.ContextVar``. + + Args: + tool_name: The name of the tool to call. + + Keyword Args: + kwargs: Arguments to pass to the tool. + + Returns: + A list of Content items representing the tool output. + """ + if self._header_provider is not None: + headers = self._header_provider(kwargs) + token = _mcp_call_headers.set(headers) + try: + return await super().call_tool(tool_name, **kwargs) + finally: + _mcp_call_headers.reset(token) + return await super().call_tool(tool_name, **kwargs) + class MCPWebsocketTool(MCPTool): """MCP tool for connecting to WebSocket-based MCP servers. diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index eb233eea99..1622ad6ca1 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -3796,4 +3796,377 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: assert meta is None +async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_client(): + """Test that calling get_mcp_client multiple times does not accumulate duplicate hooks.""" + tool = MCPStreamableHTTPTool( + name="test", + url="http://example.com/mcp", + header_provider=lambda kw: {"X-Token": kw.get("token", "")}, + ) + + try: + with patch("agent_framework._mcp.streamable_http_client"): + tool.get_mcp_client() + tool.get_mcp_client() + tool.get_mcp_client() + + assert tool._httpx_client is not None + hooks = tool._httpx_client.event_hooks.get("request", []) + assert len(hooks) == 1, f"Expected exactly one hook, got {len(hooks)}" + finally: + if getattr(tool, "_httpx_client", None) is not None: + await tool._httpx_client.aclose() + + +# endregion + + +# region: MCPStreamableHTTPTool header_provider + + +async def test_mcp_streamable_http_tool_header_provider_injects_headers(): + """Test that header_provider integrates with call_tool via runtime kwargs. + + When header_provider is configured, runtime kwargs from FunctionInvocationContext + are passed to the provider and the MCP session.call_tool is invoked successfully. + """ + + class _TestServer(MCPStreamableHTTPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="greet", + description="Says hello", + inputSchema={ + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Hello!")]) + ) + self.session.send_ping = AsyncMock() + self.is_connected = True + + def get_mcp_client(self): + return None + + def provider(kwargs): + return {"X-Some-Token": kwargs.get("some_token", "")} + + server = _TestServer( + name="test", + url="http://example.com/mcp", + header_provider=provider, + ) + async with server: + await server.load_tools() + + # Simulate the runtime kwargs that flow from FunctionInvocationContext.kwargs + await server.call_tool("greet", name="Alice", some_token="my-secret") + + # Verify the MCP session.call_tool was called + server.session.call_tool.assert_called_once() + + +async def test_mcp_streamable_http_tool_header_provider_sets_contextvar(): + """Test that call_tool sets the contextvar with headers from header_provider.""" + from agent_framework._mcp import _mcp_call_headers + + observed_headers: list[dict[str, str]] = [] + original_call_tool = MCPTool.call_tool + + async def spy_call_tool(self, tool_name, **kwargs): + # Capture the contextvar value during the super call + try: + observed_headers.append(_mcp_call_headers.get()) + except LookupError: + observed_headers.append({}) + return await original_call_tool(self, tool_name, **kwargs) + + class _TestServer(MCPStreamableHTTPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="greet", + description="Says hello", + inputSchema={"type": "object", "properties": {"name": {"type": "string"}}}, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Hello!")]) + ) + self.session.send_ping = AsyncMock() + self.is_connected = True + + def get_mcp_client(self): + return None + + server = _TestServer( + name="test", + url="http://example.com/mcp", + header_provider=lambda kw: {"X-Auth": kw.get("auth_token", "")}, + ) + async with server: + await server.load_tools() + + with patch.object(MCPTool, "call_tool", spy_call_tool): + await server.call_tool("greet", name="Alice", auth_token="bearer-xyz") + + assert len(observed_headers) == 1 + assert observed_headers[0] == {"X-Auth": "bearer-xyz"} + + +async def test_mcp_streamable_http_tool_header_provider_contextvar_reset_after_call(): + """Test that the contextvar is properly reset after call_tool completes.""" + from agent_framework._mcp import _mcp_call_headers + + class _TestServer(MCPStreamableHTTPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="greet", + description="Says hello", + inputSchema={"type": "object", "properties": {"name": {"type": "string"}}}, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Hello!")]) + ) + self.session.send_ping = AsyncMock() + self.is_connected = True + + def get_mcp_client(self): + return None + + server = _TestServer( + name="test", + url="http://example.com/mcp", + header_provider=lambda kw: {"X-Token": kw.get("token", "")}, + ) + async with server: + await server.load_tools() + await server.call_tool("greet", name="Alice", token="secret") + + # After call_tool, the contextvar should be unset (reset to no value) + with pytest.raises(LookupError): + _mcp_call_headers.get() + + +async def test_mcp_streamable_http_tool_without_header_provider(): + """Test that call_tool works normally when no header_provider is configured.""" + + class _TestServer(MCPStreamableHTTPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="greet", + description="Says hello", + inputSchema={"type": "object", "properties": {"name": {"type": "string"}}}, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Hello!")]) + ) + self.session.send_ping = AsyncMock() + self.is_connected = True + + def get_mcp_client(self): + return None + + server = _TestServer( + name="test", + url="http://example.com/mcp", + ) + async with server: + await server.load_tools() + await server.call_tool("greet", name="Alice") + server.session.call_tool.assert_called_once() + + # Without header_provider, call_tool should delegate directly to MCPTool + assert server._header_provider is None + + +async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook(): + """Test that the httpx event hook injects headers from the contextvar.""" + import httpx + + from agent_framework._mcp import MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT, _mcp_call_headers + + tool = MCPStreamableHTTPTool( + name="test", + url="http://example.com/mcp", + header_provider=lambda kw: {"X-Custom": kw.get("custom", "")}, + ) + + try: + with patch("agent_framework._mcp.streamable_http_client"): + # Trigger get_mcp_client to set up the event hook + tool.get_mcp_client() + + # The tool should have created an httpx client with the event hook + assert tool._httpx_client is not None + assert tool._httpx_client.follow_redirects is True + assert tool._httpx_client.timeout.connect == MCP_DEFAULT_TIMEOUT + assert tool._httpx_client.timeout.read == MCP_DEFAULT_SSE_READ_TIMEOUT + hooks = tool._httpx_client.event_hooks.get("request", []) + assert len(hooks) == 1, "Expected one request event hook" + + # Simulate what happens during a call_tool: contextvar is set + token = _mcp_call_headers.set({"X-Custom": "test-value"}) + try: + request = httpx.Request("POST", "http://example.com/mcp") + await hooks[0](request) + assert request.headers.get("X-Custom") == "test-value" + finally: + _mcp_call_headers.reset(token) + finally: + # Ensure any created httpx client is properly closed + if getattr(tool, "_httpx_client", None) is not None: + await tool._httpx_client.aclose() + + +async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client(): + """Test that header_provider works when the user provides their own httpx client.""" + import httpx + + from agent_framework._mcp import _mcp_call_headers + + user_client = httpx.AsyncClient(headers={"X-Base": "static"}) + + tool = MCPStreamableHTTPTool( + name="test", + url="http://example.com/mcp", + http_client=user_client, + header_provider=lambda kw: {"X-Dynamic": kw.get("dynamic", "")}, + ) + + with patch("agent_framework._mcp.streamable_http_client"): + tool.get_mcp_client() + + # The user's client should still be used + assert tool._httpx_client is user_client + hooks = user_client.event_hooks.get("request", []) + assert len(hooks) == 1 + + # Verify the hook injects headers + token = _mcp_call_headers.set({"X-Dynamic": "per-request"}) + try: + request = httpx.Request("POST", "http://example.com/mcp") + await hooks[0](request) + assert request.headers.get("X-Dynamic") == "per-request" + finally: + _mcp_call_headers.reset(token) + + await user_client.aclose() + + +async def test_mcp_streamable_http_tool_header_provider_via_invoke_with_context(): + """Test that header_provider receives kwargs via FunctionTool.invoke with FunctionInvocationContext. + + This exercises the full pipeline: FunctionInvocationContext.kwargs -> FunctionTool.invoke + -> MCPStreamableHTTPTool.call_tool -> header_provider. + """ + from agent_framework._mcp import _mcp_call_headers + + observed_headers: list[dict[str, str]] = [] + original_call_tool = MCPStreamableHTTPTool.call_tool + + async def spy_call_tool(self, tool_name, **kwargs): + # Capture the contextvar value set by call_tool before delegating + result = await original_call_tool(self, tool_name, **kwargs) + try: + observed_headers.append(_mcp_call_headers.get()) + except LookupError: + observed_headers.append({}) + return result + + class _TestServer(MCPStreamableHTTPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="greet", + description="Says hello", + inputSchema={ + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Hello!")]) + ) + self.session.send_ping = AsyncMock() + self.is_connected = True + + def get_mcp_client(self): + return None + + provider_received: list[dict] = [] + + def provider(kwargs): + provider_received.append(dict(kwargs)) + return {"X-Some-Token": kwargs.get("some_token", "")} + + server = _TestServer( + name="test", + url="http://example.com/mcp", + header_provider=provider, + ) + async with server: + await server.load_tools() + func = server.functions[0] + + # Build a FunctionInvocationContext with runtime kwargs, as the agent framework would + context = FunctionInvocationContext( + function=func, + arguments={"name": "Alice"}, + kwargs={"some_token": "my-secret"}, + ) + + with patch.object(MCPStreamableHTTPTool, "call_tool", spy_call_tool): + result = await func.invoke(arguments={"name": "Alice"}, context=context) + + # Verify the invoke produced a result + assert isinstance(result, list) + assert result[0].text == "Hello!" + + # Verify header_provider was called with the runtime kwargs + assert len(provider_received) == 1 + assert provider_received[0]["some_token"] == "my-secret" + + # Verify session.call_tool was called with the tool arguments (not the runtime kwargs) + server.session.call_tool.assert_called_once() + call_args = server.session.call_tool.call_args + assert call_args.kwargs.get("arguments", {}).get("name") == "Alice" + + # endregion