Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/mcp/client/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ async def stdin_writer():
except anyio.ClosedResourceError: # pragma: no cover
await anyio.lowlevel.checkpoint()

async with (
anyio.create_task_group() as tg,
process,
):
async with anyio.create_task_group() as tg, process:
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)
try:
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def streamable_http_app(
host: str = "127.0.0.1",
auth: AuthSettings | None = None,
token_verifier: TokenVerifier | None = None,
auth_server_provider: (OAuthAuthorizationServerProvider[Any, Any, Any] | None) = None,
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None,
custom_starlette_routes: list[Route] | None = None,
debug: bool = False,
) -> Starlette:
Expand Down
5 changes: 1 addition & 4 deletions src/mcp/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@

from pydantic import BaseModel

from mcp.types import (
Icon,
ServerCapabilities,
)
from mcp.types import Icon, ServerCapabilities


class InitializationOptions(BaseModel):
Expand Down
5 changes: 1 addition & 4 deletions src/mcp/server/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ async def run_server():


@asynccontextmanager
async def stdio_server(
stdin: anyio.AsyncFile[str] | None = None,
stdout: anyio.AsyncFile[str] | None = None,
):
async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None):
"""Server transport for stdio: this communicates with an MCP client by reading
from the current process' stdin and writing to stdout.
"""
Expand Down
52 changes: 8 additions & 44 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,10 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
# Clear any remaining server instances
self._server_instances.clear()

async def handle_request(
self,
scope: Scope,
receive: Receive,
send: Send,
) -> None:
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process ASGI request with proper session handling and transport setup.

Dispatches to the appropriate handler based on stateless mode.

Args:
scope: ASGI scope
receive: ASGI receive function
send: ASGI send function
"""
if self._task_group is None:
raise RuntimeError("Task group is not initialized. Make sure to use run().")
Expand All @@ -148,19 +138,8 @@ async def handle_request(
else:
await self._handle_stateful_request(scope, receive, send)

async def _handle_stateless_request(
self,
scope: Scope,
receive: Receive,
send: Send,
) -> None:
"""Process request in stateless mode - creating a new transport for each request.

Args:
scope: ASGI scope
receive: ASGI receive function
send: ASGI send function
"""
async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process request in stateless mode - creating a new transport for each request."""
logger.debug("Stateless mode: Creating new transport for this request")
# No session ID needed in stateless mode
http_transport = StreamableHTTPServerTransport(
Expand Down Expand Up @@ -196,19 +175,8 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
# Terminate the transport after the request is handled
await http_transport.terminate()

async def _handle_stateful_request(
self,
scope: Scope,
receive: Receive,
send: Send,
) -> None:
"""Process request in stateful mode - maintaining session state between requests.

Args:
scope: ASGI scope
receive: ASGI receive function
send: ASGI send function
"""
async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process request in stateful mode - maintaining session state between requests."""
request = Request(scope, receive)
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)

Expand Down Expand Up @@ -248,11 +216,8 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
self.app.create_initialization_options(),
stateless=False, # Stateful mode
)
except Exception as e:
logger.error(
f"Session {http_transport.mcp_session_id} crashed: {e}",
exc_info=True,
)
except Exception:
logger.exception(f"Session {http_transport.mcp_session_id} crashed")
Comment on lines +219 to +220
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I keep finding those.

finally:
# Only remove from instances if not terminated
if ( # pragma: no branch
Expand All @@ -262,8 +227,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
):
logger.info(
"Cleaning up crashed session "
f"{http_transport.mcp_session_id} from "
"active instances."
f"{http_transport.mcp_session_id} from active instances."
)
del self._server_instances[http_transport.mcp_session_id]

Expand Down
45 changes: 17 additions & 28 deletions src/mcp/server/transport_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,35 @@
logger = logging.getLogger(__name__)


# TODO(Marcelo): We should flatten these settings. To be fair, I don't think we should even have this middleware.
class TransportSecuritySettings(BaseModel):
"""Settings for MCP transport security features.

These settings help protect against DNS rebinding attacks by validating
incoming request headers.
These settings help protect against DNS rebinding attacks by validating incoming request headers.
"""

enable_dns_rebinding_protection: bool = Field(
default=True,
description="Enable DNS rebinding protection (recommended for production)",
)
enable_dns_rebinding_protection: bool = True
"""Enable DNS rebinding protection (recommended for production)."""

allowed_hosts: list[str] = Field(
default=[],
description="List of allowed Host header values. Only applies when "
+ "enable_dns_rebinding_protection is True.",
)
allowed_hosts: list[str] = Field(default_factory=list)
"""List of allowed Host header values.

allowed_origins: list[str] = Field(
default=[],
description="List of allowed Origin header values. Only applies when "
+ "enable_dns_rebinding_protection is True.",
)
Only applies when `enable_dns_rebinding_protection` is `True`.
"""

allowed_origins: list[str] = Field(default_factory=list)
"""List of allowed Origin header values.

Only applies when `enable_dns_rebinding_protection` is `True`.
"""


# TODO(Marcelo): This should be a proper ASGI middleware. I'm sad to see this.
class TransportSecurityMiddleware:
"""Middleware to enforce DNS rebinding protection for MCP transport endpoints."""

def __init__(self, settings: TransportSecuritySettings | None = None):
# If not specified, disable DNS rebinding protection by default
# for backwards compatibility
# If not specified, disable DNS rebinding protection by default for backwards compatibility
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)

def _validate_host(self, host: str | None) -> bool: # pragma: no cover
Expand Down Expand Up @@ -88,16 +86,7 @@ def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover

def _validate_content_type(self, content_type: str | None) -> bool:
"""Validate the Content-Type header for POST requests."""
if not content_type: # pragma: lax no cover
logger.warning("Missing Content-Type header in POST request")
return False

# Content-Type must start with application/json
if not content_type.lower().startswith("application/json"):
logger.warning(f"Invalid Content-Type header: {content_type}")
return False

return True
return content_type is not None and content_type.lower().startswith("application/json")

async def validate_request(self, request: Request, is_post: bool = False) -> Response | None:
"""Validate request headers for DNS rebinding protection.
Expand Down
7 changes: 1 addition & 6 deletions src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,5 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS
client_streams = (server_to_client_receive, client_to_server_send)
server_streams = (client_to_server_receive, server_to_client_send)

async with (
server_to_client_receive,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
):
async with server_to_client_receive, client_to_server_send, client_to_server_receive, server_to_client_send:
yield client_streams, server_streams