diff --git a/.github/actions/conformance/client.py b/.github/actions/conformance/client.py
index 58f684f01..d45ca6626 100644
--- a/.github/actions/conformance/client.py
+++ b/.github/actions/conformance/client.py
@@ -5,7 +5,7 @@
Contract:
- MCP_CONFORMANCE_SCENARIO env var -> scenario name
- - MCP_CONFORMANCE_CONTEXT env var -> optional JSON (for client-credentials scenarios)
+ - MCP_CONFORMANCE_CONTEXT env var -> optional JSON (for auth scenarios)
- Server URL as last CLI argument (sys.argv[1])
- Must exit 0 within 30 seconds
@@ -16,7 +16,16 @@
elicitation-sep1034-client-defaults - Elicitation with default accept callback
auth/client-credentials-jwt - Client credentials with private_key_jwt
auth/client-credentials-basic - Client credentials with client_secret_basic
+ auth/cross-app-access-complete-flow - Enterprise managed OAuth (SEP-990) - v0.1.14+
auth/* - Authorization code flow (default for auth scenarios)
+
+Enterprise Auth (SEP-990):
+ The conformance package v0.1.14+ (https://github.com/modelcontextprotocol/conformance/pull/110)
+ provides the scenario 'auth/cross-app-access-complete-flow' which tests the complete
+ enterprise managed OAuth flow: IDP ID token → ID-JAG → access token.
+
+ The client receives test context (idp_id_token, idp_token_endpoint, etc.) via
+ MCP_CONFORMANCE_CONTEXT environment variable and performs the token exchange flows automatically.
"""
import asyncio
@@ -314,9 +323,100 @@ async def run_auth_code_client(server_url: str) -> None:
await _run_auth_session(server_url, oauth_auth)
+@register("auth/cross-app-access-complete-flow")
+async def run_cross_app_access_complete_flow(server_url: str) -> None:
+ """Enterprise managed auth: Complete SEP-990 flow (OIDC ID token → ID-JAG → access token).
+
+ This scenario is provided by @modelcontextprotocol/conformance@0.1.14+ (PR #110).
+ It tests the complete enterprise managed OAuth flow using token exchange (RFC 8693)
+ and JWT bearer grant (RFC 7523).
+ """
+ from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ TokenExchangeParameters,
+ )
+
+ context = get_conformance_context()
+ # The conformance package provides these fields
+ idp_id_token = context.get("idp_id_token")
+ idp_token_endpoint = context.get("idp_token_endpoint")
+ idp_issuer = context.get("idp_issuer")
+
+ # For cross-app access, we need to determine the MCP server's resource ID and auth issuer
+ # The conformance package sets up the auth server, and the MCP server URL is passed to us
+
+ if not idp_id_token:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_id_token'")
+ if not idp_token_endpoint:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_token_endpoint'")
+ if not idp_issuer:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_issuer'")
+
+ # Extract base URL and construct auth issuer and resource ID
+ # The conformance test sets up auth server at a known location
+ base_url = server_url.replace("/mcp", "")
+ auth_issuer = context.get("auth_issuer", base_url)
+ resource_id = context.get("resource_id", server_url)
+
+ logger.debug("Cross-app access flow:")
+ logger.debug(f" IDP Issuer: {idp_issuer}")
+ logger.debug(f" IDP Token Endpoint: {idp_token_endpoint}")
+ logger.debug(f" Auth Issuer: {auth_issuer}")
+ logger.debug(f" Resource ID: {resource_id}")
+
+ # Create token exchange parameters from IDP ID token
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=idp_id_token,
+ mcp_server_auth_issuer=auth_issuer,
+ mcp_server_resource_id=resource_id,
+ scope=context.get("scope"),
+ )
+
+ # Get pre-configured client credentials from context (if provided)
+ client_id = context.get("client_id")
+ client_secret = context.get("client_secret")
+
+ # Create storage and pre-configure client info if credentials are provided
+ storage = InMemoryTokenStorage()
+
+ # Create enterprise auth provider
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url=server_url,
+ client_metadata=OAuthClientMetadata(
+ client_name="conformance-cross-app-client",
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ ),
+ storage=storage,
+ idp_token_endpoint=idp_token_endpoint,
+ token_exchange_params=token_exchange_params,
+ )
+
+ # If client credentials are provided in context, use them instead of dynamic registration
+ if client_id and client_secret:
+ from mcp.shared.auth import OAuthClientInformationFull
+
+ logger.debug(f"Using pre-configured client credentials: {client_id}")
+ client_info = OAuthClientInformationFull(
+ client_id=client_id,
+ client_secret=client_secret,
+ token_endpoint_auth_method="client_secret_basic",
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ )
+ enterprise_auth.context.client_info = client_info
+ await storage.set_client_info(client_info)
+
+ await _run_auth_session(server_url, enterprise_auth)
+
+
async def _run_auth_session(server_url: str, oauth_auth: OAuthClientProvider) -> None:
"""Common session logic for all OAuth flows."""
- client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0)
+ # Allow timeout to be configured via environment variable for different test scenarios
+ timeout = float(os.environ.get("MCP_CONFORMANCE_TIMEOUT", "30.0"))
+ client = httpx.AsyncClient(auth=oauth_auth, timeout=timeout)
async with streamable_http_client(url=server_url, http_client=client) as (read_stream, write_stream):
async with ClientSession(
read_stream, write_stream, elicitation_callback=default_elicitation_callback
diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml
index d876da00b..73718540f 100644
--- a/.github/workflows/conformance.yml
+++ b/.github/workflows/conformance.yml
@@ -33,13 +33,13 @@ jobs:
runs-on: ubuntu-latest
continue-on-error: true
steps:
- - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1
- - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1
+ - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
+ - uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # v7.2.0
with:
enable-cache: true
version: 0.9.5
- - uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0
+ - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with:
node-version: 24
- run: uv sync --frozen --all-extras --package mcp
- - run: npx @modelcontextprotocol/conformance@0.1.13 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all
+ - run: npx @modelcontextprotocol/conformance@0.1.14 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all
diff --git a/README.v2.md b/README.v2.md
index 55d867586..b0b0b7286 100644
--- a/README.v2.md
+++ b/README.v2.md
@@ -70,6 +70,7 @@
- [Writing MCP Clients](#writing-mcp-clients)
- [Client Display Utilities](#client-display-utilities)
- [OAuth Authentication for Clients](#oauth-authentication-for-clients)
+ - [Enterprise Managed Authorization](#enterprise-managed-authorization)
- [Parsing Tool Results](#parsing-tool-results)
- [MCP Primitives](#mcp-primitives)
- [Server Capabilities](#server-capabilities)
@@ -2395,6 +2396,328 @@ _Full example: [examples/snippets/clients/oauth_client.py](https://github.com/mo
For a complete working example, see [`examples/clients/simple-auth-client/`](examples/clients/simple-auth-client/).
+#### Enterprise Managed Authorization
+
+The SDK includes support for Enterprise Managed Authorization (SEP-990), which enables MCP clients to connect to protected servers using enterprise Single Sign-On (SSO) systems. This implementation supports:
+
+- **RFC 8693**: OAuth 2.0 Token Exchange (ID Token -> ID-JAG)
+- **RFC 7523**: JSON Web Token (JWT) Profile for OAuth 2.0 Authorization Grants (ID-JAG -> Access Token)
+- Integration with enterprise identity providers (Okta, Azure AD, etc.)
+
+**Key Components:**
+
+The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provider to implement the enterprise authorization flow:
+
+**Token Exchange Flow:**
+
+1. **Obtain ID Token** from your enterprise IdP (e.g., Okta, Azure AD)
+2. **Exchange ID Token for ID-JAG** using RFC 8693 Token Exchange
+3. **Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant
+4. **Use Access Token** to call protected MCP server tools
+
+**Using the Access Token with MCP Server:**
+
+1. Once you have obtained the access token, you can use it to authenticate requests to the MCP server
+2. The access token is automatically included in all subsequent requests to the MCP server, allowing you to access protected tools and resources based on your enterprise identity and permissions.
+
+**Handling Token Expiration and Refresh:**
+
+Access tokens have a limited lifetime and will expire. When tokens expire:
+
+- **Check Token Expiration**: Use the `expires_in` field to determine when the token expires
+- **Refresh Flow**: When expired, repeat the token exchange flow with a fresh ID token from your IdP
+- **Automatic Refresh**: Implement automatic token refresh before expiration (recommended for production)
+- **Error Handling**: Catch authentication errors and retry with refreshed tokens
+
+**Important Notes:**
+
+- **ID Token Expiration**: If the ID token from your IdP expires, you must re-authenticate with the IdP to obtain a new ID token before performing token exchange
+- **Token Storage**: Store tokens securely and implement the `TokenStorage` interface to persist tokens between application restarts
+- **Scope Changes**: If you need different scopes, you must obtain a new ID token from the IdP with the required scopes
+- **Security**: Never log or expose access tokens or ID tokens in production environments
+
+**Example Usage:**
+
+
+```python
+import asyncio
+
+import httpx
+from pydantic import AnyUrl
+
+from mcp import ClientSession
+from mcp.client.auth import TokenStorage
+from mcp.client.auth.extensions import (
+ EnterpriseAuthOAuthClientProvider,
+ TokenExchangeParameters,
+)
+from mcp.client.streamable_http import streamable_http_client
+from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
+
+
+# Placeholder function for IdP authentication
+async def get_id_token_from_idp() -> str:
+ """Placeholder function to get ID token from your IdP.
+
+ In production, implement actual IdP authentication flow.
+ """
+ raise NotImplementedError("Implement your IdP authentication flow here")
+
+
+# Define token storage implementation
+class SimpleTokenStorage(TokenStorage):
+ def __init__(self) -> None:
+ self._tokens: OAuthToken | None = None
+ self._client_info: OAuthClientInformationFull | None = None
+
+ async def get_tokens(self) -> OAuthToken | None:
+ return self._tokens
+
+ async def set_tokens(self, tokens: OAuthToken) -> None:
+ self._tokens = tokens
+
+ async def get_client_info(self) -> OAuthClientInformationFull | None:
+ return self._client_info
+
+ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
+ self._client_info = client_info
+
+
+async def discover_mcp_server_metadata(server_url: str) -> tuple[str, str]:
+ """Discover MCP server's OAuth metadata and resource identifier.
+
+ Returns:
+ Tuple of (auth_issuer, resource_id)
+ """
+ from mcp.client.auth.utils import (
+ build_oauth_authorization_server_metadata_discovery_urls,
+ build_protected_resource_metadata_discovery_urls,
+ handle_auth_metadata_response,
+ handle_protected_resource_response,
+ )
+
+ async with httpx.AsyncClient() as client:
+ # Step 1: Discover Protected Resource Metadata (PRM)
+ prm_urls = build_protected_resource_metadata_discovery_urls(None, server_url)
+
+ prm = None
+ for url in prm_urls:
+ response = await client.get(url)
+ prm = await handle_protected_resource_response(response)
+ if prm:
+ break
+
+ if not prm:
+ raise ValueError("Could not discover Protected Resource Metadata")
+
+ # Extract resource identifier and authorization server URL
+ resource_id = str(prm.resource)
+ auth_server_url = str(prm.authorization_servers[0]) if prm.authorization_servers else None
+
+ # Step 2: Discover OAuth Authorization Server Metadata
+ oauth_urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
+
+ oauth_metadata = None
+ for url in oauth_urls:
+ response = await client.get(url)
+ ok, asm = await handle_auth_metadata_response(response)
+ if ok and asm:
+ oauth_metadata = asm
+ break
+
+ if not oauth_metadata or not oauth_metadata.issuer:
+ raise ValueError("Could not discover OAuth metadata or issuer")
+
+ auth_issuer = str(oauth_metadata.issuer)
+
+ return auth_issuer, resource_id
+
+
+async def main() -> None:
+ """Example demonstrating enterprise managed authorization with MCP."""
+ server_url = "https://mcp-server.example.com"
+
+ # Step 1: Get ID token from your IdP (e.g., Okta, Azure AD)
+ id_token = await get_id_token_from_idp()
+
+ # Step 2: Discover MCP server's OAuth metadata and resource identifier
+ # This replaces hardcoding these values
+ mcp_server_auth_issuer, mcp_server_resource_id = await discover_mcp_server_metadata(server_url)
+ print(f"Discovered auth issuer: {mcp_server_auth_issuer}")
+ print(f"Discovered resource ID: {mcp_server_resource_id}")
+
+ # Step 3: Configure token exchange parameters using discovered values
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer=mcp_server_auth_issuer,
+ mcp_server_resource_id=mcp_server_resource_id,
+ scope="mcp:tools mcp:resources", # Optional scopes
+ )
+
+ # Step 4: Create enterprise auth provider
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url=server_url,
+ client_metadata=OAuthClientMetadata(
+ client_name="Enterprise MCP Client",
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token", # Your IdP's token endpoint
+ token_exchange_params=token_exchange_params,
+ # Optional: IdP client credentials if your IdP requires client authentication for token exchange
+ # idp_client_id="your-idp-client-id",
+ # idp_client_secret="your-idp-client-secret",
+ )
+
+ # Step 5: Create authenticated HTTP client
+ # The auth provider automatically handles the two-step token exchange:
+ # 1. ID Token -> ID-JAG (via IDP)
+ # 2. ID-JAG -> Access Token (via MCP server)
+ client = httpx.AsyncClient(auth=enterprise_auth, timeout=30.0)
+
+ # Step 6: Connect to MCP server with authenticated client
+ async with streamable_http_client(url=server_url, http_client=client) as (read, write):
+ async with ClientSession(read, write) as session:
+ await session.initialize()
+
+ # List available tools
+ tools_result = await session.list_tools()
+ print(f"Available tools: {[t.name for t in tools_result.tools]}")
+
+ # Call a tool - auth tokens are automatically managed
+ if tools_result.tools:
+ tool_name = tools_result.tools[0].name
+ result = await session.call_tool(tool_name, {})
+ print(f"Tool result: {result.content}")
+
+ # List available resources
+ resources = await session.list_resources()
+ for resource in resources.resources:
+ print(f"Resource: {resource.uri}")
+
+
+async def advanced_manual_flow() -> None:
+ """Advanced example showing manual token exchange.
+
+ Use cases for manual token exchange:
+ - Testing and debugging: Inspect ID-JAG claims before exchanging for access token
+ - Token caching: Store and reuse ID-JAG across multiple MCP server connections
+ - Custom error handling: Implement specific retry logic for each token exchange step
+ - Monitoring: Log token exchange metrics and performance
+ - Token introspection: Validate ID-JAG structure before sending to MCP server
+ """
+ id_token = await get_id_token_from_idp()
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example.com",
+ mcp_server_resource_id="https://mcp-server.example.com",
+ )
+
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example.com",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Manual token exchange (useful for debugging, caching, custom error handling, etc.)
+ async with httpx.AsyncClient() as client:
+ # Step 1: Exchange ID token for ID-JAG
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+ # WARNING: Only log tokens in development/testing environments
+ # In production, NEVER log tokens or token fragments as they are sensitive credentials
+ print(f"Obtained ID-JAG: {id_jag[:50]}...")
+
+ # Step 2: Build JWT bearer grant request
+ jwt_bearer_request = await enterprise_auth.exchange_id_jag_for_access_token(id_jag)
+ print(f"Built JWT bearer grant request to: {jwt_bearer_request.url}")
+
+ # Step 3: Execute the request to get access token
+ response = await client.send(jwt_bearer_request)
+ response.raise_for_status()
+ token_data = response.json()
+
+ access_token = OAuthToken(
+ access_token=token_data["access_token"],
+ token_type=token_data["token_type"],
+ expires_in=token_data.get("expires_in"),
+ )
+ # WARNING: In production, do not log token expiry or any token information
+ print(f"Access token obtained, expires in: {access_token.expires_in}s")
+
+ # Use the access token for API calls
+ _ = {"Authorization": f"Bearer {access_token.access_token}"}
+ # ... make authenticated requests with headers
+
+
+async def token_refresh_example() -> None:
+ """Example showing how to refresh tokens when they expire.
+
+ When your access token expires, you need to obtain a fresh ID token
+ from your enterprise IdP and use the refresh helper method.
+ """
+ # Initial setup
+ id_token = await get_id_token_from_idp()
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example.com",
+ mcp_server_resource_id="https://mcp-server.example.com",
+ )
+
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example.com",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ _ = httpx.AsyncClient(auth=enterprise_auth, timeout=30.0)
+
+ # Use the client for MCP operations...
+ # ... time passes and token expires ...
+
+ # When token expires, get a fresh ID token from your IdP
+ new_id_token = await get_id_token_from_idp()
+
+ # Refresh the authentication using the helper method
+ await enterprise_auth.refresh_with_new_id_token(new_id_token)
+
+ # Next API call will automatically use the refreshed tokens
+ # No need to recreate the client or session
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
+```
+
+_Full example: [examples/snippets/clients/enterprise_managed_auth_client.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/enterprise_managed_auth_client.py)_
+
+
+**Working with SAML Assertions:**
+
+If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions:
+
+```python
+token_exchange_params = TokenExchangeParameters.from_saml_assertion(
+ saml_assertion=saml_assertion_string,
+ mcp_server_auth_issuer="https://your-idp.com",
+ mcp_server_resource_id="https://mcp-server.example.com",
+ scope="mcp:tools",
+)
+```
+
+For more details on the enterprise authorization flow, see the [MCP Enterprise Authorization specification](https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization).
+
### Parsing Tool Results
When calling tools through MCP, the `CallToolResult` object contains the tool's response in a structured format. Understanding how to parse this result is essential for properly handling tool outputs.
diff --git a/examples/snippets/clients/enterprise_managed_auth_client.py b/examples/snippets/clients/enterprise_managed_auth_client.py
new file mode 100644
index 000000000..3888fdc21
--- /dev/null
+++ b/examples/snippets/clients/enterprise_managed_auth_client.py
@@ -0,0 +1,258 @@
+import asyncio
+
+import httpx
+from pydantic import AnyUrl
+
+from mcp import ClientSession
+from mcp.client.auth import TokenStorage
+from mcp.client.auth.extensions import (
+ EnterpriseAuthOAuthClientProvider,
+ TokenExchangeParameters,
+)
+from mcp.client.streamable_http import streamable_http_client
+from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
+
+
+# Placeholder function for IdP authentication
+async def get_id_token_from_idp() -> str:
+ """Placeholder function to get ID token from your IdP.
+
+ In production, implement actual IdP authentication flow.
+ """
+ raise NotImplementedError("Implement your IdP authentication flow here")
+
+
+# Define token storage implementation
+class SimpleTokenStorage(TokenStorage):
+ def __init__(self) -> None:
+ self._tokens: OAuthToken | None = None
+ self._client_info: OAuthClientInformationFull | None = None
+
+ async def get_tokens(self) -> OAuthToken | None:
+ return self._tokens
+
+ async def set_tokens(self, tokens: OAuthToken) -> None:
+ self._tokens = tokens
+
+ async def get_client_info(self) -> OAuthClientInformationFull | None:
+ return self._client_info
+
+ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
+ self._client_info = client_info
+
+
+async def discover_mcp_server_metadata(server_url: str) -> tuple[str, str]:
+ """Discover MCP server's OAuth metadata and resource identifier.
+
+ Returns:
+ Tuple of (auth_issuer, resource_id)
+ """
+ from mcp.client.auth.utils import (
+ build_oauth_authorization_server_metadata_discovery_urls,
+ build_protected_resource_metadata_discovery_urls,
+ handle_auth_metadata_response,
+ handle_protected_resource_response,
+ )
+
+ async with httpx.AsyncClient() as client:
+ # Step 1: Discover Protected Resource Metadata (PRM)
+ prm_urls = build_protected_resource_metadata_discovery_urls(None, server_url)
+
+ prm = None
+ for url in prm_urls:
+ response = await client.get(url)
+ prm = await handle_protected_resource_response(response)
+ if prm:
+ break
+
+ if not prm:
+ raise ValueError("Could not discover Protected Resource Metadata")
+
+ # Extract resource identifier and authorization server URL
+ resource_id = str(prm.resource)
+ auth_server_url = str(prm.authorization_servers[0]) if prm.authorization_servers else None
+
+ # Step 2: Discover OAuth Authorization Server Metadata
+ oauth_urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
+
+ oauth_metadata = None
+ for url in oauth_urls:
+ response = await client.get(url)
+ ok, asm = await handle_auth_metadata_response(response)
+ if ok and asm:
+ oauth_metadata = asm
+ break
+
+ if not oauth_metadata or not oauth_metadata.issuer:
+ raise ValueError("Could not discover OAuth metadata or issuer")
+
+ auth_issuer = str(oauth_metadata.issuer)
+
+ return auth_issuer, resource_id
+
+
+async def main() -> None:
+ """Example demonstrating enterprise managed authorization with MCP."""
+ server_url = "https://mcp-server.example.com"
+
+ # Step 1: Get ID token from your IdP (e.g., Okta, Azure AD)
+ id_token = await get_id_token_from_idp()
+
+ # Step 2: Discover MCP server's OAuth metadata and resource identifier
+ # This replaces hardcoding these values
+ mcp_server_auth_issuer, mcp_server_resource_id = await discover_mcp_server_metadata(server_url)
+ print(f"Discovered auth issuer: {mcp_server_auth_issuer}")
+ print(f"Discovered resource ID: {mcp_server_resource_id}")
+
+ # Step 3: Configure token exchange parameters using discovered values
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer=mcp_server_auth_issuer,
+ mcp_server_resource_id=mcp_server_resource_id,
+ scope="mcp:tools mcp:resources", # Optional scopes
+ )
+
+ # Step 4: Create enterprise auth provider
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url=server_url,
+ client_metadata=OAuthClientMetadata(
+ client_name="Enterprise MCP Client",
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token", # Your IdP's token endpoint
+ token_exchange_params=token_exchange_params,
+ # Optional: IdP client credentials if your IdP requires client authentication for token exchange
+ # idp_client_id="your-idp-client-id",
+ # idp_client_secret="your-idp-client-secret",
+ )
+
+ # Step 5: Create authenticated HTTP client
+ # The auth provider automatically handles the two-step token exchange:
+ # 1. ID Token -> ID-JAG (via IDP)
+ # 2. ID-JAG -> Access Token (via MCP server)
+ client = httpx.AsyncClient(auth=enterprise_auth, timeout=30.0)
+
+ # Step 6: Connect to MCP server with authenticated client
+ async with streamable_http_client(url=server_url, http_client=client) as (read, write):
+ async with ClientSession(read, write) as session:
+ await session.initialize()
+
+ # List available tools
+ tools_result = await session.list_tools()
+ print(f"Available tools: {[t.name for t in tools_result.tools]}")
+
+ # Call a tool - auth tokens are automatically managed
+ if tools_result.tools:
+ tool_name = tools_result.tools[0].name
+ result = await session.call_tool(tool_name, {})
+ print(f"Tool result: {result.content}")
+
+ # List available resources
+ resources = await session.list_resources()
+ for resource in resources.resources:
+ print(f"Resource: {resource.uri}")
+
+
+async def advanced_manual_flow() -> None:
+ """Advanced example showing manual token exchange.
+
+ Use cases for manual token exchange:
+ - Testing and debugging: Inspect ID-JAG claims before exchanging for access token
+ - Token caching: Store and reuse ID-JAG across multiple MCP server connections
+ - Custom error handling: Implement specific retry logic for each token exchange step
+ - Monitoring: Log token exchange metrics and performance
+ - Token introspection: Validate ID-JAG structure before sending to MCP server
+ """
+ id_token = await get_id_token_from_idp()
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example.com",
+ mcp_server_resource_id="https://mcp-server.example.com",
+ )
+
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example.com",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Manual token exchange (useful for debugging, caching, custom error handling, etc.)
+ async with httpx.AsyncClient() as client:
+ # Step 1: Exchange ID token for ID-JAG
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+ # WARNING: Only log tokens in development/testing environments
+ # In production, NEVER log tokens or token fragments as they are sensitive credentials
+ print(f"Obtained ID-JAG: {id_jag[:50]}...")
+
+ # Step 2: Build JWT bearer grant request
+ jwt_bearer_request = await enterprise_auth.exchange_id_jag_for_access_token(id_jag)
+ print(f"Built JWT bearer grant request to: {jwt_bearer_request.url}")
+
+ # Step 3: Execute the request to get access token
+ response = await client.send(jwt_bearer_request)
+ response.raise_for_status()
+ token_data = response.json()
+
+ access_token = OAuthToken(
+ access_token=token_data["access_token"],
+ token_type=token_data["token_type"],
+ expires_in=token_data.get("expires_in"),
+ )
+ # WARNING: In production, do not log token expiry or any token information
+ print(f"Access token obtained, expires in: {access_token.expires_in}s")
+
+ # Use the access token for API calls
+ _ = {"Authorization": f"Bearer {access_token.access_token}"}
+ # ... make authenticated requests with headers
+
+
+async def token_refresh_example() -> None:
+ """Example showing how to refresh tokens when they expire.
+
+ When your access token expires, you need to obtain a fresh ID token
+ from your enterprise IdP and use the refresh helper method.
+ """
+ # Initial setup
+ id_token = await get_id_token_from_idp()
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example.com",
+ mcp_server_resource_id="https://mcp-server.example.com",
+ )
+
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example.com",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ _ = httpx.AsyncClient(auth=enterprise_auth, timeout=30.0)
+
+ # Use the client for MCP operations...
+ # ... time passes and token expires ...
+
+ # When token expires, get a fresh ID token from your IdP
+ new_id_token = await get_id_token_from_idp()
+
+ # Refresh the authentication using the helper method
+ await enterprise_auth.refresh_with_new_id_token(new_id_token)
+
+ # Next API call will automatically use the refreshed tokens
+ # No need to recreate the client or session
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/src/mcp/client/auth/extensions/__init__.py b/src/mcp/client/auth/extensions/__init__.py
index e69de29bb..f9594864f 100644
--- a/src/mcp/client/auth/extensions/__init__.py
+++ b/src/mcp/client/auth/extensions/__init__.py
@@ -0,0 +1,33 @@
+"""MCP Client Auth Extensions."""
+
+from mcp.client.auth.extensions.client_credentials import (
+ ClientCredentialsOAuthProvider,
+ JWTParameters,
+ PrivateKeyJWTOAuthProvider,
+ RFC7523OAuthClientProvider,
+ SignedJWTParameters,
+ static_assertion_provider,
+)
+from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ IDJAGClaims,
+ IDJAGTokenExchangeResponse,
+ TokenExchangeParameters,
+ decode_id_jag,
+ validate_token_exchange_params,
+)
+
+__all__ = [
+ "ClientCredentialsOAuthProvider",
+ "static_assertion_provider",
+ "SignedJWTParameters",
+ "PrivateKeyJWTOAuthProvider",
+ "JWTParameters",
+ "RFC7523OAuthClientProvider",
+ "EnterpriseAuthOAuthClientProvider",
+ "IDJAGClaims",
+ "IDJAGTokenExchangeResponse",
+ "TokenExchangeParameters",
+ "decode_id_jag",
+ "validate_token_exchange_params",
+]
diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py
new file mode 100644
index 000000000..286f6bd2d
--- /dev/null
+++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py
@@ -0,0 +1,509 @@
+"""Enterprise Managed Authorization extension for MCP (SEP-990).
+
+Implements RFC 8693 Token Exchange and RFC 7523 JWT Bearer Grant for
+enterprise SSO integration.
+"""
+
+import logging
+import time
+from json import JSONDecodeError
+from typing import cast
+
+import httpx
+import jwt
+from pydantic import BaseModel, Field
+from typing_extensions import NotRequired, Required, TypedDict
+
+from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
+from mcp.shared.auth import OAuthClientMetadata
+
+logger = logging.getLogger(__name__)
+
+
+class TokenExchangeRequestData(TypedDict):
+ """Type definition for RFC 8693 Token Exchange request data.
+
+ Required fields are those mandated by RFC 8693.
+ Optional fields (NotRequired) may be included based on IdP requirements.
+ """
+
+ grant_type: Required[str]
+ requested_token_type: Required[str]
+ audience: Required[str]
+ resource: Required[str]
+ subject_token: Required[str]
+ subject_token_type: Required[str]
+ scope: NotRequired[str]
+ client_id: NotRequired[str]
+ client_secret: NotRequired[str]
+
+
+class JWTBearerGrantRequestData(TypedDict):
+ """Type definition for RFC 7523 JWT Bearer Grant request data.
+
+ Required fields are those mandated by RFC 7523.
+ Optional fields (NotRequired) are for client authentication.
+ """
+
+ grant_type: Required[str]
+ assertion: Required[str]
+ client_id: NotRequired[str]
+ client_secret: NotRequired[str]
+
+
+class TokenExchangeParameters(BaseModel):
+ """Parameters for RFC 8693 Token Exchange request."""
+
+ requested_token_type: str = Field(
+ default="urn:ietf:params:oauth:token-type:id-jag",
+ description="Type of token being requested (ID-JAG)",
+ )
+
+ audience: str = Field(
+ ...,
+ description="Issuer URL of the MCP Server's authorization server",
+ )
+
+ resource: str = Field(
+ ...,
+ description="RFC 9728 Resource Identifier of the MCP Server",
+ )
+
+ scope: str | None = Field(
+ default=None,
+ description="Space-separated list of scopes being requested",
+ )
+
+ subject_token: str = Field(
+ ...,
+ description="ID Token or SAML assertion for the end user",
+ )
+
+ subject_token_type: str = Field(
+ ...,
+ description="Type of subject token (id_token or saml2)",
+ )
+
+ @classmethod
+ def from_id_token(
+ cls,
+ id_token: str,
+ mcp_server_auth_issuer: str,
+ mcp_server_resource_id: str,
+ scope: str | None = None,
+ ) -> "TokenExchangeParameters":
+ """Create parameters for OIDC ID Token exchange."""
+ return cls(
+ subject_token=id_token,
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience=mcp_server_auth_issuer,
+ resource=mcp_server_resource_id,
+ scope=scope,
+ )
+
+ @classmethod
+ def from_saml_assertion(
+ cls,
+ saml_assertion: str,
+ mcp_server_auth_issuer: str,
+ mcp_server_resource_id: str,
+ scope: str | None = None,
+ ) -> "TokenExchangeParameters":
+ """Create parameters for SAML assertion exchange."""
+ return cls(
+ subject_token=saml_assertion,
+ subject_token_type="urn:ietf:params:oauth:token-type:saml2",
+ audience=mcp_server_auth_issuer,
+ resource=mcp_server_resource_id,
+ scope=scope,
+ )
+
+
+class IDJAGTokenExchangeResponse(BaseModel):
+ """Response from RFC 8693 Token Exchange for ID-JAG."""
+
+ issued_token_type: str = Field(
+ ...,
+ description="Type of token issued (should be id-jag)",
+ )
+
+ access_token: str = Field(
+ ...,
+ description="The ID-JAG token (named access_token per RFC 8693)",
+ )
+
+ token_type: str = Field(
+ ...,
+ description="Token type (should be N_A for ID-JAG)",
+ )
+
+ scope: str | None = Field(
+ default=None,
+ description="Granted scopes",
+ )
+
+ expires_in: int | None = Field(
+ default=None,
+ description="Lifetime in seconds",
+ )
+
+ @property
+ def id_jag(self) -> str:
+ """Get the ID-JAG token."""
+ return self.access_token
+
+
+class IDJAGClaims(BaseModel):
+ """Claims structure for Identity Assertion JWT Authorization Grant."""
+
+ model_config = {"extra": "allow"}
+
+ # JWT header
+ typ: str = Field(
+ ...,
+ description="JWT type - must be 'oauth-id-jag+jwt'",
+ )
+
+ # Required claims
+ jti: str = Field(..., description="Unique JWT ID")
+ iss: str = Field(..., description="IdP issuer URL")
+ sub: str = Field(..., description="Subject (user) identifier")
+ aud: str = Field(..., description="MCP Server's auth server issuer")
+ resource: str = Field(..., description="MCP Server resource identifier")
+ client_id: str = Field(..., description="MCP Client identifier")
+ exp: int = Field(..., description="Expiration timestamp")
+ iat: int = Field(..., description="Issued-at timestamp")
+
+ # Optional claims
+ scope: str | None = Field(None, description="Space-separated scopes")
+ email: str | None = Field(None, description="User email")
+
+
+class EnterpriseAuthOAuthClientProvider(OAuthClientProvider):
+ """OAuth client provider for Enterprise Managed Authorization (SEP-990).
+
+ Implements:
+ - RFC 8693: Token Exchange (ID Token → ID-JAG)
+ - RFC 7523: JWT Bearer Grant (ID-JAG → Access Token)
+
+ Concurrency & Thread Safety:
+ - SAFE: Concurrent requests within a single asyncio event loop. Token operations
+ are protected by the parent class's `OAuthContext.lock`.
+ - UNSAFE: Sharing a provider instance across multiple OS threads. Each thread
+ must instantiate its own provider and event loop.
+ - Note: Ensure any shared `TokenStorage` implementation is async-safe.
+ """
+
+ # Default ID-JAG expiry when IdP doesn't provide expires_in
+ # 15 minutes is a conservative default for enterprise environments
+ DEFAULT_ID_JAG_EXPIRY_SECONDS = 900
+
+ def __init__(
+ self,
+ server_url: str,
+ client_metadata: OAuthClientMetadata,
+ storage: TokenStorage,
+ idp_token_endpoint: str,
+ token_exchange_params: TokenExchangeParameters,
+ timeout: float = 300.0,
+ idp_client_id: str | None = None,
+ idp_client_secret: str | None = None,
+ default_id_jag_expiry: int = DEFAULT_ID_JAG_EXPIRY_SECONDS,
+ ) -> None:
+ """Initialize Enterprise Auth OAuth Client.
+
+ Args:
+ server_url: MCP server URL
+ client_metadata: OAuth client metadata
+ storage: Token storage implementation
+ idp_token_endpoint: Enterprise IdP token endpoint URL
+ token_exchange_params: Token exchange parameters
+ timeout: Request timeout in seconds
+ idp_client_id: Optional client ID registered with the IdP for token exchange
+ idp_client_secret: Optional client secret registered with the IdP for token exchange
+ default_id_jag_expiry: Fallback ID-JAG expiry in seconds if the IdP
+ omits `expires_in` (default: 900s/15m). Adjust to balance token
+ freshness against IdP request load.
+ """
+ super().__init__(
+ server_url=server_url,
+ client_metadata=client_metadata,
+ storage=storage,
+ timeout=timeout,
+ )
+ self.idp_token_endpoint = idp_token_endpoint
+ self.token_exchange_params = token_exchange_params
+ self.idp_client_id = idp_client_id
+ self.idp_client_secret = idp_client_secret
+ self.default_id_jag_expiry = default_id_jag_expiry
+ self._id_jag: str | None = None
+ self._id_jag_expiry: float | None = None
+
+ # Validate client authentication configuration
+ if idp_client_secret is not None and idp_client_id is None:
+ logger.warning(
+ "idp_client_secret provided without idp_client_id. "
+ "The secret will be sent to the IdP but may be ignored. "
+ "Consider providing both idp_client_id and idp_client_secret together."
+ )
+
+ async def exchange_token_for_id_jag(
+ self,
+ client: httpx.AsyncClient,
+ ) -> str:
+ """Exchange ID Token for ID-JAG using RFC 8693 Token Exchange.
+
+ Note: Overrides the configured `audience` with the discovered OAuth
+ issuer URL (if available) to satisfy MCP server `aud` claim requirements.
+
+ Args:
+ client: HTTP client for making requests
+
+ Returns:
+ The ID-JAG token string
+
+ Raises:
+ OAuthTokenError: If token exchange fails
+ """
+ logger.debug("Starting token exchange for ID-JAG")
+
+ audience = self.token_exchange_params.audience
+ if self.context.oauth_metadata and self.context.oauth_metadata.issuer:
+ discovered_issuer = str(self.context.oauth_metadata.issuer)
+ if audience != discovered_issuer:
+ logger.warning(
+ f"Overriding audience '{audience}' with discovered issuer '{discovered_issuer}'. "
+ f"To prevent this, set token_exchange_params.audience to the issuer URL."
+ )
+ audience = discovered_issuer
+
+ # Build token exchange request
+ token_data: TokenExchangeRequestData = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "requested_token_type": self.token_exchange_params.requested_token_type,
+ "audience": audience,
+ "resource": self.token_exchange_params.resource,
+ "subject_token": self.token_exchange_params.subject_token,
+ "subject_token_type": self.token_exchange_params.subject_token_type,
+ }
+
+ if self.token_exchange_params.scope and self.token_exchange_params.scope.strip():
+ token_data["scope"] = self.token_exchange_params.scope
+
+ # Add IdP client authentication if provided
+ if self.idp_client_id is not None:
+ token_data["client_id"] = self.idp_client_id
+ if self.idp_client_secret is not None:
+ token_data["client_secret"] = self.idp_client_secret
+
+ try:
+ response = await client.post(
+ self.idp_token_endpoint,
+ data=token_data,
+ timeout=self.context.timeout,
+ )
+
+ if response.status_code != 200:
+ error_data: dict[str, str] = {}
+ try:
+ if response.headers.get("content-type", "").startswith("application/json"):
+ error_data = response.json()
+ except JSONDecodeError:
+ # Response is not valid JSON, use default error handling
+ pass
+
+ error: str = error_data.get("error", "unknown_error")
+ error_description: str = error_data.get(
+ "error_description", f"Token exchange failed (HTTP {response.status_code})"
+ )
+ raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}")
+
+ # Parse response
+ token_response = IDJAGTokenExchangeResponse.model_validate_json(response.content)
+
+ # Validate response
+ if token_response.issued_token_type != "urn:ietf:params:oauth:token-type:id-jag":
+ raise OAuthTokenError(f"Unexpected token type: {token_response.issued_token_type}")
+
+ if token_response.token_type != "N_A":
+ logger.warning(f"Expected token_type 'N_A', got '{token_response.token_type}'")
+
+ logger.debug("Successfully obtained ID-JAG")
+ self._id_jag = token_response.id_jag
+
+ # Track ID-JAG expiry to avoid using stale cached tokens
+ if token_response.expires_in:
+ self._id_jag_expiry = time.time() + token_response.expires_in
+ else:
+ # If no expires_in, use configured default expiry
+ self._id_jag_expiry = time.time() + self.default_id_jag_expiry
+ logger.debug(
+ f"IdP did not provide expires_in, using default expiry of "
+ f"{self.default_id_jag_expiry} seconds for ID-JAG"
+ )
+
+ return token_response.id_jag
+
+ except httpx.HTTPError as e:
+ raise OAuthTokenError(f"HTTP error during token exchange: {e}") from e
+
+ async def exchange_id_jag_for_access_token(
+ self,
+ id_jag: str,
+ ) -> httpx.Request:
+ """Build JWT bearer grant request to exchange ID-JAG for access token (RFC 7523).
+
+ Builds the request without executing it. HTTP execution and error parsing
+ are deferred to the parent class's `async_auth_flow` for consistency.
+
+ Args:
+ id_jag: The ID-JAG token
+
+ Returns:
+ httpx.Request for the JWT bearer grant
+
+ Raises:
+ OAuthFlowError: If OAuth metadata not discovered
+ """
+ logger.info("Building JWT bearer grant request for ID-JAG")
+
+ # Discover token endpoint from MCP server if not already done
+ if not self.context.oauth_metadata or not self.context.oauth_metadata.token_endpoint:
+ raise OAuthFlowError("MCP server token endpoint not discovered")
+
+ token_endpoint = str(self.context.oauth_metadata.token_endpoint)
+
+ # Build JWT bearer grant request
+ token_data: JWTBearerGrantRequestData = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
+ "assertion": id_jag,
+ }
+
+ # Add client authentication
+ if self.context.client_info:
+ # Default to client_secret_basic if not specified (per OAuth 2.0 spec)
+ if self.context.client_info.token_endpoint_auth_method is None:
+ self.context.client_info.token_endpoint_auth_method = "client_secret_basic"
+
+ if self.context.client_info.client_id is not None:
+ token_data["client_id"] = self.context.client_info.client_id
+ if self.context.client_info.client_secret is not None:
+ token_data["client_secret"] = self.context.client_info.client_secret
+
+ # Apply client authentication method (handles client_secret_basic vs client_secret_post)
+ headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
+ # Cast to dict[str, str] for prepare_token_auth compatibility
+ # Double-cast to bypass TypedDict strictness for prepare_token_auth
+ data_dict = cast(dict[str, str], cast(object, token_data))
+ data_dict, headers = self.context.prepare_token_auth(data_dict, headers)
+
+ return httpx.Request("POST", token_endpoint, data=data_dict, headers=headers)
+
+ async def _perform_authorization(self) -> httpx.Request:
+ """Perform enterprise authorization flow.
+
+ Overrides parent method to use token exchange + JWT bearer grant
+ instead of standard authorization code flow.
+
+ This method:
+ 1. Exchanges IDP ID token for ID-JAG at the IDP server (direct HTTP call)
+ 2. Returns an httpx.Request for JWT bearer grant (ID-JAG → Access token)
+
+ Returns:
+ httpx.Request for the JWT bearer grant to the MCP authorization server
+ """
+ # Check if we already have valid tokens
+ if self.context.is_token_valid():
+ # Reuse unexpired cached ID-JAG to prevent auth failures
+ if self._id_jag and self._id_jag_expiry:
+ if time.time() < self._id_jag_expiry:
+ logger.debug("Reusing cached ID-JAG for JWT bearer grant")
+ return await self.exchange_id_jag_for_access_token(self._id_jag)
+ else:
+ logger.debug("Cached ID-JAG expired, will obtain a new one")
+ # Fall through to full flow if ID-JAG is expired or missing (e.g., loaded from storage)
+
+ # Step 1: Exchange IDP ID token for ID-JAG (RFC 8693)
+ # This is an external call to the IDP, so we make it directly
+ async with httpx.AsyncClient(timeout=self.context.timeout) as client:
+ id_jag = await self.exchange_token_for_id_jag(client)
+ # Cache the ID-JAG for potential reuse
+ self._id_jag = id_jag
+
+ # Step 2: Build JWT bearer grant request (RFC 7523)
+ # This request will be yielded by the parent's async_auth_flow
+ # and the response will be handled by _handle_token_response
+ jwt_bearer_request = await self.exchange_id_jag_for_access_token(id_jag)
+
+ logger.debug("Returning JWT bearer grant request to async_auth_flow")
+ return jwt_bearer_request
+
+ async def refresh_with_new_id_token(self, new_id_token: str) -> None:
+ """Refresh MCP server access tokens using a fresh ID token from the IdP.
+
+ Updates the subject token and clears cached tokens (including ID-JAG),
+ triggering re-authentication on the next API request.
+
+ Note: OAuth metadata is not re-discovered. If the MCP server's OAuth
+ configuration has changed, you must create a new provider instance.
+
+ Args:
+ new_id_token: Fresh ID token obtained from your enterprise IdP.
+ """
+ logger.info("Refreshing tokens with new ID token from IdP")
+ self.token_exchange_params.subject_token = new_id_token
+
+ # Clear caches to force ID-JAG re-exchange and re-authentication
+ self._id_jag = None
+ self._id_jag_expiry = None
+ self.context.clear_tokens()
+ logger.debug("Token refresh prepared - will re-authenticate on next request")
+
+
+def decode_id_jag(id_jag: str) -> IDJAGClaims:
+ """Decode an ID-JAG token without verification.
+
+ Args:
+ id_jag: The ID-JAG token string
+
+ Returns:
+ Decoded ID-JAG claims
+
+ Note:
+ This function does not verify the JWT, instead relying on the receiving server to validate it.
+ """
+ # Decode without verification for inspection
+ claims = jwt.decode(id_jag, options={"verify_signature": False})
+ header = jwt.get_unverified_header(id_jag)
+
+ # Add typ from header to claims
+ claims["typ"] = header.get("typ", "")
+
+ return IDJAGClaims.model_validate(claims)
+
+
+def validate_token_exchange_params(
+ params: TokenExchangeParameters,
+) -> None:
+ """Validate token exchange parameters.
+
+ Args:
+ params: Token exchange parameters to validate
+
+ Raises:
+ OAuthFlowError: If parameters are invalid
+ """
+ if not params.subject_token:
+ raise OAuthFlowError("subject_token is required")
+
+ if not params.audience:
+ raise OAuthFlowError("audience is required")
+
+ if not params.resource:
+ raise OAuthFlowError("resource is required")
+
+ if params.subject_token_type not in [
+ "urn:ietf:params:oauth:token-type:id_token",
+ "urn:ietf:params:oauth:token-type:saml2",
+ ]:
+ raise OAuthFlowError(f"Invalid subject_token_type: {params.subject_token_type}")
diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py
new file mode 100644
index 000000000..b63083076
--- /dev/null
+++ b/tests/client/auth/test_enterprise_managed_auth_client.py
@@ -0,0 +1,1624 @@
+"""Tests for Enterprise Managed Authorization client-side implementation."""
+
+import logging
+import time
+import urllib.parse
+from typing import Any
+from unittest.mock import AsyncMock, Mock, patch
+
+import httpx
+import jwt
+import pytest
+from pydantic import AnyHttpUrl, AnyUrl
+
+from mcp.client.auth import OAuthFlowError, OAuthTokenError
+from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ IDJAGClaims,
+ IDJAGTokenExchangeResponse,
+ TokenExchangeParameters,
+ decode_id_jag,
+ validate_token_exchange_params,
+)
+from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthToken
+
+
+@pytest.fixture
+def sample_id_token() -> str:
+ """Generate a sample ID token for testing."""
+ payload = {
+ "iss": "https://idp.example.com",
+ "sub": "user123",
+ "aud": "mcp-client-app",
+ "exp": int(time.time()) + 3600,
+ "iat": int(time.time()),
+ "email": "user@example.com",
+ }
+ return jwt.encode(payload, "secret", algorithm="HS256")
+
+
+@pytest.fixture
+def sample_id_jag() -> str:
+ """Generate a sample ID-JAG token for testing."""
+ # Create typed claims using IDJAGClaims model
+ claims = IDJAGClaims(
+ typ="oauth-id-jag+jwt",
+ jti="unique-jwt-id-12345",
+ iss="https://idp.example.com",
+ sub="user123",
+ aud="https://auth.mcp-server.example/",
+ resource="https://mcp-server.example/",
+ client_id="mcp-client-app",
+ exp=int(time.time()) + 300,
+ iat=int(time.time()),
+ scope="read write",
+ email=None, # Optional field
+ )
+
+ # Dump to dict for JWT encoding (exclude typ as it goes in header)
+ payload = claims.model_dump(exclude={"typ"}, exclude_none=True)
+
+ return jwt.encode(payload, "secret", algorithm="HS256", headers={"typ": "oauth-id-jag+jwt"})
+
+
+@pytest.fixture
+def mock_token_storage() -> Any:
+ """Create a mock token storage."""
+ storage = Mock()
+ storage.get_tokens = AsyncMock(return_value=None)
+ storage.set_tokens = AsyncMock()
+ storage.get_client_info = AsyncMock(return_value=None)
+ storage.set_client_info = AsyncMock()
+ return storage
+
+
+def test_token_exchange_params_from_id_token():
+ """Test creating TokenExchangeParameters from ID token."""
+ params = TokenExchangeParameters.from_id_token(
+ id_token="eyJhbGc...",
+ mcp_server_auth_issuer="https://auth.server.example/",
+ mcp_server_resource_id="https://server.example/",
+ scope="read write",
+ )
+
+ assert params.subject_token == "eyJhbGc..."
+ assert params.subject_token_type == "urn:ietf:params:oauth:token-type:id_token"
+ assert params.audience == "https://auth.server.example/"
+ assert params.resource == "https://server.example/"
+ assert params.scope == "read write"
+ assert params.requested_token_type == "urn:ietf:params:oauth:token-type:id-jag"
+
+
+def test_token_exchange_params_from_saml_assertion():
+ """Test creating TokenExchangeParameters from SAML assertion."""
+ params = TokenExchangeParameters.from_saml_assertion(
+ saml_assertion="...",
+ mcp_server_auth_issuer="https://auth.server.example/",
+ mcp_server_resource_id="https://server.example/",
+ scope="read",
+ )
+
+ assert params.subject_token == "..."
+ assert params.subject_token_type == "urn:ietf:params:oauth:token-type:saml2"
+ assert params.audience == "https://auth.server.example/"
+ assert params.resource == "https://server.example/"
+ assert params.scope == "read"
+
+
+def test_validate_token_exchange_params_valid():
+ """Test validating valid token exchange parameters."""
+ params = TokenExchangeParameters.from_id_token(
+ id_token="token",
+ mcp_server_auth_issuer="https://auth.example/",
+ mcp_server_resource_id="https://server.example/",
+ )
+
+ # Should not raise
+ validate_token_exchange_params(params)
+
+
+def test_validate_token_exchange_params_invalid_token_type():
+ """Test validation fails for invalid subject token type."""
+ params = TokenExchangeParameters(
+ subject_token="token",
+ subject_token_type="invalid:type",
+ audience="https://auth.example/",
+ resource="https://server.example/",
+ )
+
+ with pytest.raises(OAuthFlowError, match="Invalid subject_token_type"):
+ validate_token_exchange_params(params)
+
+
+def test_validate_token_exchange_params_missing_subject_token():
+ """Test validation fails for missing subject token."""
+ params = TokenExchangeParameters(
+ subject_token="",
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience="https://auth.example/",
+ resource="https://server.example/",
+ )
+
+ with pytest.raises(OAuthFlowError, match="subject_token is required"):
+ validate_token_exchange_params(params)
+
+
+def test_token_exchange_response_parsing():
+ """Test parsing token exchange response."""
+ response_json = """{
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": "eyJhbGc...",
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300
+ }"""
+
+ response = IDJAGTokenExchangeResponse.model_validate_json(response_json)
+
+ assert response.issued_token_type == "urn:ietf:params:oauth:token-type:id-jag"
+ assert response.id_jag == "eyJhbGc..."
+ assert response.access_token == "eyJhbGc..."
+ assert response.token_type == "N_A"
+ assert response.scope == "read write"
+ assert response.expires_in == 300
+
+
+def test_token_exchange_response_id_jag_property():
+ """Test id_jag property returns access_token."""
+ response = IDJAGTokenExchangeResponse(
+ issued_token_type="urn:ietf:params:oauth:token-type:id-jag",
+ access_token="the-id-jag-token",
+ token_type="N_A",
+ )
+
+ assert response.id_jag == "the-id-jag-token"
+
+
+def test_decode_id_jag(sample_id_jag: str):
+ """Test decoding ID-JAG token."""
+ claims = decode_id_jag(sample_id_jag)
+
+ assert claims.iss == "https://idp.example.com"
+ assert claims.sub == "user123"
+ assert claims.aud == "https://auth.mcp-server.example/"
+ assert claims.resource == "https://mcp-server.example/"
+ assert claims.client_id == "mcp-client-app"
+ assert claims.scope == "read write"
+
+
+def test_decode_id_jag_invalid_jwt():
+ """Test decoding malformed ID-JAG raises appropriate error."""
+ with pytest.raises(jwt.DecodeError):
+ decode_id_jag("not.a.valid.jwt")
+
+
+def test_decode_id_jag_incomplete_jwt():
+ """Test decoding incomplete JWT raises error."""
+ with pytest.raises(jwt.DecodeError):
+ decode_id_jag("only.two.parts")
+
+
+def test_id_jag_claims_with_extra_fields():
+ """Test IDJAGClaims allows extra fields."""
+ claims_data = {
+ "typ": "oauth-id-jag+jwt",
+ "jti": "jti123",
+ "iss": "https://idp.example.com",
+ "sub": "user123",
+ "aud": "https://auth.server.example/",
+ "resource": "https://server.example/",
+ "client_id": "client123",
+ "exp": int(time.time()) + 300,
+ "iat": int(time.time()),
+ "scope": "read",
+ "email": "user@example.com",
+ "custom_claim": "custom_value", # Extra field
+ }
+
+ claims = IDJAGClaims.model_validate(claims_data)
+ assert claims.email == "user@example.com"
+ # Extra field should be preserved
+ assert claims.model_extra is not None and claims.model_extra.get("custom_claim") == "custom_value"
+
+
+# ============================================================================
+# Tests for EnterpriseAuthOAuthClientProvider
+# ============================================================================
+
+
+@pytest.mark.anyio
+async def test_exchange_token_for_id_jag_success(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test successful token exchange for ID-JAG."""
+ # Create provider
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify
+ assert id_jag == sample_id_jag
+ assert provider._id_jag == sample_id_jag
+
+ # Verify request was made correctly
+ mock_client.post.assert_called_once()
+ call_args = mock_client.post.call_args
+ assert call_args[0][0] == "https://idp.example.com/oauth2/token"
+ assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:token-exchange"
+ assert call_args[1]["data"]["requested_token_type"] == "urn:ietf:params:oauth:token-type:id-jag"
+ assert call_args[1]["data"]["audience"] == "https://auth.mcp-server.example/"
+ assert call_args[1]["data"]["resource"] == "https://mcp-server.example/"
+
+
+@pytest.mark.anyio
+async def test_exchange_token_for_id_jag_error(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange failure handling."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock error response
+ mock_response = httpx.Response(
+ status_code=400,
+ json={
+ "error": "invalid_request",
+ "error_description": "Invalid subject token",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="Token exchange failed"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_for_id_jag_unexpected_token_type(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with unexpected token type."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock response with wrong token type
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "access_token": "some-token",
+ "token_type": "Bearer",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="Unexpected token type"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_for_access_token_success(sample_id_jag: str, mock_token_storage: Any):
+ """Test building JWT bearer grant request."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Build JWT bearer grant request
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+ # Verify the request was built correctly
+ assert isinstance(request, httpx.Request)
+ assert request.method == "POST"
+ assert str(request.url) == "https://auth.mcp-server.example/oauth2/token"
+
+ # Parse the request body
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["grant_type"][0] == "urn:ietf:params:oauth:grant-type:jwt-bearer"
+ assert body_params["assertion"][0] == sample_id_jag
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_for_access_token_no_metadata(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant fails without OAuth metadata."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # No OAuth metadata set
+
+ # Should raise OAuthFlowError
+ with pytest.raises(OAuthFlowError, match="token endpoint not discovered"):
+ await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_full_flow(mock_token_storage: Any, sample_id_jag: str):
+ """Test that _perform_authorization performs token exchange and builds JWT bearer request."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock the IDP token exchange response
+ with patch("httpx.AsyncClient") as mock_client_class:
+ mock_client = AsyncMock()
+ mock_client_class.return_value.__aenter__.return_value = mock_client
+
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform authorization
+ request = await provider._perform_authorization()
+
+ # Verify it returns an httpx.Request for JWT bearer grant
+ assert isinstance(request, httpx.Request)
+ assert request.method == "POST"
+ assert str(request.url) == "https://auth.mcp-server.example/oauth2/token"
+
+ # Verify the request contains JWT bearer grant
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["grant_type"][0] == "urn:ietf:params:oauth:grant-type:jwt-bearer"
+ assert body_params["assertion"][0] == sample_id_jag
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_with_valid_tokens(mock_token_storage: Any, sample_id_jag: str):
+ """Test that _perform_authorization uses cached ID-JAG when tokens are valid."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Set valid tokens and cached ID-JAG with valid expiry
+ provider.context.current_tokens = OAuthToken(
+ token_type="Bearer",
+ access_token="valid-token",
+ expires_in=3600,
+ )
+ provider.context.token_expiry_time = time.time() + 3600
+ provider._id_jag = sample_id_jag
+ provider._id_jag_expiry = time.time() + 300 # Valid for 5 more minutes
+
+ # Should return a JWT bearer grant request using cached ID-JAG
+ request = await provider._perform_authorization()
+ assert isinstance(request, httpx.Request)
+ assert request.method == "POST"
+ assert str(request.url) == "https://auth.mcp-server.example/oauth2/token"
+
+ # Verify it uses the cached ID-JAG
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["assertion"][0] == sample_id_jag
+
+
+@pytest.mark.anyio
+async def test_exchange_token_with_client_authentication(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test token exchange with client authentication."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ idp_client_id="test-idp-client-id", # IdP client ID, not MCP client ID
+ idp_client_secret="test-idp-client-secret", # IdP client secret
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the ID-JAG was returned
+ assert id_jag == sample_id_jag
+
+ # Verify client credentials were included
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["client_id"] == "test-idp-client-id"
+ assert call_args[1]["data"]["client_secret"] == "test-idp-client-secret"
+
+
+@pytest.mark.anyio
+async def test_exchange_token_with_client_id_only(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test token exchange with client_id but no client_secret (covers branch 232->235)."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ idp_client_id="test-idp-client-id", # IdP client ID, not MCP client ID
+ idp_client_secret=None, # No secret
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the ID-JAG was returned
+ assert id_jag == sample_id_jag
+
+ # Verify client_id was included but NOT client_secret
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["client_id"] == "test-idp-client-id"
+ assert "client_secret" not in call_args[1]["data"]
+
+
+@pytest.mark.anyio
+async def test_exchange_token_http_error(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with HTTP error."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection failed"))
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="HTTP error during token exchange"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_malformed_json_error_response(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with malformed JSON error response that raises JSONDecodeError."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock error response with malformed JSON (will raise JSONDecodeError when parsing)
+ mock_response = httpx.Response(
+ status_code=400,
+ content=b'{"error": "invalid_request", "invalid json structure', # Malformed JSON
+ headers={"content-type": "application/json"},
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError with default error message including status code
+ with pytest.raises(OAuthTokenError, match=r"Token exchange failed.*HTTP 400"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_non_json_error_response(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with non-JSON error response."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock error response with non-JSON content
+ mock_response = httpx.Response(
+ status_code=500,
+ content=b"Internal Server Error",
+ headers={"content-type": "text/plain"},
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError with default error
+ with pytest.raises(OAuthTokenError, match="Token exchange failed: unknown_error"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_warning_for_non_na_token_type(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test token exchange logs warning for non-N_A token type."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock response with different token_type
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "Bearer", # Not N_A
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should succeed but log warning
+ with patch.object(
+ logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning"
+ ) as mock_warning:
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+ assert id_jag == sample_id_jag
+ mock_warning.assert_called_once()
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_client_authentication(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant request building with client authentication."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info with secret
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret="test-client-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Build JWT bearer grant request
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+ # Verify request was built correctly
+ assert isinstance(request, httpx.Request)
+ assert request.method == "POST"
+
+ # Verify client credentials were included in request body
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["client_id"][0] == "test-client-id"
+ # With client_secret_basic (default), credentials should be in Authorization header
+ assert "Authorization" in request.headers
+ assert request.headers["Authorization"].startswith("Basic ")
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_client_id_only(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant request building with client_id but no client_secret."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info WITHOUT secret (client_secret=None)
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret=None, # No secret
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Build JWT bearer grant request
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+ # Verify request was built correctly
+ assert isinstance(request, httpx.Request)
+
+ # Verify client_id was included but NOT client_secret
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["client_id"][0] == "test-client-id"
+ assert "client_secret" not in body_params
+ # With no client_secret, there should be no Authorization header either
+ assert "Authorization" not in request.headers or not request.headers["Authorization"].startswith("Basic ")
+
+
+@pytest.mark.anyio
+async def test_exchange_token_with_client_info_but_no_client_id(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test token exchange when only client_secret is provided (no client_id)."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ idp_client_id=None, # No client ID
+ idp_client_secret="test-idp-secret", # But has secret
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the ID-JAG was returned
+ assert id_jag == sample_id_jag
+
+ # Verify client_id was not included (None), but client_secret was included
+ call_args = mock_client.post.call_args
+ assert "client_id" not in call_args[1]["data"]
+ assert call_args[1]["data"]["client_secret"] == "test-idp-secret"
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_client_info_but_no_client_id(sample_id_jag: str, mock_token_storage: Any):
+ """Test ID-JAG exchange request building when client_info exists but client_id is None."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Set client info with client_id=None
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id=None, # This should skip the client_id assignment
+ client_secret="test-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Build JWT bearer grant request
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+ # Verify request was built correctly
+ assert isinstance(request, httpx.Request)
+
+ # Verify client_id was not included (None), but client_secret should be handled
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert "client_id" not in body_params or body_params["client_id"][0] == ""
+
+
+def test_validate_token_exchange_params_missing_audience():
+ """Test validation fails for missing audience."""
+ params = TokenExchangeParameters(
+ subject_token="token",
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience="",
+ resource="https://server.example/",
+ )
+
+ with pytest.raises(OAuthFlowError, match="audience is required"):
+ validate_token_exchange_params(params)
+
+
+def test_validate_token_exchange_params_missing_resource():
+ """Test validation fails for missing resource."""
+ params = TokenExchangeParameters(
+ subject_token="token",
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience="https://auth.example/",
+ resource="",
+ )
+
+ with pytest.raises(OAuthFlowError, match="resource is required"):
+ validate_token_exchange_params(params)
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_existing_auth_method(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant when token_endpoint_auth_method is already set (covers branch 323->326)."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info WITH auth method already set
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret="test-client-secret",
+ token_endpoint_auth_method="client_secret_post", # Already set
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Build JWT bearer grant request
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+ # Verify request was built correctly
+ assert isinstance(request, httpx.Request)
+
+ # Verify it used client_secret_post (in body, not header)
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["client_id"][0] == "test-client-id"
+ assert body_params["client_secret"][0] == "test-client-secret"
+ # Should NOT have Authorization header for client_secret_post
+ assert "Authorization" not in request.headers or not request.headers["Authorization"].startswith("Basic ")
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_with_valid_tokens_no_id_jag(mock_token_storage: Any):
+ """Test _perform_authorization when tokens are valid but no cached ID-JAG (covers branch 354->360)."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Set valid tokens but NO cached ID-JAG
+ provider.context.current_tokens = OAuthToken(
+ token_type="Bearer",
+ access_token="valid-token",
+ expires_in=3600,
+ )
+ provider.context.token_expiry_time = time.time() + 3600
+ provider._id_jag = None # No cached ID-JAG
+
+ # Mock the IDP token exchange response
+ sample_id_jag = "test-id-jag-token"
+ with patch("httpx.AsyncClient") as mock_client_class:
+ mock_client = AsyncMock()
+ mock_client_class.return_value.__aenter__.return_value = mock_client
+
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should fall through and perform full flow
+ request = await provider._perform_authorization()
+
+ # Verify it returns a JWT bearer grant request
+ assert isinstance(request, httpx.Request)
+ assert request.method == "POST"
+
+ # Verify it made the IDP token exchange call
+ mock_client.post.assert_called_once()
+
+
+@pytest.mark.anyio
+async def test_refresh_with_new_id_token(mock_token_storage: Any):
+ """Test refresh_with_new_id_token helper method."""
+ old_id_token = "old-id-token"
+ new_id_token = "new-id-token"
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=old_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set some existing tokens and cached ID-JAG
+ provider.context.current_tokens = OAuthToken(
+ token_type="Bearer",
+ access_token="old-access-token",
+ expires_in=3600,
+ )
+ provider._id_jag = "old-id-jag"
+ provider._id_jag_expiry = time.time() + 3600
+
+ # Verify initial state
+ assert provider.token_exchange_params.subject_token == old_id_token
+ assert provider._id_jag == "old-id-jag"
+ assert provider._id_jag_expiry is not None
+ assert provider.context.current_tokens.access_token == "old-access-token"
+
+ # Call refresh with new ID token
+ await provider.refresh_with_new_id_token(new_id_token)
+
+ # Verify state after refresh
+ assert provider.token_exchange_params.subject_token == new_id_token
+ assert provider._id_jag is None # Cached ID-JAG should be cleared
+ assert provider._id_jag_expiry is None # Expiry should be cleared
+ assert provider.context.current_tokens is None # Tokens should be cleared
+ assert provider.context.token_expiry_time is None # Expiry should be cleared
+
+
+@pytest.mark.anyio
+async def test_id_jag_expiry_tracking(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test that ID-JAG expiry is tracked when obtained from IdP."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock HTTP response with expires_in
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "expires_in": 300, # 5 minutes
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ before_time = time.time()
+ _ = await provider.exchange_token_for_id_jag(mock_client)
+ after_time = time.time()
+
+ # Verify ID-JAG was cached
+ assert provider._id_jag == sample_id_jag
+ # Verify expiry was set (should be current time + 300 seconds)
+ assert provider._id_jag_expiry is not None
+ assert before_time + 300 <= provider._id_jag_expiry <= after_time + 300
+
+
+@pytest.mark.anyio
+async def test_id_jag_expiry_default_when_not_provided(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test that default expiry (15 minutes = 900 seconds) is used when expires_in not provided."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock HTTP response WITHOUT expires_in
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ # No expires_in field
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ before_time = time.time()
+ await provider.exchange_token_for_id_jag(mock_client)
+ after_time = time.time()
+
+ # Verify default expiry was set (900 seconds = 15 minutes)
+ assert provider._id_jag_expiry is not None
+ assert before_time + 900 <= provider._id_jag_expiry <= after_time + 900
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_checks_id_jag_expiry(mock_token_storage: Any, sample_id_jag: str):
+ """Test that _perform_authorization checks ID-JAG expiry before reusing."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Set valid tokens and cached ID-JAG that has EXPIRED
+ provider.context.current_tokens = OAuthToken(
+ token_type="Bearer",
+ access_token="valid-token",
+ expires_in=3600,
+ )
+ provider.context.token_expiry_time = time.time() + 3600
+ provider._id_jag = sample_id_jag
+ provider._id_jag_expiry = time.time() - 10 # Expired 10 seconds ago
+
+ # Mock the IDP token exchange response for new ID-JAG
+ new_id_jag = "new-id-jag-token"
+ with patch("httpx.AsyncClient") as mock_client_class:
+ mock_client = AsyncMock()
+ mock_client_class.return_value.__aenter__.return_value = mock_client
+
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": new_id_jag,
+ "token_type": "N_A",
+ },
+ )
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should get a new ID-JAG since the cached one is expired
+ request = await provider._perform_authorization()
+
+ # Verify it made the IDP token exchange call (didn't reuse expired ID-JAG)
+ mock_client.post.assert_called_once()
+
+ # Verify the request uses the NEW ID-JAG
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["assertion"][0] == new_id_jag
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_reuses_valid_cached_id_jag(mock_token_storage: Any, sample_id_jag: str):
+ """Test that _perform_authorization reuses cached ID-JAG when still valid."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Set valid tokens and cached ID-JAG that is STILL VALID
+ provider.context.current_tokens = OAuthToken(
+ token_type="Bearer",
+ access_token="valid-token",
+ expires_in=3600,
+ )
+ provider.context.token_expiry_time = time.time() + 3600
+ provider._id_jag = sample_id_jag
+ provider._id_jag_expiry = time.time() + 300 # Valid for 5 more minutes
+
+ # Should reuse cached ID-JAG without calling IdP
+ request = await provider._perform_authorization()
+
+ # Verify it returns a JWT bearer grant request using cached ID-JAG
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["assertion"][0] == sample_id_jag
+
+
+@pytest.mark.anyio
+async def test_audience_override_warning(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test that audience override logs a warning when values differ."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://configured-audience.example/", # Different from issuer
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set OAuth metadata with different issuer
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://actual-issuer.example/"), # Different from configured
+ authorization_endpoint=AnyHttpUrl("https://actual-issuer.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://actual-issuer.example/oauth2/token"),
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should log warning about audience override
+ with patch.object(
+ logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning"
+ ) as mock_warning:
+ await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify warning was called with message about override
+ mock_warning.assert_called_once()
+ warning_message = mock_warning.call_args[0][0]
+ assert "Overriding audience" in warning_message
+ assert "https://configured-audience.example/" in warning_message
+ assert "https://actual-issuer.example/" in warning_message
+
+
+@pytest.mark.anyio
+async def test_audience_no_warning_when_matching(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test that no warning is logged when audience matches issuer."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/", # Same as issuer
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set OAuth metadata with matching issuer
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"), # Same as configured
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should NOT log warning when values match
+ with patch.object(
+ logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning"
+ ) as mock_warning:
+ await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify warning was NOT called
+ mock_warning.assert_not_called()
+
+
+@pytest.mark.anyio
+async def test_empty_scope_not_included(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test that empty or whitespace-only scope is not included in token request."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope=" ", # Whitespace-only scope
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify scope was NOT included in request
+ call_args = mock_client.post.call_args
+ assert "scope" not in call_args[1]["data"]
+
+
+@pytest.mark.anyio
+async def test_custom_default_id_jag_expiry(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test that custom default_id_jag_expiry is used when IdP doesn't provide expires_in."""
+ custom_expiry = 1800 # 30 minutes
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ default_id_jag_expiry=custom_expiry, # Custom expiry
+ )
+
+ # Verify the custom default is set
+ assert provider.default_id_jag_expiry == custom_expiry
+
+ # Mock HTTP response WITHOUT expires_in
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ # No expires_in field
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ before_time = time.time()
+ await provider.exchange_token_for_id_jag(mock_client)
+ after_time = time.time()
+
+ # Verify custom expiry was used (1800 seconds)
+ assert provider._id_jag_expiry is not None
+ assert before_time + custom_expiry <= provider._id_jag_expiry <= after_time + custom_expiry
+
+
+@pytest.mark.anyio
+async def test_default_id_jag_expiry_constant(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test that DEFAULT_ID_JAG_EXPIRY_SECONDS class constant is used by default."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ # Not providing default_id_jag_expiry, should use class constant
+ )
+
+ # Verify the class constant is used (900 seconds = 15 minutes)
+ assert provider.default_id_jag_expiry == EnterpriseAuthOAuthClientProvider.DEFAULT_ID_JAG_EXPIRY_SECONDS
+ assert provider.default_id_jag_expiry == 900 # 15 minutes
+
+ # Mock HTTP response WITHOUT expires_in
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ before_time = time.time()
+ await provider.exchange_token_for_id_jag(mock_client)
+ after_time = time.time()
+
+ # Verify default constant was used (900 seconds)
+ assert provider._id_jag_expiry is not None
+ assert before_time + 900 <= provider._id_jag_expiry <= after_time + 900
+
+
+@pytest.mark.anyio
+async def test_exchange_token_without_oauth_metadata(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test token exchange when oauth_metadata is not set.
+
+ This tests the scenario where OAuth metadata discovery hasn't happened yet.
+ The configured audience from token_exchange_params should be used directly.
+
+ Note: Testing issuer=None is not possible because OAuthMetadata.issuer is a
+ required AnyHttpUrl field per RFC 8414, so the Pydantic model prevents None.
+ """
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.configured.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # No OAuth metadata set (None)
+ assert provider.context.oauth_metadata is None
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the configured audience was used (no override since metadata is None)
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["audience"] == "https://auth.configured.example/"