diff --git a/client/src/App.tsx b/client/src/App.tsx index 12e9a7bd0..59d15ba06 100644 --- a/client/src/App.tsx +++ b/client/src/App.tsx @@ -33,6 +33,7 @@ import { import { getToolUiResourceUri } from "@modelcontextprotocol/ext-apps/app-bridge"; import { AuthDebuggerState, EMPTY_DEBUGGER_STATE } from "./lib/auth-types"; import { OAuthStateMachine } from "./lib/oauth-state-machine"; +import { createProxyFetch } from "./lib/proxyFetch"; import { cacheToolOutputSchemas } from "./utils/schemaUtils"; import { cleanParams } from "./utils/paramUtils"; import type { JsonSchemaType } from "./utils/jsonUtils"; @@ -622,9 +623,17 @@ const App = () => { }; try { - const stateMachine = new OAuthStateMachine(sseUrl, (updates) => { - currentState = { ...currentState, ...updates }; - }); + const fetchFn = + connectionType === "proxy" && config + ? createProxyFetch(config) + : undefined; + const stateMachine = new OAuthStateMachine( + sseUrl, + (updates) => { + currentState = { ...currentState, ...updates }; + }, + fetchFn, + ); while ( currentState.oauthStep !== "complete" && @@ -662,7 +671,7 @@ const App = () => { }); } }, - [sseUrl], + [sseUrl, connectionType, config], ); useEffect(() => { @@ -1264,6 +1273,8 @@ const App = () => { onBack={() => setIsAuthDebuggerVisible(false)} authState={authState} updateAuthState={updateAuthState} + config={config} + connectionType={connectionType} /> ); diff --git a/client/src/__tests__/proxyFetchEndpoint.test.ts b/client/src/__tests__/proxyFetchEndpoint.test.ts new file mode 100644 index 000000000..4be75f6c4 --- /dev/null +++ b/client/src/__tests__/proxyFetchEndpoint.test.ts @@ -0,0 +1,234 @@ +/** + * Tests for the proxy server's POST /fetch endpoint. + * Spawns the server and hits it like any other HTTP client would. + */ +import { spawn, type ChildProcess } from "child_process"; +import { + createServer, + type IncomingMessage, + type Server, + type ServerResponse, +} from "http"; +import { resolve } from "path"; + +const TEST_PORT = 16321; +const TEST_TOKEN = "test-proxy-token-12345"; +const SERVER_PATH = resolve(__dirname, "../../../server/build/index.js"); + +/** Placeholder URL for tests where auth fails before the proxy fetches (no network). */ +const UNUSED_UPSTREAM_URL = "http://127.0.0.1:1/unused"; + +async function waitForServer(baseUrl: string, maxWaitMs = 5000): Promise { + const start = Date.now(); + while (Date.now() - start < maxWaitMs) { + try { + const res = await fetch(`${baseUrl}/health`); + if (res.ok) return; + } catch { + await new Promise((r) => setTimeout(r, 50)); + } + } + throw new Error("Server did not become ready"); +} + +/** + * Runs `fn` with a local HTTP server on 127.0.0.1:ephemeral-port. + * `origin` is `http://127.0.0.1:` (no trailing path). + */ +async function withLocalUpstream( + onRequest: (req: IncomingMessage, res: ServerResponse) => void, + fn: (origin: string) => Promise, +): Promise { + const upstream: Server = createServer(onRequest); + + await new Promise((resolve, reject) => { + upstream.once("error", reject); + upstream.listen(0, "127.0.0.1", () => resolve()); + }); + + const addr = upstream.address(); + if (!addr || typeof addr === "string") { + upstream.close(); + throw new Error("Expected TCP listen address"); + } + + const origin = `http://127.0.0.1:${addr.port}`; + + try { + await fn(origin); + } finally { + await new Promise((r) => upstream.close(() => r())); + } +} + +describe("POST /fetch endpoint", () => { + let server: ChildProcess; + const baseUrl = `http://localhost:${TEST_PORT}`; + + beforeAll(async () => { + server = spawn("node", [SERVER_PATH], { + env: { + ...process.env, + SERVER_PORT: String(TEST_PORT), + MCP_PROXY_AUTH_TOKEN: TEST_TOKEN, + }, + stdio: "ignore", + }); + await waitForServer(baseUrl); + }, 10000); + + afterAll(() => { + server.kill(); + }); + + it("returns 401 when no auth header", async () => { + const res = await fetch(`${baseUrl}/fetch`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + url: UNUSED_UPSTREAM_URL, + init: { method: "GET" }, + }), + }); + expect(res.status).toBe(401); + const body = await res.json(); + expect(body.error).toBe("Unauthorized"); + }); + + it("returns 401 when auth token is invalid", async () => { + const res = await fetch(`${baseUrl}/fetch`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-MCP-Proxy-Auth": "Bearer wrong-token", + }, + body: JSON.stringify({ + url: UNUSED_UPSTREAM_URL, + init: { method: "GET" }, + }), + }); + expect(res.status).toBe(401); + }); + + it("returns 400 for non-http(s) URL when auth token is valid", async () => { + const res = await fetch(`${baseUrl}/fetch`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-MCP-Proxy-Auth": `Bearer ${TEST_TOKEN}`, + }, + body: JSON.stringify({ + url: "file:///etc/passwd", + init: { method: "GET" }, + }), + }); + expect(res.status).toBe(400); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe("Only http/https URLs are allowed"); + }); + + it("returns 400 for invalid URL string when auth token is valid", async () => { + const res = await fetch(`${baseUrl}/fetch`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-MCP-Proxy-Auth": `Bearer ${TEST_TOKEN}`, + }, + body: JSON.stringify({ + url: "not a valid url", + init: { method: "GET" }, + }), + }); + expect(res.status).toBe(400); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe("Invalid URL"); + }); + + it("returns 400 when url is missing when auth token is valid", async () => { + const res = await fetch(`${baseUrl}/fetch`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-MCP-Proxy-Auth": `Bearer ${TEST_TOKEN}`, + }, + body: JSON.stringify({ init: { method: "GET" } }), + }); + expect(res.status).toBe(400); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe("Missing or invalid url"); + }); + + it("forwards request when auth token is valid", async () => { + const upstreamPayload = JSON.stringify({ hello: "proxy-fetch-test" }); + + await withLocalUpstream( + (req, res) => { + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(upstreamPayload); + }, + async (origin) => { + const upstreamUrl = `${origin}/ok`; + + const res = await fetch(`${baseUrl}/fetch`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-MCP-Proxy-Auth": `Bearer ${TEST_TOKEN}`, + }, + body: JSON.stringify({ + url: upstreamUrl, + init: { method: "GET" }, + }), + }); + + expect(res.status).toBe(200); + const body = (await res.json()) as { + ok: boolean; + status: number; + statusText: string; + body: string; + headers: Record; + }; + expect(body.ok).toBe(true); + expect(body.status).toBe(200); + expect(body.statusText).toBe("OK"); + expect(body.body).toBe(upstreamPayload); + expect(body.headers["content-type"]).toMatch(/application\/json/i); + }, + ); + }); + + it("mirrors upstream 404 (non-2xx) when auth token is valid", async () => { + await withLocalUpstream( + (req, res) => { + res.writeHead(404, { "Content-Type": "application/json" }); + res.end('{"error":"not_found"}'); + }, + async (origin) => { + const upstreamUrl = `${origin}/missing`; + + const res = await fetch(`${baseUrl}/fetch`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-MCP-Proxy-Auth": `Bearer ${TEST_TOKEN}`, + }, + body: JSON.stringify({ + url: upstreamUrl, + init: { method: "GET" }, + }), + }); + + expect(res.status).toBe(404); + const body = (await res.json()) as { + ok: boolean; + status: number; + body: string; + }; + expect(body.ok).toBe(false); + expect(body.status).toBe(404); + expect(JSON.parse(body.body)).toEqual({ error: "not_found" }); + }, + ); + }); +}); diff --git a/client/src/components/AuthDebugger.tsx b/client/src/components/AuthDebugger.tsx index 6252c1161..376817d67 100644 --- a/client/src/components/AuthDebugger.tsx +++ b/client/src/components/AuthDebugger.tsx @@ -5,14 +5,18 @@ import { AlertCircle } from "lucide-react"; import { AuthDebuggerState, EMPTY_DEBUGGER_STATE } from "../lib/auth-types"; import { OAuthFlowProgress } from "./OAuthFlowProgress"; import { OAuthStateMachine } from "../lib/oauth-state-machine"; +import { createProxyFetch } from "../lib/proxyFetch"; import { SESSION_KEYS } from "../lib/constants"; import { validateRedirectUrl } from "@/utils/urlValidation"; +import type { InspectorConfig } from "../lib/configurationTypes"; export interface AuthDebuggerProps { serverUrl: string; onBack: () => void; authState: AuthDebuggerState; updateAuthState: (updates: Partial) => void; + config?: InspectorConfig; + connectionType?: "direct" | "proxy"; } interface StatusMessageProps { @@ -60,6 +64,8 @@ const AuthDebugger = ({ onBack, authState, updateAuthState, + config, + connectionType, }: AuthDebuggerProps) => { // Check for existing tokens on mount useEffect(() => { @@ -102,9 +108,17 @@ const AuthDebugger = ({ }); }, [serverUrl, updateAuthState]); + const fetchFn = useMemo( + () => + connectionType === "proxy" && config + ? createProxyFetch(config) + : undefined, + [connectionType, config], + ); + const stateMachine = useMemo( - () => new OAuthStateMachine(serverUrl, updateAuthState), - [serverUrl, updateAuthState], + () => new OAuthStateMachine(serverUrl, updateAuthState, fetchFn), + [serverUrl, updateAuthState, fetchFn], ); const proceedToNextStep = useCallback(async () => { @@ -150,11 +164,15 @@ const AuthDebugger = ({ latestError: null, }; - const oauthMachine = new OAuthStateMachine(serverUrl, (updates) => { - // Update our temporary state during the process - currentState = { ...currentState, ...updates }; - // But don't call updateAuthState yet - }); + const oauthMachine = new OAuthStateMachine( + serverUrl, + (updates) => { + // Update our temporary state during the process + currentState = { ...currentState, ...updates }; + // But don't call updateAuthState yet + }, + fetchFn, + ); // Manually step through each stage of the OAuth flow while (currentState.oauthStep !== "complete") { @@ -214,7 +232,7 @@ const AuthDebugger = ({ } finally { updateAuthState({ isInitiatingAuth: false }); } - }, [serverUrl, updateAuthState, authState]); + }, [serverUrl, updateAuthState, authState, fetchFn]); const handleClearOAuth = useCallback(() => { if (serverUrl) { diff --git a/client/src/components/__tests__/AuthDebugger.test.tsx b/client/src/components/__tests__/AuthDebugger.test.tsx index fec876778..71eec04aa 100644 --- a/client/src/components/__tests__/AuthDebugger.test.tsx +++ b/client/src/components/__tests__/AuthDebugger.test.tsx @@ -1,3 +1,4 @@ +import React from "react"; import { render, screen, @@ -9,7 +10,7 @@ import "@testing-library/jest-dom"; import { describe, it, beforeEach, jest } from "@jest/globals"; import AuthDebugger, { AuthDebuggerProps } from "../AuthDebugger"; import { TooltipProvider } from "../ui/tooltip"; -import { SESSION_KEYS } from "../../lib/constants"; +import { SESSION_KEYS, DEFAULT_INSPECTOR_CONFIG } from "../../lib/constants"; const mockOAuthTokens = { access_token: "test_access_token", @@ -58,7 +59,7 @@ import { OAuthMetadata } from "@modelcontextprotocol/sdk/shared/auth.js"; import { EMPTY_DEBUGGER_STATE } from "../../lib/auth-types"; // Mock local auth module -jest.mock("@/lib/auth", () => ({ +jest.mock("../../lib/auth", () => ({ DebugInspectorOAuthClientProvider: jest.fn().mockImplementation(() => ({ tokens: jest.fn().mockImplementation(() => Promise.resolve(undefined)), clear: jest.fn().mockImplementation(() => { @@ -269,6 +270,7 @@ describe("AuthDebugger", () => { // Should first discover and save OAuth metadata expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledWith( new URL("https://example.com/"), + { fetchFn: undefined }, ); // Check that updateAuthState was called with the right info message @@ -404,6 +406,65 @@ describe("AuthDebugger", () => { }); }); + describe("Proxy Fetch integration", () => { + it("passes fetchFn to SDK when connectionType is proxy", async () => { + const configWithProxy = { + ...DEFAULT_INSPECTOR_CONFIG, + MCP_PROXY_FULL_ADDRESS: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_FULL_ADDRESS, + value: "http://localhost:6277", + }, + MCP_PROXY_AUTH_TOKEN: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_AUTH_TOKEN, + value: "test-proxy-token", + }, + }; + + await act(async () => { + renderAuthDebugger({ + config: configWithProxy, + connectionType: "proxy", + authState: { + ...defaultAuthState, + isInitiatingAuth: false, + oauthStep: "metadata_discovery", + }, + }); + }); + + await act(async () => { + fireEvent.click(screen.getByText("Continue")); + }); + + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledWith( + new URL("https://example.com/"), + { fetchFn: expect.any(Function) }, + ); + }); + + it("passes undefined fetchFn when connectionType is direct", async () => { + await act(async () => { + renderAuthDebugger({ + connectionType: "direct", + authState: { + ...defaultAuthState, + isInitiatingAuth: false, + oauthStep: "metadata_discovery", + }, + }); + }); + + await act(async () => { + fireEvent.click(screen.getByText("Continue")); + }); + + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledWith( + new URL("https://example.com/"), + { fetchFn: undefined }, + ); + }); + }); + describe("OAuth Flow Steps", () => { it("should handle OAuth flow step progression", async () => { const updateAuthState = jest.fn(); @@ -428,6 +489,7 @@ describe("AuthDebugger", () => { expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledWith( new URL("https://example.com/"), + { fetchFn: undefined }, ); }); @@ -725,6 +787,8 @@ describe("AuthDebugger", () => { await waitFor(() => { expect(mockDiscoverOAuthProtectedResourceMetadata).toHaveBeenCalledWith( "https://example.com/mcp", + {}, + undefined, ); }); @@ -743,6 +807,7 @@ describe("AuthDebugger", () => { expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledWith( new URL("https://custom-auth.example.com/mcp/tenant"), + { fetchFn: undefined }, ); }); @@ -779,6 +844,8 @@ describe("AuthDebugger", () => { await waitFor(() => { expect(mockDiscoverOAuthProtectedResourceMetadata).toHaveBeenCalledWith( "https://example.com/mcp", + {}, + undefined, ); }); @@ -797,6 +864,7 @@ describe("AuthDebugger", () => { // Verify that regular OAuth metadata discovery was still called expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledWith( new URL("https://example.com/"), + { fetchFn: undefined }, ); }); }); diff --git a/client/src/lib/__tests__/auth.test.ts b/client/src/lib/__tests__/auth.test.ts index 329b7f027..03c503d81 100644 --- a/client/src/lib/__tests__/auth.test.ts +++ b/client/src/lib/__tests__/auth.test.ts @@ -133,7 +133,10 @@ describe("discoverScopes", () => { expect(result).toBe(expected); if (expectedCallUrl) { - expect(mockDiscoverAuth).toHaveBeenCalledWith(new URL(expectedCallUrl)); + expect(mockDiscoverAuth).toHaveBeenCalledWith( + new URL(expectedCallUrl), + { fetchFn: undefined }, + ); } }, ); diff --git a/client/src/lib/__tests__/mcpProxyTransportErrorCode.test.ts b/client/src/lib/__tests__/mcpProxyTransportErrorCode.test.ts new file mode 100644 index 000000000..7e0a33ac9 --- /dev/null +++ b/client/src/lib/__tests__/mcpProxyTransportErrorCode.test.ts @@ -0,0 +1,24 @@ +/** + * Compare client vs server numeric literal without importing server/mcpProxy.ts: + * that module pulls in the SDK graph and breaks Jest (e.g. optional peer deps). + */ +import { readFileSync } from "fs"; +import { resolve } from "path"; +import { MCP_PROXY_TRANSPORT_ERROR_CODE as clientCode } from "../constants"; + +const EXPORT_RE = + /export const MCP_PROXY_TRANSPORT_ERROR_CODE\s*=\s*(-?\d+)\s*;/; + +describe("MCP_PROXY_TRANSPORT_ERROR_CODE", () => { + it("matches server/src/mcpProxy.ts (avoid silent drift between client and server)", () => { + const serverSrcPath = resolve( + __dirname, + "../../../../server/src/mcpProxy.ts", + ); + const serverSrc = readFileSync(serverSrcPath, "utf-8"); + const match = serverSrc.match(EXPORT_RE); + expect(match).not.toBeNull(); + const serverCode = Number(match![1]); + expect(clientCode).toBe(serverCode); + }); +}); diff --git a/client/src/lib/__tests__/proxyFetch.test.ts b/client/src/lib/__tests__/proxyFetch.test.ts new file mode 100644 index 000000000..5ac239eb9 --- /dev/null +++ b/client/src/lib/__tests__/proxyFetch.test.ts @@ -0,0 +1,244 @@ +import { createProxyFetch } from "../proxyFetch"; +import { DEFAULT_INSPECTOR_CONFIG } from "../constants"; +import type { InspectorConfig } from "../configurationTypes"; + +describe("createProxyFetch", () => { + const mockFetch = jest.fn(); + const proxyAddress = "http://localhost:6277"; + + const configWithProxy: InspectorConfig = { + ...DEFAULT_INSPECTOR_CONFIG, + MCP_PROXY_FULL_ADDRESS: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_FULL_ADDRESS, + value: proxyAddress, + }, + MCP_PROXY_AUTH_TOKEN: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_AUTH_TOKEN, + value: "test-proxy-token", + }, + }; + + beforeEach(() => { + jest.clearAllMocks(); + global.fetch = mockFetch; + }); + + it("returns a function", () => { + const fetchFn = createProxyFetch(configWithProxy); + expect(typeof fetchFn).toBe("function"); + }); + + it("sends POST to proxy /fetch endpoint with correct headers", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + ok: true, + status: 200, + statusText: "OK", + headers: {}, + body: "response body", + }), + }); + + const fetchFn = createProxyFetch(configWithProxy); + await fetchFn("https://example.com/.well-known/oauth-authorization-server"); + + expect(mockFetch).toHaveBeenCalledTimes(1); + expect(mockFetch).toHaveBeenCalledWith(`${proxyAddress}/fetch`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-MCP-Proxy-Auth": "Bearer test-proxy-token", + }, + body: expect.any(String), + }); + }); + + it("includes target url and init in request body", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + ok: true, + status: 200, + statusText: "OK", + headers: { "content-type": "application/json" }, + body: '{"issuer":"https://example.com"}', + }), + }); + + const fetchFn = createProxyFetch(configWithProxy); + await fetchFn("https://example.com/oauth/token", { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded" }, + body: "grant_type=authorization_code&code=abc", + }); + + const callBody = JSON.parse(mockFetch.mock.calls[0][1].body); + expect(callBody).toEqual({ + url: "https://example.com/oauth/token", + init: { + method: "POST", + headers: { "content-type": "application/x-www-form-urlencoded" }, + body: "grant_type=authorization_code&code=abc", + }, + }); + }); + + it("reconstructs Response from proxy response", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + ok: true, + status: 200, + statusText: "OK", + headers: { "content-type": "application/json" }, + body: '{"issuer":"https://example.com"}', + }), + }); + + const fetchFn = createProxyFetch(configWithProxy); + const response = await fetchFn( + "https://example.com/.well-known/oauth-authorization-server", + ); + + expect(response.ok).toBe(true); + expect(response.status).toBe(200); + expect(response.statusText).toBe("OK"); + expect(response.headers.get("content-type")).toBe("application/json"); + const body = await response.text(); + expect(body).toBe('{"issuer":"https://example.com"}'); + }); + + it("returns non-ok Response when upstream status is not 2xx (mirrored by proxy)", async () => { + mockFetch.mockResolvedValue({ + ok: false, + status: 404, + statusText: "Not Found", + json: () => + Promise.resolve({ + ok: false, + status: 404, + statusText: "Not Found", + headers: { "content-type": "application/json" }, + body: '{"error":"not_found"}', + }), + }); + + const fetchFn = createProxyFetch(configWithProxy); + const response = await fetchFn("https://example.com/.well-known/missing"); + + expect(response.ok).toBe(false); + expect(response.status).toBe(404); + expect(response.statusText).toBe("Not Found"); + expect(response.headers.get("content-type")).toBe("application/json"); + expect(await response.text()).toBe('{"error":"not_found"}'); + }); + + it("returns non-ok Response when upstream returns 400 with token-endpoint error JSON (RFC 6749)", async () => { + const tokenErrorBody = + '{"error":"invalid_grant","error_description":"code expired"}'; + mockFetch.mockResolvedValue({ + ok: false, + status: 400, + statusText: "Bad Request", + json: () => + Promise.resolve({ + ok: false, + status: 400, + statusText: "Bad Request", + headers: { "content-type": "application/json" }, + body: tokenErrorBody, + }), + }); + + const fetchFn = createProxyFetch(configWithProxy); + const response = await fetchFn("https://example.com/oauth/token", { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded" }, + body: "grant_type=authorization_code&code=x", + }); + + expect(response.ok).toBe(false); + expect(response.status).toBe(400); + expect(await response.json()).toEqual({ + error: "invalid_grant", + error_description: "code expired", + }); + }); + + it("throws when proxy POST returns JSON error envelope (e.g. 401 invalid session)", async () => { + mockFetch.mockResolvedValue({ + ok: false, + status: 401, + statusText: "Unauthorized", + json: () => + Promise.resolve({ + error: "Unauthorized", + message: + "Authentication required. Use the session token shown in the console when starting the server.", + }), + }); + + const fetchFn = createProxyFetch(configWithProxy); + await expect(fetchFn("https://example.com/")).rejects.toThrow( + "Authentication required. Use the session token shown in the console when starting the server.", + ); + }); + + it("throws when proxy response is not valid JSON", async () => { + mockFetch.mockResolvedValue({ + ok: false, + status: 502, + statusText: "Bad Gateway", + json: () => Promise.reject(new SyntaxError("Unexpected token")), + }); + + const fetchFn = createProxyFetch(configWithProxy); + await expect(fetchFn("https://example.com/")).rejects.toThrow( + "Proxy fetch failed: 502 Bad Gateway", + ); + }); + + it("uses URL object as input", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + ok: true, + status: 200, + statusText: "OK", + headers: {}, + body: "", + }), + }); + + const fetchFn = createProxyFetch(configWithProxy); + await fetchFn(new URL("https://example.com/discovery")); + + const callBody = JSON.parse(mockFetch.mock.calls[0][1].body); + expect(callBody.url).toBe("https://example.com/discovery"); + }); + + it("uses Request.url when input is a Request (not [object Request])", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + ok: true, + status: 200, + statusText: "OK", + headers: {}, + body: "", + }), + }); + + const fetchFn = createProxyFetch(configWithProxy); + await fetchFn(new Request("https://example.com/from-request")); + + const callBody = JSON.parse(mockFetch.mock.calls[0][1].body); + expect(callBody.url).toBe("https://example.com/from-request"); + }); +}); diff --git a/client/src/lib/auth.ts b/client/src/lib/auth.ts index b8366ef6e..f0fc2fc4b 100644 --- a/client/src/lib/auth.ts +++ b/client/src/lib/auth.ts @@ -22,10 +22,12 @@ import { validateRedirectUrl } from "@/utils/urlValidation"; export const discoverScopes = async ( serverUrl: string, resourceMetadata?: OAuthProtectedResourceMetadata, + fetchFn?: typeof fetch, ): Promise => { try { const metadata = await discoverAuthorizationServerMetadata( new URL("/", serverUrl), + { fetchFn }, ); // Prefer resource metadata scopes, but fall back to OAuth metadata if empty diff --git a/client/src/lib/connectionAuthErrors.ts b/client/src/lib/connectionAuthErrors.ts new file mode 100644 index 000000000..b061d37eb --- /dev/null +++ b/client/src/lib/connectionAuthErrors.ts @@ -0,0 +1,41 @@ +import { SseError } from "@modelcontextprotocol/sdk/client/sse.js"; +import { StreamableHTTPError } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; +import { UnauthorizedError } from "@modelcontextprotocol/sdk/client/auth.js"; +import { McpError } from "@modelcontextprotocol/sdk/types.js"; +import { MCP_PROXY_TRANSPORT_ERROR_CODE } from "./constants"; + +function isPlainObject(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +/** `McpError.data` from server `mcpProxy` `serializeProxyTransportError`. */ +export function mcpProxyTransportErrorDataIndicatesUnauthorized( + data: Record, +): boolean { + if ("upstream401" in data) { + const snapshot = data.upstream401; + if (snapshot != null) return true; + } + const status = data.httpStatus; + return typeof status === "number" && status === 401; +} + +/** + * Whether `handleAuthError` / OAuth recovery should run for this failure. + */ +export function isConnectionAuthError(error: unknown): boolean { + if (error instanceof SseError && error.code === 401) return true; + if (error instanceof StreamableHTTPError && error.code === 401) return true; + if (error instanceof UnauthorizedError) return true; + + if ( + error instanceof McpError && + error.code === MCP_PROXY_TRANSPORT_ERROR_CODE && + isPlainObject(error.data) && + mcpProxyTransportErrorDataIndicatesUnauthorized(error.data) + ) { + return true; + } + + return false; +} diff --git a/client/src/lib/constants.ts b/client/src/lib/constants.ts index 6cb1a02cc..d986d3802 100644 --- a/client/src/lib/constants.ts +++ b/client/src/lib/constants.ts @@ -37,6 +37,15 @@ export type ConnectionStatus = export const DEFAULT_MCP_PROXY_LISTEN_PORT = "6277"; +/** + * JSON-RPC error code when the inspector proxy cannot complete a forward to the upstream MCP server. + * **-32099** — high end of JSON-RPC’s -32000..-32099 server-error band to avoid MCP-assigned codes + * (see comment on `server/src/mcpProxy.ts` `MCP_PROXY_TRANSPORT_ERROR_CODE`). Duplicated like + * `DEFAULT_MCP_PROXY_LISTEN_PORT`. Keep in sync with the server value; + * `src/lib/__tests__/mcpProxyTransportErrorCode.test.ts` fails if they drift. + */ +export const MCP_PROXY_TRANSPORT_ERROR_CODE = -32099; + /** * Default configuration for the MCP Inspector, Currently persisted in local_storage in the Browser. * Future plans: Provide json config file + Browser local_storage to override default values diff --git a/client/src/lib/hooks/__tests__/useConnection.test.tsx b/client/src/lib/hooks/__tests__/useConnection.test.tsx index 4907a085b..1d4f4bd0f 100644 --- a/client/src/lib/hooks/__tests__/useConnection.test.tsx +++ b/client/src/lib/hooks/__tests__/useConnection.test.tsx @@ -5,12 +5,17 @@ import { ClientRequest, CreateTaskResultSchema, JSONRPCMessage, + McpError, } from "@modelcontextprotocol/sdk/types.js"; import type { AnySchema, SchemaOutput, } from "@modelcontextprotocol/sdk/server/zod-compat.js"; -import { DEFAULT_INSPECTOR_CONFIG, CLIENT_IDENTITY } from "../../constants"; +import { + DEFAULT_INSPECTOR_CONFIG, + CLIENT_IDENTITY, + MCP_PROXY_TRANSPORT_ERROR_CODE, +} from "../../constants"; import { SSEClientTransportOptions, SseError, @@ -94,17 +99,38 @@ jest.mock("@modelcontextprotocol/sdk/client/sse.js", () => { }; }); -jest.mock("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({ - StreamableHTTPClientTransport: jest.fn((url, options) => { - mockStreamableHTTPTransport.url = url; - mockStreamableHTTPTransport.options = options; - return mockStreamableHTTPTransport; - }), -})); +jest.mock("@modelcontextprotocol/sdk/client/streamableHttp.js", () => { + class StreamableHTTPError extends Error { + code: number; + constructor(code: number, message: string) { + super(`Streamable HTTP error: ${message}`); + this.code = code; + } + } -jest.mock("@modelcontextprotocol/sdk/client/auth.js", () => ({ - auth: jest.fn().mockResolvedValue("AUTHORIZED"), -})); + return { + StreamableHTTPError, + StreamableHTTPClientTransport: jest.fn((url, options) => { + mockStreamableHTTPTransport.url = url; + mockStreamableHTTPTransport.options = options; + return mockStreamableHTTPTransport; + }), + }; +}); + +jest.mock("@modelcontextprotocol/sdk/client/auth.js", () => { + class UnauthorizedError extends Error { + constructor(message?: string) { + super(message ?? "Unauthorized"); + this.name = "UnauthorizedError"; + } + } + + return { + UnauthorizedError, + auth: jest.fn().mockResolvedValue("AUTHORIZED"), + }; +}); // Mock the toast hook const mockToast = jest.fn(); @@ -322,7 +348,7 @@ describe("useConnection", () => { const [, samplingHandler] = samplingHandlerCall; // Invoke handler; should return a CreateTaskResult immediately - let createTaskResult: SchemaOutput; + let createTaskResult!: SchemaOutput; await act(async () => { createTaskResult = await samplingHandler(samplingRequest); }); @@ -449,7 +475,7 @@ describe("useConnection", () => { }); expect(elicitRequestHandlerCall).toBeDefined(); - const [, handler] = elicitRequestHandlerCall; + const [, handler] = elicitRequestHandlerCall!; mockOnElicitationRequest.mockImplementation((_request, resolve) => { resolve({ action: "accept", content: { name: "test" } }); @@ -640,7 +666,7 @@ describe("useConnection", () => { }); expect(elicitRequestHandlerCall).toBeDefined(); - const [, handler] = elicitRequestHandlerCall; + const [, handler] = elicitRequestHandlerCall!; const mockElicitationRequest: ElicitRequest = { method: "elicitation/create", @@ -707,7 +733,8 @@ describe("useConnection", () => { } }); - const [, handler] = elicitRequestHandlerCall; + expect(elicitRequestHandlerCall).toBeDefined(); + const [, handler] = elicitRequestHandlerCall!; const mockElicitationRequest: ElicitRequest = { method: "elicitation/create", @@ -732,7 +759,7 @@ describe("useConnection", () => { resolve(mockResponse); }); - let handlerResult; + let handlerResult!: ElicitResult; await act(async () => { handlerResult = await handler(mockElicitationRequest); }); @@ -1514,15 +1541,19 @@ describe("useConnection", () => { expect(mockDiscoverScopes).toHaveBeenCalledWith( defaultProps.sseUrl, undefined, + expect.any(Function), // fetchFn when connectionType is proxy ); } else { expect(mockDiscoverScopes).not.toHaveBeenCalled(); } - expect(mockAuth).toHaveBeenCalledWith(expect.any(Object), { - serverUrl: defaultProps.sseUrl, - scope: expectedAuthScope, - }); + expect(mockAuth).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + serverUrl: defaultProps.sseUrl, + scope: expectedAuthScope, + }), + ); }, ); @@ -1538,11 +1569,101 @@ describe("useConnection", () => { expect(mockDiscoverScopes).toHaveBeenCalledWith( defaultProps.sseUrl, undefined, + expect.any(Function), // fetchFn when connectionType is proxy + ); + expect(mockAuth).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + serverUrl: defaultProps.sseUrl, + scope: undefined, + }), + ); + }); + + it("passes undefined fetchFn when connectionType is direct", async () => { + mockDiscoverScopes.mockResolvedValue("read write"); + setup401Error(); + + const directProps = { + ...defaultProps, + connectionType: "direct" as const, + }; + await attemptConnection(directProps); + + expect(mockDiscoverScopes).toHaveBeenCalledWith( + defaultProps.sseUrl, + undefined, + undefined, // fetchFn is undefined for direct ); - expect(mockAuth).toHaveBeenCalledWith(expect.any(Object), { - serverUrl: defaultProps.sseUrl, - scope: undefined, + expect(mockAuth).toHaveBeenCalledWith( + expect.any(Object), + expect.not.objectContaining({ fetchFn: expect.anything() }), + ); + }); + }); + + describe("Inspector proxy McpError auth recovery", () => { + beforeEach(() => { + jest.clearAllMocks(); + mockAuth.mockResolvedValue("AUTHORIZED"); + mockDiscoverScopes.mockResolvedValue(undefined); + mockClient.connect.mockResolvedValue(undefined); + }); + + const attemptConnect = async ( + props: Parameters[0] = defaultProps, + ) => { + const { result } = renderHook(() => useConnection(props)); + await act(async () => { + try { + await result.current.connect(); + } catch { + // connect may throw when auth recovery does not retry + } }); + }; + + it("invokes auth when connect fails with inspector proxy transport McpError and upstream401 data", async () => { + mockClient.connect.mockRejectedValueOnce( + new McpError(MCP_PROXY_TRANSPORT_ERROR_CODE, "proxy transport", { + upstream401: { body: "{}", contentType: "application/json" }, + }), + ); + await attemptConnect(); + expect(mockAuth).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + serverUrl: defaultProps.sseUrl, + }), + ); + }); + + it("invokes auth when connect fails with inspector proxy transport McpError and httpStatus 401", async () => { + mockClient.connect.mockRejectedValueOnce( + new McpError(MCP_PROXY_TRANSPORT_ERROR_CODE, "proxy transport", { + httpStatus: 401, + }), + ); + await attemptConnect(); + expect(mockAuth).toHaveBeenCalled(); + }); + + it("does not invoke auth for inspector proxy McpError without auth payload", async () => { + mockClient.connect.mockRejectedValueOnce( + new McpError(MCP_PROXY_TRANSPORT_ERROR_CODE, "proxy transport", { + message: "upstream failure", + }), + ); + await attemptConnect(); + expect(mockAuth).not.toHaveBeenCalled(); + }); + + it("does not invoke auth when httpStatus is 401 but JSON-RPC code is not inspector proxy", async () => { + mockClient.connect.mockRejectedValueOnce( + new McpError(-32603, "Internal error", { httpStatus: 401 }), + ); + await attemptConnect(); + expect(mockAuth).not.toHaveBeenCalled(); }); }); diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index e14d1037f..016f8aa4f 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -1,13 +1,11 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { SSEClientTransport, - SseError, SSEClientTransportOptions, } from "@modelcontextprotocol/sdk/client/sse.js"; import { StreamableHTTPClientTransport, StreamableHTTPClientTransportOptions, - StreamableHTTPError, } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; import { ClientNotification, @@ -51,6 +49,7 @@ import { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js"; import { useEffect, useRef, useState } from "react"; import { useToast } from "@/lib/hooks/useToast"; import { ConnectionStatus, CLIENT_IDENTITY } from "../constants"; +import { isConnectionAuthError } from "../connectionAuthErrors"; import { Notification } from "../notificationTypes"; import { auth, @@ -64,6 +63,7 @@ import { clearScopeFromSessionStorage, discoverScopes, } from "../auth"; +import { createProxyFetch } from "../proxyFetch"; import { getMCPProxyAddress, getMCPTaskTtl, @@ -379,17 +379,6 @@ export function useConnection({ } }; - const is401Error = (error: unknown): boolean => { - return ( - (error instanceof SseError && error.code === 401) || - (error instanceof StreamableHTTPError && error.code === 401) || - (error instanceof Error && error.message.includes("401")) || - (error instanceof Error && error.message.includes("Unauthorized")) || - (error instanceof Error && - error.message.includes("Missing Authorization header")) - ); - }; - const isProxyAuthError = (error: unknown): boolean => { return ( error instanceof Error && @@ -398,19 +387,24 @@ export function useConnection({ }; const handleAuthError = async (error: unknown) => { - if (is401Error(error)) { + if (isConnectionAuthError(error)) { let scope = oauthScope?.trim(); + const fetchFn = + connectionType === "proxy" ? createProxyFetch(config) : undefined; + if (!scope) { // Only discover resource metadata when we need to discover scopes let resourceMetadata; try { resourceMetadata = await discoverOAuthProtectedResourceMetadata( new URL("/", sseUrl), + {}, + fetchFn, ); } catch { // Resource metadata is optional, continue without it } - scope = await discoverScopes(sseUrl, resourceMetadata); + scope = await discoverScopes(sseUrl, resourceMetadata, fetchFn); } saveScopeToSessionStorage(sseUrl, scope); @@ -420,6 +414,7 @@ export function useConnection({ const result = await auth(serverAuthProvider, { serverUrl: sseUrl, scope, + ...(fetchFn && { fetchFn }), }); return result === "AUTHORIZED"; } catch (authError) { @@ -823,7 +818,7 @@ export function useConnection({ if (shouldRetry) { return connect(undefined, retryCount + 1); } - if (is401Error(error)) { + if (isConnectionAuthError(error)) { // Don't set error state if we're about to redirect for auth return; diff --git a/client/src/lib/oauth-state-machine.ts b/client/src/lib/oauth-state-machine.ts index 8dc9da8f9..6628b9ad5 100644 --- a/client/src/lib/oauth-state-machine.ts +++ b/client/src/lib/oauth-state-machine.ts @@ -19,6 +19,7 @@ export interface StateMachineContext { serverUrl: string; provider: DebugInspectorOAuthClientProvider; updateState: (updates: Partial) => void; + fetchFn?: typeof fetch; } export interface StateTransition { @@ -38,6 +39,8 @@ export const oauthTransitions: Record = { try { resourceMetadata = await discoverOAuthProtectedResourceMetadata( context.serverUrl, + {}, + context.fetchFn, ); if (resourceMetadata?.authorization_servers?.length) { authServerUrl = new URL(resourceMetadata.authorization_servers[0]); @@ -57,7 +60,10 @@ export const oauthTransitions: Record = { resourceMetadata ?? undefined, ); - const metadata = await discoverAuthorizationServerMetadata(authServerUrl); + const metadata = await discoverAuthorizationServerMetadata( + authServerUrl, + { fetchFn: context.fetchFn }, + ); if (!metadata) { throw new Error("Failed to discover OAuth metadata"); } @@ -98,6 +104,7 @@ export const oauthTransitions: Record = { fullInformation = await registerClient(context.serverUrl, { metadata, clientMetadata, + fetchFn: context.fetchFn, }); context.provider.saveClientInformation(fullInformation); } @@ -122,6 +129,7 @@ export const oauthTransitions: Record = { scope = await discoverScopes( context.serverUrl, context.state.resourceMetadata ?? undefined, + context.fetchFn, ); } @@ -189,6 +197,7 @@ export const oauthTransitions: Record = { ? context.state.resource : new URL(context.state.resource) : undefined, + fetchFn: context.fetchFn, }); context.provider.saveTokens(tokens); @@ -211,6 +220,7 @@ export class OAuthStateMachine { constructor( private serverUrl: string, private updateState: (updates: Partial) => void, + private fetchFn?: typeof fetch, ) {} async executeStep(state: AuthDebuggerState): Promise { @@ -220,6 +230,7 @@ export class OAuthStateMachine { serverUrl: this.serverUrl, provider, updateState: this.updateState, + fetchFn: this.fetchFn, }; const transition = oauthTransitions[state.oauthStep]; diff --git a/client/src/lib/proxyFetch.ts b/client/src/lib/proxyFetch.ts new file mode 100644 index 000000000..a0ff5977f --- /dev/null +++ b/client/src/lib/proxyFetch.ts @@ -0,0 +1,158 @@ +import { getMCPProxyAddress, getMCPProxyAuthToken } from "@/utils/configUtils"; +import type { InspectorConfig } from "./configurationTypes"; + +interface ProxyFetchResponse { + ok: boolean; + status: number; + statusText: string; + headers: Record; + body: string; +} + +function isJsonObject(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +/** + * JSON body from POST /fetch when the proxy itself failed (auth, TLS, etc.), + * not a mirrored upstream HTTP response. + */ +function messageFromProxyInfrastructureError(data: unknown): string | null { + if (!isJsonObject(data)) { + return null; + } + const rec = data; + if (!("error" in rec) || "status" in rec) { + return null; + } + if (typeof rec.message === "string") { + return rec.message; + } + if (typeof rec.error === "string") { + return rec.error; + } + return "Proxy fetch failed"; +} + +/** + * Validates the JSON shape the proxy returns when it successfully forwarded + * a request and is mirroring the upstream response. + */ +function parseMirroredUpstreamJson(data: unknown): ProxyFetchResponse | null { + if (!isJsonObject(data)) { + return null; + } + const rec = data; + + if (typeof rec.status !== "number") { + return null; + } + if (typeof rec.body !== "string") { + return null; + } + if (typeof rec.statusText !== "string") { + return null; + } + if (typeof rec.ok !== "boolean") { + return null; + } + if (rec.headers === null || typeof rec.headers !== "object") { + return null; + } + if (Array.isArray(rec.headers)) { + return null; + } + + const headers: Record = {}; + for (const [key, val] of Object.entries(rec.headers)) { + if (typeof val !== "string") { + return null; + } + headers[key] = val; + } + + return { + ok: rec.ok, + status: rec.status, + statusText: rec.statusText, + headers, + body: rec.body, + }; +} + +/** + * Creates a fetch function that routes requests through the proxy server + * to avoid CORS restrictions on OAuth discovery and token endpoints. + */ +export function createProxyFetch(config: InspectorConfig): typeof fetch { + const proxyAddress = getMCPProxyAddress(config); + const { token, header } = getMCPProxyAuthToken(config); + + return async ( + input: RequestInfo | URL, + init?: RequestInit, + ): Promise => { + const url = + typeof input === "string" + ? input + : input instanceof Request + ? input.url + : input.toString(); + + // Serialize body for JSON transport. URLSearchParams and similar don't + // JSON-serialize (they become {}), so we must convert to string first. + let serializedBody: string | undefined; + if (init?.body != null) { + if (typeof init.body === "string") { + serializedBody = init.body; + } else if (init.body instanceof URLSearchParams) { + serializedBody = init.body.toString(); + } else { + serializedBody = String(init.body); + } + } + + const proxyResponse = await fetch(`${proxyAddress}/fetch`, { + method: "POST", + headers: { + "Content-Type": "application/json", + [header]: `Bearer ${token}`, + }, + body: JSON.stringify({ + url, + init: { + method: init?.method, + headers: init?.headers + ? Object.fromEntries(new Headers(init.headers)) + : undefined, + body: serializedBody, + }, + }), + }); + + let data: unknown; + try { + data = await proxyResponse.json(); + } catch { + throw new Error( + `Proxy fetch failed: ${proxyResponse.status} ${proxyResponse.statusText}`, + ); + } + + const infraMessage = messageFromProxyInfrastructureError(data); + if (infraMessage !== null) { + throw new Error(infraMessage); + } + + const mirrored = parseMirroredUpstreamJson(data); + if (mirrored === null) { + throw new Error("Proxy fetch failed: unexpected response shape"); + } + + return new Response(mirrored.body, { + status: mirrored.status, + statusText: mirrored.statusText, + headers: new Headers(mirrored.headers), + }); + }; +} diff --git a/server/src/index.ts b/server/src/index.ts index 4d1fffa29..e6af55e5f 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -27,7 +27,7 @@ import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; import express from "express"; import rateLimit from "express-rate-limit"; import { findActualExecutable } from "spawn-rx"; -import mcpProxy from "./mcpProxy.js"; +import mcpProxy, { type ProxyHeaderHolder } from "./mcpProxy.js"; import { randomUUID, randomBytes, timingSafeEqual } from "node:crypto"; import { fileURLToPath } from "url"; import { dirname, join } from "path"; @@ -72,6 +72,29 @@ const is401Error = (error: unknown): boolean => { return false; }; +/** + * Prefer forwarding the upstream MCP 401 (WWW-Authenticate + body) so the browser + * matches direct-mode OAuth behavior. Falls back to JSON-encoding `error` if unknown. + */ +const sendProxiedUnauthorized = ( + res: express.Response, + error: unknown, + headerHolder?: ProxyHeaderHolder, +) => { + const captured = headerHolder?.lastUpstream401; + if (captured && headerHolder) { + if (captured.wwwAuthenticate) { + res.setHeader("WWW-Authenticate", captured.wwwAuthenticate); + } + res.status(401); + res.setHeader("Content-Type", captured.contentType); + res.send(captured.body); + delete headerHolder.lastUpstream401; + return; + } + res.status(401).json(error); +}; + // Function to get HTTP headers. const getHttpHeaders = (req: express.Request): Record => { const headers: Record = {}; @@ -175,13 +198,16 @@ const updateHeadersInPlace = ( const app = express(); app.use(cors()); app.use((req, res, next) => { - res.header("Access-Control-Expose-Headers", "mcp-session-id"); + res.header( + "Access-Control-Expose-Headers", + "mcp-session-id, WWW-Authenticate", + ); next(); }); const webAppTransports: Map = new Map(); // Web app transports by web app sessionId const serverTransports: Map = new Map(); // Server Transports by web app sessionId -const sessionHeaderHolders: Map = new Map(); // For dynamic header updates +const sessionHeaderHolders: Map = new Map(); // For dynamic header updates // Use provided token from environment or generate a new one const sessionToken = @@ -303,7 +329,7 @@ const createWebReadableStream = (nodeStream: any): ReadableStream => { * `Content-Type` are preserved. For SSE requests, it also converts Node.js * streams to web-compatible streams. */ -const createCustomFetch = (headerHolder: { headers: HeadersInit }) => { +const createCustomFetch = (headerHolder: ProxyHeaderHolder) => { return async ( input: RequestInfo | URL, init?: RequestInit, @@ -334,6 +360,28 @@ const createCustomFetch = (headerHolder: { headers: HeadersInit }) => { { ...init, headers: headersObject } as any, ); + if (response.status === 401) { + const wwwAuthenticate = + response.headers.get("www-authenticate") ?? undefined; + const contentType = + response.headers.get("content-type") ?? "application/json"; + const body = await response.text(); + headerHolder.lastUpstream401 = { + wwwAuthenticate, + body, + contentType, + }; + const responseHeaders: Record = {}; + response.headers.forEach((value: string, key: string) => { + responseHeaders[key] = value; + }); + return new Response(body, { + status: 401, + statusText: response.statusText, + headers: responseHeaders, + }) as Response; + } + // Check if this is an SSE request by looking at the Accept header const acceptHeader = finalHeaders.get("Accept"); const isSSE = acceptHeader?.includes("text/event-stream"); @@ -366,7 +414,7 @@ const createTransport = async ( req: express.Request, ): Promise<{ transport: Transport; - headerHolder?: { headers: HeadersInit }; + headerHolder?: ProxyHeaderHolder; }> => { const query = req.query; console.log("Query parameters:", JSON.stringify(query)); @@ -397,7 +445,7 @@ const createTransport = async ( const headers = getHttpHeaders(req); headers["Accept"] = "text/event-stream"; - const headerHolder = { headers }; + const headerHolder: ProxyHeaderHolder = { headers }; console.log( `SSE transport: url=${url}, headers=${JSON.stringify(headers)}`, @@ -416,7 +464,7 @@ const createTransport = async ( } else if (transportType === "streamable-http") { const headers = getHttpHeaders(req); headers["Accept"] = "text/event-stream, application/json"; - const headerHolder = { headers }; + const headerHolder: ProxyHeaderHolder = { headers }; const transport = new StreamableHTTPClientTransport( new URL(query.url as string), @@ -501,9 +549,11 @@ app.post( } } else { console.log("New StreamableHttp connection request"); + let streamableHeaderHolder: ProxyHeaderHolder | undefined; try { const { transport: serverTransport, headerHolder } = await createTransport(req); + streamableHeaderHolder = headerHolder; const webAppTransport = new StreamableHTTPServerTransport({ sessionIdGenerator: randomUUID, @@ -528,6 +578,7 @@ app.post( mcpProxy({ transportToClient: webAppTransport, transportToServer: serverTransport, + headerHolder, }); await (webAppTransport as StreamableHTTPServerTransport).handleRequest( @@ -541,7 +592,7 @@ app.post( "Received 401 Unauthorized from MCP server:", error instanceof Error ? error.message : error, ); - res.status(401).json(error); + sendProxiedUnauthorized(res, error, streamableHeaderHolder); return; } console.error("Error in /mcp POST route:", error); @@ -681,7 +732,7 @@ app.get( console.error( "Received 401 Unauthorized from MCP server. Authentication failure.", ); - res.status(401).json(error); + sendProxiedUnauthorized(res, error, undefined); return; } console.error("Error in /stdio route:", error); @@ -695,12 +746,14 @@ app.get( originValidationMiddleware, authMiddleware, async (req, res) => { + let sseHeaderHolder: ProxyHeaderHolder | undefined; try { console.log( "New SSE connection request. NOTE: The SSE transport is deprecated and has been replaced by StreamableHttp", ); const { transport: serverTransport, headerHolder } = await createTransport(req); + sseHeaderHolder = headerHolder; const proxyFullAddress = (req.query.proxyFullAddress as string) || ""; const prefix = proxyFullAddress || ""; @@ -721,13 +774,14 @@ app.get( mcpProxy({ transportToClient: webAppTransport, transportToServer: serverTransport, + headerHolder, }); } catch (error) { if (is401Error(error)) { console.error( "Received 401 Unauthorized from MCP server. Authentication failure.", ); - res.status(401).json(error); + sendProxiedUnauthorized(res, error, sseHeaderHolder); return; } else if (error instanceof SseError && error.code === 404) { console.error( @@ -781,6 +835,60 @@ app.get("/health", (req, res) => { }); }); +app.post( + "/fetch", + express.json(), + originValidationMiddleware, + authMiddleware, + async (req, res) => { + try { + const { url, init } = req.body as { url: string; init?: RequestInit }; + + if (typeof url !== "string" || url.length === 0) { + res.status(400).json({ error: "Missing or invalid url" }); + return; + } + + let parsedUrl: URL; + try { + parsedUrl = new URL(url); + } catch { + res.status(400).json({ error: "Invalid URL" }); + return; + } + + if (!["http:", "https:"].includes(parsedUrl.protocol)) { + res.status(400).json({ error: "Only http/https URLs are allowed" }); + return; + } + + const response = await fetch(url, { + method: init?.method ?? "GET", + headers: (init?.headers as Record) ?? {}, + body: init?.body as string | undefined, + }); + + const responseBody = await response.text(); + const headers: Record = {}; + response.headers.forEach((value, key) => { + headers[key] = value; + }); + + res.status(response.status).json({ + ok: response.ok, + status: response.status, + statusText: response.statusText, + headers, + body: responseBody, + }); + } catch (error) { + res.status(500).json({ + error: error instanceof Error ? error.message : String(error), + }); + } + }, +); + app.get("/config", originValidationMiddleware, authMiddleware, (req, res) => { try { res.json({ diff --git a/server/src/mcpProxy.ts b/server/src/mcpProxy.ts index 174eef0ec..3195553e1 100644 --- a/server/src/mcpProxy.ts +++ b/server/src/mcpProxy.ts @@ -1,6 +1,33 @@ +import { SseError } from "@modelcontextprotocol/sdk/client/sse.js"; +import { StreamableHTTPError } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; import { isJSONRPCRequest } from "@modelcontextprotocol/sdk/types.js"; +/** + * JSON-RPC error code for failed proxy → upstream transport sends. + * + * JSON-RPC 2.0 reserves **-32000 .. -32099** for implementation-defined *server* errors. + * MCP uses that band for protocol/SDK errors (see `@modelcontextprotocol/sdk` `ErrorCode`, e.g. + * `-32000` connection closed, `-32001` request timeout, `-32042` URL elicitation; other codes in + * the band are assigned over time — e.g. **-32002** is used for **ResourceNotFound** in parts of + * the MCP ecosystem). We use **-32099** so this inspector-only bridge error stays clear of those + * registered meanings; the client still keys off `error.data` shape, not the number alone. + * + * Keep in sync with `client/src/lib/constants.ts`; drift is caught by + * `client/src/lib/__tests__/mcpProxyTransportErrorCode.test.ts`. + */ +export const MCP_PROXY_TRANSPORT_ERROR_CODE = -32099; + +/** Session header bag used with `createCustomFetch`; may hold last upstream 401 snapshot. */ +export type ProxyHeaderHolder = { + headers: HeadersInit; + lastUpstream401?: { + wwwAuthenticate?: string; + body: string; + contentType: string; + }; +}; + function onClientError(error: Error) { console.error("Error from inspector client:", error); } @@ -15,12 +42,49 @@ function onServerError(error: Error) { } } +/** Exported for unit tests; used by the proxy when forwarding transport failures. */ +export function serializeProxyTransportError( + error: Error, + headerHolder?: ProxyHeaderHolder, +): Record { + const data: Record = { + message: error.message, + name: error.name, + }; + if (error.cause !== undefined) { + data.cause = + error.cause instanceof Error ? error.cause.message : String(error.cause); + } + const attachHttpStatusIfValid = (code: number) => { + if (Number.isInteger(code) && code >= 100 && code <= 599) { + data.httpStatus = code; + } + }; + if (error instanceof StreamableHTTPError) { + if (error.code !== undefined) attachHttpStatusIfValid(error.code); + } else if (error instanceof SseError) { + if (error.code !== undefined) attachHttpStatusIfValid(error.code); + } + if (headerHolder?.lastUpstream401) { + const u = headerHolder.lastUpstream401; + data.upstream401 = { + wwwAuthenticate: u.wwwAuthenticate, + body: u.body, + contentType: u.contentType, + }; + delete headerHolder.lastUpstream401; + } + return data; +} + export default function mcpProxy({ transportToClient, transportToServer, + headerHolder, }: { transportToClient: Transport; transportToServer: Transport; + headerHolder?: ProxyHeaderHolder; }) { let transportToClientClosed = false; let transportToServerClosed = false; @@ -28,18 +92,25 @@ export default function mcpProxy({ let reportedServerSession = false; transportToClient.onmessage = (message) => { - transportToServer.send(message).catch((error) => { + transportToServer.send(message).catch((err: unknown) => { + const error = err instanceof Error ? err : new Error(String(err)); // Send error response back to client if it was a request (has id) and connection is still open if (isJSONRPCRequest(message) && !transportToClientClosed) { + const causeStr = + error.cause !== undefined + ? error.cause instanceof Error + ? error.cause.message + : String(error.cause) + : null; const errorResponse = { jsonrpc: "2.0" as const, id: message.id, error: { - code: -32001, - message: error.cause - ? `${error.message} (cause: ${error.cause})` + code: MCP_PROXY_TRANSPORT_ERROR_CODE, + message: causeStr + ? `${error.message} (cause: ${causeStr})` : error.message, - data: error, + data: serializeProxyTransportError(error, headerHolder), }, }; transportToClient.send(errorResponse).catch(onClientError);