diff --git a/core/__tests__/auth/discovery.test.ts b/core/__tests__/auth/discovery.test.ts index 591701291..7ce0225a4 100644 --- a/core/__tests__/auth/discovery.test.ts +++ b/core/__tests__/auth/discovery.test.ts @@ -1,5 +1,8 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; -import { discoverScopes } from "../../auth/discovery.js"; +import { + discoverScopes, + getAuthorizationServerUrl, +} from "../../auth/discovery.js"; import type { OAuthProtectedResourceMetadata } from "@modelcontextprotocol/sdk/shared/auth.js"; // Mock SDK functions @@ -176,4 +179,139 @@ describe("OAuth Scope Discovery", () => { { fetchFn: mockFetchFn }, ); }); + + it("should use authorization_servers URL from resource metadata for discovery (different domain)", async () => { + const { discoverAuthorizationServerMetadata } = + await import("@modelcontextprotocol/sdk/client/auth.js"); + vi.mocked(discoverAuthorizationServerMetadata).mockResolvedValue({ + issuer: "https://auth-server.com", + authorization_endpoint: "https://auth-server.com/authorize", + token_endpoint: "https://auth-server.com/token", + response_types_supported: ["code"], + scopes_supported: ["read", "write"], + }); + + const resourceMetadata: OAuthProtectedResourceMetadata = { + resource: "https://mcp-server.com", + authorization_servers: ["https://auth-server.com/"], + scopes_supported: ["read", "write"], + }; + + const scopes = await discoverScopes( + "https://mcp-server.com", + resourceMetadata, + ); + + expect(scopes).toBe("read write"); + expect(discoverAuthorizationServerMetadata).toHaveBeenCalledWith( + new URL("https://auth-server.com/"), + { fetchFn: undefined }, + ); + }); + + it("should preserve full path in authorization_servers URL", async () => { + const { discoverAuthorizationServerMetadata } = + await import("@modelcontextprotocol/sdk/client/auth.js"); + vi.mocked(discoverAuthorizationServerMetadata).mockResolvedValue({ + issuer: "https://auth-server.com/realms/my-realm", + authorization_endpoint: + "https://auth-server.com/realms/my-realm/authorize", + token_endpoint: "https://auth-server.com/realms/my-realm/token", + response_types_supported: ["code"], + scopes_supported: ["read", "write"], + }); + + const resourceMetadata: OAuthProtectedResourceMetadata = { + resource: "https://mcp-server.com", + authorization_servers: ["https://auth-server.com/realms/my-realm/"], + scopes_supported: ["read", "write"], + }; + + const scopes = await discoverScopes( + "https://mcp-server.com", + resourceMetadata, + ); + + expect(scopes).toBe("read write"); + expect(discoverAuthorizationServerMetadata).toHaveBeenCalledWith( + new URL("https://auth-server.com/realms/my-realm/"), + { fetchFn: undefined }, + ); + }); + + it("should fall back to serverUrl when authorization_servers is empty", async () => { + const { discoverAuthorizationServerMetadata } = + await import("@modelcontextprotocol/sdk/client/auth.js"); + vi.mocked(discoverAuthorizationServerMetadata).mockResolvedValue({ + issuer: "https://mcp-server.com", + authorization_endpoint: "https://mcp-server.com/authorize", + token_endpoint: "https://mcp-server.com/token", + response_types_supported: ["code"], + scopes_supported: ["read", "write"], + }); + + const resourceMetadata: OAuthProtectedResourceMetadata = { + resource: "https://mcp-server.com", + authorization_servers: [], + scopes_supported: ["read", "write"], + }; + + const scopes = await discoverScopes( + "https://mcp-server.com", + resourceMetadata, + ); + + expect(scopes).toBe("read write"); + expect(discoverAuthorizationServerMetadata).toHaveBeenCalledWith( + new URL("/", "https://mcp-server.com"), + { fetchFn: undefined }, + ); + }); +}); + +describe("getAuthorizationServerUrl", () => { + const serverUrl = "https://mcp.example.com"; + + it("returns server URL when resourceMetadata is null", () => { + expect(getAuthorizationServerUrl(serverUrl, null)).toEqual( + new URL("/", serverUrl), + ); + }); + + it("returns server URL when resourceMetadata is undefined", () => { + expect(getAuthorizationServerUrl(serverUrl)).toEqual( + new URL("/", serverUrl), + ); + }); + + it("returns server URL when authorization_servers is empty array", () => { + const resourceMetadata: OAuthProtectedResourceMetadata = { + resource: serverUrl, + authorization_servers: [], + }; + expect(getAuthorizationServerUrl(serverUrl, resourceMetadata)).toEqual( + new URL("/", serverUrl), + ); + }); + + it("falls back to server URL when authorization_servers[0] is empty string", () => { + const resourceMetadata: OAuthProtectedResourceMetadata = { + resource: serverUrl, + authorization_servers: [""], + }; + expect(getAuthorizationServerUrl(serverUrl, resourceMetadata)).toEqual( + new URL("/", serverUrl), + ); + }); + + it("returns authorization_servers[0] when present and truthy", () => { + const authUrl = "https://auth.example.com/"; + const resourceMetadata: OAuthProtectedResourceMetadata = { + resource: serverUrl, + authorization_servers: [authUrl], + }; + expect(getAuthorizationServerUrl(serverUrl, resourceMetadata)).toEqual( + new URL(authUrl), + ); + }); }); diff --git a/core/__tests__/auth/state-machine.test.ts b/core/__tests__/auth/state-machine.test.ts index 3fb153ae0..9f91caf94 100644 --- a/core/__tests__/auth/state-machine.test.ts +++ b/core/__tests__/auth/state-machine.test.ts @@ -167,6 +167,44 @@ describe("OAuthStateMachine", () => { ); }); + it("should use authorization_servers URL from resource metadata for auth server discovery", async () => { + const authServerUrl = "https://auth-server.com/"; + const resourceMetaDifferentAuth: OAuthProtectedResourceMetadata = { + resource: serverUrl, + authorization_servers: [authServerUrl], + scopes_supported: ["read", "write"], + }; + const selectedResource = new URL(serverUrl); + const { + discoverOAuthProtectedResourceMetadata, + discoverAuthorizationServerMetadata, + selectResourceURL, + } = await import("@modelcontextprotocol/sdk/client/auth.js"); + vi.mocked(discoverOAuthProtectedResourceMetadata).mockResolvedValue( + resourceMetaDifferentAuth, + ); + vi.mocked(selectResourceURL).mockResolvedValue(selectedResource); + + const stateMachine = new OAuthStateMachine( + serverUrl, + mockProvider, + updateState, + ); + await stateMachine.executeStep(state); + + expect(discoverAuthorizationServerMetadata).toHaveBeenCalledWith( + new URL(authServerUrl), + expect.any(Object), + ); + expect(updateState).toHaveBeenCalledWith( + expect.objectContaining({ + resourceMetadata: resourceMetaDifferentAuth, + authServerUrl: new URL(authServerUrl), + oauthStep: "client_registration", + }), + ); + }); + it("should call selectResourceURL only when resource metadata is present", async () => { const { discoverOAuthProtectedResourceMetadata, selectResourceURL } = await import("@modelcontextprotocol/sdk/client/auth.js"); diff --git a/core/auth/discovery.ts b/core/auth/discovery.ts index ae1b33ef6..f2d9194b7 100644 --- a/core/auth/discovery.ts +++ b/core/auth/discovery.ts @@ -1,6 +1,19 @@ import { discoverAuthorizationServerMetadata } from "@modelcontextprotocol/sdk/client/auth.js"; import type { OAuthProtectedResourceMetadata } from "@modelcontextprotocol/sdk/shared/auth.js"; +/** + * Returns the URL to use for OAuth authorization server metadata discovery. + * Uses resource metadata's authorization_servers[0] when present, otherwise the MCP server URL. + */ +export function getAuthorizationServerUrl( + serverUrl: string, + resourceMetadata?: OAuthProtectedResourceMetadata | null, +): URL { + const first = resourceMetadata?.authorization_servers?.[0]; + // Use truthy check to match original state-machine: empty string falls back to serverUrl + return first ? new URL(first) : new URL("/", serverUrl); +} + /** * Discovers OAuth scopes from server metadata, with preference for resource metadata scopes * @param serverUrl - The MCP server URL @@ -14,10 +27,13 @@ export const discoverScopes = async ( fetchFn?: typeof fetch, ): Promise => { try { - const metadata = await discoverAuthorizationServerMetadata( - new URL("/", serverUrl), - { fetchFn }, + const authServerUrl = getAuthorizationServerUrl( + serverUrl, + resourceMetadata, ); + const metadata = await discoverAuthorizationServerMetadata(authServerUrl, { + fetchFn, + }); // Prefer resource metadata scopes, but fall back to OAuth metadata if empty const resourceScopes = resourceMetadata?.scopes_supported; diff --git a/core/auth/state-machine.ts b/core/auth/state-machine.ts index 49f950718..4f2958e79 100644 --- a/core/auth/state-machine.ts +++ b/core/auth/state-machine.ts @@ -1,6 +1,6 @@ import type { OAuthStep, AuthGuidedState } from "./types.js"; import type { BaseOAuthClientProvider } from "./providers.js"; -import { discoverScopes } from "./discovery.js"; +import { discoverScopes, getAuthorizationServerUrl } from "./discovery.js"; import { discoverAuthorizationServerMetadata, registerClient, @@ -32,20 +32,12 @@ export const oauthTransitions: Record = { metadata_discovery: { canTransition: async () => true, execute: async (context) => { - // Default to discovering from the server's URL - let authServerUrl: URL = new URL("/", context.serverUrl); let resourceMetadata: OAuthProtectedResourceMetadata | null = null; let resourceMetadataError: Error | null = null; try { resourceMetadata = await discoverOAuthProtectedResourceMetadata( context.serverUrl as string | URL, ); - if (resourceMetadata?.authorization_servers?.length) { - const firstServer = resourceMetadata.authorization_servers[0]; - if (firstServer) { - authServerUrl = new URL(firstServer); - } - } } catch (e) { if (e instanceof Error) { resourceMetadataError = e; @@ -54,6 +46,11 @@ export const oauthTransitions: Record = { } } + const authServerUrl = getAuthorizationServerUrl( + context.serverUrl, + resourceMetadata, + ); + const resource: URL | undefined = resourceMetadata ? await selectResourceURL( context.serverUrl,