Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 252 additions & 0 deletions .github/actions/conformance/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
elicitation-sep1034-client-defaults - Elicitation with default accept callback
auth/client-credentials-jwt - Client credentials with private_key_jwt
auth/client-credentials-basic - Client credentials with client_secret_basic
auth/enterprise-token-exchange - Enterprise auth with OIDC ID token (SEP-990)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these tests have been consolidated to 1, see modelcontextprotocol/conformance#110 and please update to match

auth/enterprise-saml-exchange - Enterprise auth with SAML assertion (SEP-990)
auth/enterprise-id-jag-validation - Validate ID-JAG token structure (SEP-990)
auth/* - Authorization code flow (default for auth scenarios)
"""

Expand Down Expand Up @@ -293,6 +296,255 @@ async def run_auth_code_client(server_url: str) -> None:
await _run_auth_session(server_url, oauth_auth)


@register("auth/enterprise-token-exchange")
async def run_enterprise_token_exchange(server_url: str) -> None:
"""Enterprise managed auth: Token exchange flow (RFC 8693)."""
from mcp.client.auth.extensions.enterprise_managed_auth import (
EnterpriseAuthOAuthClientProvider,
TokenExchangeParameters,
)

context = get_conformance_context()
id_token = context.get("id_token")
idp_token_endpoint = context.get("idp_token_endpoint")
mcp_server_auth_issuer = context.get("mcp_server_auth_issuer")
mcp_server_resource_id = context.get("mcp_server_resource_id")
scope = context.get("scope")

if not id_token:
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'id_token'")
if not idp_token_endpoint:
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_token_endpoint'")
if not mcp_server_auth_issuer:
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_auth_issuer'")
if not mcp_server_resource_id:
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_resource_id'")

# Create token exchange parameters
token_exchange_params = TokenExchangeParameters.from_id_token(
id_token=id_token,
mcp_server_auth_issuer=mcp_server_auth_issuer,
mcp_server_resource_id=mcp_server_resource_id,
scope=scope,
)

# Create enterprise auth provider
enterprise_auth = EnterpriseAuthOAuthClientProvider(
server_url=server_url,
client_metadata=OAuthClientMetadata(
client_name="conformance-enterprise-client",
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
response_types=["token"],
),
storage=InMemoryTokenStorage(),
idp_token_endpoint=idp_token_endpoint,
token_exchange_params=token_exchange_params,
)

# Perform token exchange flow
async with httpx.AsyncClient() as client:
# Step 1: Set OAuth metadata manually (since we're not going through full OAuth flow)
logger.debug(f"Setting OAuth metadata for {server_url}")
from pydantic import AnyUrl as PydanticAnyUrl

from mcp.shared.auth import OAuthMetadata

# Extract base URL from server_url
base_url = server_url.replace("/mcp", "")
token_endpoint_url = f"{base_url}/oauth/token"
auth_endpoint_url = f"{base_url}/oauth/authorize"

enterprise_auth.context.oauth_metadata = OAuthMetadata(
issuer=mcp_server_auth_issuer,
authorization_endpoint=PydanticAnyUrl(auth_endpoint_url),
token_endpoint=PydanticAnyUrl(token_endpoint_url),
)
logger.debug(f"OAuth metadata set, token_endpoint: {token_endpoint_url}")

# Step 2: Exchange ID token for ID-JAG
logger.debug("Exchanging ID token for ID-JAG")
id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
logger.debug(f"Obtained ID-JAG: {id_jag[:50]}...")

# Step 3: Exchange ID-JAG for access token
logger.debug("Exchanging ID-JAG for access token")
access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
logger.debug(f"Obtained access token, expires in: {access_token.expires_in}s")

# Step 4: Verify we can make authenticated requests
logger.debug("Verifying access token with MCP endpoint")
auth_client = httpx.AsyncClient(headers={"Authorization": f"Bearer {access_token.access_token}"})
response = await auth_client.get(server_url.replace("/mcp", "") + "/mcp")
if response.status_code == 200:
logger.debug(f"Successfully authenticated with MCP server: {response.json()}")
else:
logger.warning(f"MCP server returned {response.status_code}")

logger.debug("Enterprise auth flow completed successfully")


@register("auth/enterprise-saml-exchange")
async def run_enterprise_saml_exchange(server_url: str) -> None:
"""Enterprise managed auth: SAML assertion exchange flow."""
from mcp.client.auth.extensions.enterprise_managed_auth import (
EnterpriseAuthOAuthClientProvider,
TokenExchangeParameters,
)

context = get_conformance_context()
saml_assertion = context.get("saml_assertion")
idp_token_endpoint = context.get("idp_token_endpoint")
mcp_server_auth_issuer = context.get("mcp_server_auth_issuer")
mcp_server_resource_id = context.get("mcp_server_resource_id")
scope = context.get("scope")

if not saml_assertion:
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'saml_assertion'")
if not idp_token_endpoint:
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_token_endpoint'")
if not mcp_server_auth_issuer:
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_auth_issuer'")
if not mcp_server_resource_id:
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_resource_id'")

# Create token exchange parameters for SAML
token_exchange_params = TokenExchangeParameters.from_saml_assertion(
saml_assertion=saml_assertion,
mcp_server_auth_issuer=mcp_server_auth_issuer,
mcp_server_resource_id=mcp_server_resource_id,
scope=scope,
)

# Create enterprise auth provider
enterprise_auth = EnterpriseAuthOAuthClientProvider(
server_url=server_url,
client_metadata=OAuthClientMetadata(
client_name="conformance-enterprise-saml-client",
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
response_types=["token"],
),
storage=InMemoryTokenStorage(),
idp_token_endpoint=idp_token_endpoint,
token_exchange_params=token_exchange_params,
)

# Perform token exchange flow
async with httpx.AsyncClient() as client:
# Step 1: Set OAuth metadata manually (since we're not going through full OAuth flow)
logger.debug(f"Setting OAuth metadata for {server_url}")
from pydantic import AnyUrl as PydanticAnyUrl

from mcp.shared.auth import OAuthMetadata

# Extract base URL from server_url
base_url = server_url.replace("/mcp", "")
token_endpoint_url = f"{base_url}/oauth/token"
auth_endpoint_url = f"{base_url}/oauth/authorize"

enterprise_auth.context.oauth_metadata = OAuthMetadata(
issuer=mcp_server_auth_issuer,
authorization_endpoint=PydanticAnyUrl(auth_endpoint_url),
token_endpoint=PydanticAnyUrl(token_endpoint_url),
)
logger.debug(f"OAuth metadata set, token_endpoint: {token_endpoint_url}")

# Step 2: Exchange SAML assertion for ID-JAG
logger.debug("Exchanging SAML assertion for ID-JAG")
id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
logger.debug(f"Obtained ID-JAG from SAML: {id_jag[:50]}...")

# Step 3: Exchange ID-JAG for access token
logger.debug("Exchanging ID-JAG for access token")
access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
logger.debug(f"Obtained access token, expires in: {access_token.expires_in}s")

# Step 4: Verify we can make authenticated requests
logger.debug("Verifying access token with MCP endpoint")
auth_client = httpx.AsyncClient(headers={"Authorization": f"Bearer {access_token.access_token}"})
response = await auth_client.get(server_url.replace("/mcp", "") + "/mcp")
if response.status_code == 200:
logger.debug(f"Successfully authenticated with MCP server: {response.json()}")
else:
logger.warning(f"MCP server returned {response.status_code}")

logger.debug("SAML enterprise auth flow completed successfully")


@register("auth/enterprise-id-jag-validation")
async def run_id_jag_validation(server_url: str) -> None:
"""Validate ID-JAG token structure and claims."""
from mcp.client.auth.extensions.enterprise_managed_auth import (
EnterpriseAuthOAuthClientProvider,
TokenExchangeParameters,
decode_id_jag,
validate_token_exchange_params,
)

context = get_conformance_context()
id_token = context.get("id_token")
idp_token_endpoint = context.get("idp_token_endpoint")
mcp_server_auth_issuer = context.get("mcp_server_auth_issuer")
mcp_server_resource_id = context.get("mcp_server_resource_id")

if not all([id_token, idp_token_endpoint, mcp_server_auth_issuer, mcp_server_resource_id]):
raise RuntimeError("Missing required context parameters for ID-JAG validation")

# Create and validate token exchange parameters
token_exchange_params = TokenExchangeParameters.from_id_token(
id_token=id_token,
mcp_server_auth_issuer=mcp_server_auth_issuer,
mcp_server_resource_id=mcp_server_resource_id,
)

logger.debug("Validating token exchange parameters")
validate_token_exchange_params(token_exchange_params)
logger.debug("Token exchange parameters validated successfully")

# Create enterprise auth provider
enterprise_auth = EnterpriseAuthOAuthClientProvider(
server_url=server_url,
client_metadata=OAuthClientMetadata(
client_name="conformance-validation-client",
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
response_types=["token"],
),
storage=InMemoryTokenStorage(),
idp_token_endpoint=idp_token_endpoint,
token_exchange_params=token_exchange_params,
)

async with httpx.AsyncClient() as client:
# Get ID-JAG
id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
logger.debug(f"Obtained ID-JAG for validation: {id_jag[:50]}...")

# Decode and validate ID-JAG claims
logger.debug("Decoding ID-JAG token")
claims = decode_id_jag(id_jag)

# Validate required claims
assert claims.typ == "oauth-id-jag+jwt", f"Invalid typ: {claims.typ}"
assert claims.jti, "Missing jti claim"
assert claims.iss == mcp_server_auth_issuer or claims.iss, "Missing or invalid iss claim"
assert claims.sub, "Missing sub claim"
assert claims.aud, "Missing aud claim"
assert claims.resource == mcp_server_resource_id, f"Invalid resource: {claims.resource}"
assert claims.client_id, "Missing client_id claim"
assert claims.exp > claims.iat, "Invalid expiration"

logger.debug("ID-JAG validated successfully:")
logger.debug(f" Subject: {claims.sub}")
logger.debug(f" Issuer: {claims.iss}")
logger.debug(f" Audience: {claims.aud}")
logger.debug(f" Resource: {claims.resource}")
logger.debug(f" Client ID: {claims.client_id}")

logger.debug("ID-JAG validation completed successfully")


async def _run_auth_session(server_url: str, oauth_auth: OAuthClientProvider) -> None:
"""Common session logic for all OAuth flows."""
client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0)
Expand Down
Loading