diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 2e52b4066..bfb5d6c2a 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -177,10 +177,7 @@ async def stdin_writer(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async with ( - anyio.create_task_group() as tg, - process, - ): + async with anyio.create_task_group() as tg, process: tg.start_soon(stdout_reader) tg.start_soon(stdin_writer) try: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index c48445366..4d8627a4b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -825,7 +825,7 @@ def streamable_http_app( host: str = "127.0.0.1", auth: AuthSettings | None = None, token_verifier: TokenVerifier | None = None, - auth_server_provider: (OAuthAuthorizationServerProvider[Any, Any, Any] | None) = None, + auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, custom_starlette_routes: list[Route] | None = None, debug: bool = False, ) -> Starlette: diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py index a6cd093d9..41b9224c1 100644 --- a/src/mcp/server/models.py +++ b/src/mcp/server/models.py @@ -4,10 +4,7 @@ from pydantic import BaseModel -from mcp.types import ( - Icon, - ServerCapabilities, -) +from mcp.types import Icon, ServerCapabilities class InitializationOptions(BaseModel): diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 531404f21..5a1614545 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -30,10 +30,7 @@ async def run_server(): @asynccontextmanager -async def stdio_server( - stdin: anyio.AsyncFile[str] | None = None, - stdout: anyio.AsyncFile[str] | None = None, -): +async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None): """Server transport for stdio: this communicates with an MCP client by reading from the current process' stdin and writing to stdout. """ diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 964c52b6f..7d7f2db85 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -124,20 +124,10 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: # Clear any remaining server instances self._server_instances.clear() - async def handle_request( - self, - scope: Scope, - receive: Receive, - send: Send, - ) -> None: + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. Dispatches to the appropriate handler based on stateless mode. - - Args: - scope: ASGI scope - receive: ASGI receive function - send: ASGI send function """ if self._task_group is None: raise RuntimeError("Task group is not initialized. Make sure to use run().") @@ -148,19 +138,8 @@ async def handle_request( else: await self._handle_stateful_request(scope, receive, send) - async def _handle_stateless_request( - self, - scope: Scope, - receive: Receive, - send: Send, - ) -> None: - """Process request in stateless mode - creating a new transport for each request. - - Args: - scope: ASGI scope - receive: ASGI receive function - send: ASGI send function - """ + async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: Send) -> None: + """Process request in stateless mode - creating a new transport for each request.""" logger.debug("Stateless mode: Creating new transport for this request") # No session ID needed in stateless mode http_transport = StreamableHTTPServerTransport( @@ -196,19 +175,8 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA # Terminate the transport after the request is handled await http_transport.terminate() - async def _handle_stateful_request( - self, - scope: Scope, - receive: Receive, - send: Send, - ) -> None: - """Process request in stateful mode - maintaining session state between requests. - - Args: - scope: ASGI scope - receive: ASGI receive function - send: ASGI send function - """ + async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: Send) -> None: + """Process request in stateful mode - maintaining session state between requests.""" request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) @@ -248,11 +216,8 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE self.app.create_initialization_options(), stateless=False, # Stateful mode ) - except Exception as e: - logger.error( - f"Session {http_transport.mcp_session_id} crashed: {e}", - exc_info=True, - ) + except Exception: + logger.exception(f"Session {http_transport.mcp_session_id} crashed") finally: # Only remove from instances if not terminated if ( # pragma: no branch @@ -262,8 +227,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE ): logger.info( "Cleaning up crashed session " - f"{http_transport.mcp_session_id} from " - "active instances." + f"{http_transport.mcp_session_id} from active instances." ) del self._server_instances[http_transport.mcp_session_id] diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index c8c049901..1ed9842c0 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -9,37 +9,35 @@ logger = logging.getLogger(__name__) +# TODO(Marcelo): We should flatten these settings. To be fair, I don't think we should even have this middleware. class TransportSecuritySettings(BaseModel): """Settings for MCP transport security features. - These settings help protect against DNS rebinding attacks by validating - incoming request headers. + These settings help protect against DNS rebinding attacks by validating incoming request headers. """ - enable_dns_rebinding_protection: bool = Field( - default=True, - description="Enable DNS rebinding protection (recommended for production)", - ) + enable_dns_rebinding_protection: bool = True + """Enable DNS rebinding protection (recommended for production).""" - allowed_hosts: list[str] = Field( - default=[], - description="List of allowed Host header values. Only applies when " - + "enable_dns_rebinding_protection is True.", - ) + allowed_hosts: list[str] = Field(default_factory=list) + """List of allowed Host header values. - allowed_origins: list[str] = Field( - default=[], - description="List of allowed Origin header values. Only applies when " - + "enable_dns_rebinding_protection is True.", - ) + Only applies when `enable_dns_rebinding_protection` is `True`. + """ + + allowed_origins: list[str] = Field(default_factory=list) + """List of allowed Origin header values. + + Only applies when `enable_dns_rebinding_protection` is `True`. + """ +# TODO(Marcelo): This should be a proper ASGI middleware. I'm sad to see this. class TransportSecurityMiddleware: """Middleware to enforce DNS rebinding protection for MCP transport endpoints.""" def __init__(self, settings: TransportSecuritySettings | None = None): - # If not specified, disable DNS rebinding protection by default - # for backwards compatibility + # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) def _validate_host(self, host: str | None) -> bool: # pragma: no cover @@ -88,16 +86,7 @@ def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover def _validate_content_type(self, content_type: str | None) -> bool: """Validate the Content-Type header for POST requests.""" - if not content_type: # pragma: lax no cover - logger.warning("Missing Content-Type header in POST request") - return False - - # Content-Type must start with application/json - if not content_type.lower().startswith("application/json"): - logger.warning(f"Invalid Content-Type header: {content_type}") - return False - - return True + return content_type is not None and content_type.lower().startswith("application/json") async def validate_request(self, request: Request, is_post: bool = False) -> Response | None: """Validate request headers for DNS rebinding protection. diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 7be607fe1..d01d28b80 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -28,10 +28,5 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS client_streams = (server_to_client_receive, client_to_server_send) server_streams = (client_to_server_receive, server_to_client_send) - async with ( - server_to_client_receive, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - ): + async with server_to_client_receive, client_to_server_send, client_to_server_receive, server_to_client_send: yield client_streams, server_streams