diff --git a/src/mcp/client/__init__.py b/src/mcp/client/__init__.py index 7b9464710..87414bf50 100644 --- a/src/mcp/client/__init__.py +++ b/src/mcp/client/__init__.py @@ -1,9 +1,10 @@ """MCP Client module.""" -from mcp.client.client import Client +from mcp.client.client import Client, ClientTarget from mcp.client.session import ClientSession __all__ = [ "Client", "ClientSession", + "ClientTarget", ] diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 6eafb794a..a3b039fae 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -9,7 +9,6 @@ from pydantic import AnyUrl import mcp.types as types -from mcp.client._memory import InMemoryTransport from mcp.client.session import ( ClientSession, ElicitationFnT, @@ -18,21 +17,62 @@ MessageHandlerFnT, SamplingFnT, ) +from mcp.client.transports import HttpTransport, InMemoryTransport, Transport from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.session import ProgressFnT logger = logging.getLogger(__name__) +# Type alias for all accepted target types +ClientTarget = Server[Any] | FastMCP | Transport | str + + +def _infer_transport( + target: ClientTarget, + *, + raise_exceptions: bool = False, +) -> Transport: + """Infer the appropriate transport from the target type. + + Args: + target: The target to connect to. Can be: + - Server or FastMCP instance: Uses InMemoryTransport + - Transport instance: Uses the transport directly + - str (URL): Uses HttpTransport (Streamable HTTP) + raise_exceptions: For InMemoryTransport, whether to raise exceptions + from the server. Ignored for other transport types. + + Returns: + A Transport instance ready to connect. + + Raises: + TypeError: If the target type is not recognized. + """ + # Already a transport - use directly + if isinstance(target, Transport): + return target + + # Server or FastMCP - use in-memory transport for testing + if isinstance(target, Server | FastMCP): + return InMemoryTransport(target, raise_exceptions=raise_exceptions) + + # URL string - use Streamable HTTP transport (modern standard) + # Note: After type narrowing above, target is str here + return HttpTransport(target) + class Client: """A high-level MCP client for connecting to MCP servers. - Currently supports in-memory transport for testing. Pass a Server or - FastMCP instance directly to the constructor. + Supports multiple transport types: + - In-memory: Pass a Server or FastMCP instance directly (for testing) + - HTTP: Pass a URL string or HttpTransport instance + - SSE: Pass an SSETransport instance (legacy) - Example: + Examples: ```python + # In-memory testing (recommended for unit tests) from mcp.client import Client from mcp.server.fastmcp import FastMCP @@ -44,21 +84,34 @@ def add(a: int, b: int) -> int: async with Client(server) as client: result = await client.call_tool("add", {"a": 1, "b": 2}) + + # HTTP connection via URL string + async with Client("http://localhost:8000/mcp") as client: + result = await client.call_tool("my_tool", {...}) + + # HTTP connection with custom headers + from mcp.client.transports import HttpTransport + + transport = HttpTransport( + "http://localhost:8000/mcp", + headers={"Authorization": "Bearer token"}, + ) + async with Client(transport) as client: + result = await client.call_tool("my_tool", {...}) + + # Legacy SSE connection + from mcp.client.transports import SSETransport + + async with Client(SSETransport("http://localhost:8000/sse")) as client: + result = await client.call_tool("my_tool", {...}) ``` """ - # TODO(felixweinberger): Expand to support all transport types (like FastMCP 2): - # - Add ClientTransport base class with connect_session() method - # - Add StreamableHttpTransport, SSETransport, StdioTransport - # - Add infer_transport() to auto-detect transport from input type - # - Accept URL strings, Path objects, config dicts in constructor - # - Add auth support (OAuth, bearer tokens) - def __init__( self, - server: Server[Any] | FastMCP, + target: ClientTarget, *, - # TODO(Marcelo): When do `raise_exceptions=True` actually raises? + # TODO(Marcelo): When does `raise_exceptions=True` actually raise? raise_exceptions: bool = False, read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, @@ -68,20 +121,24 @@ def __init__( client_info: types.Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, ) -> None: - """Initialize the client with a server. + """Initialize the client. Args: - server: The MCP server to connect to (Server or FastMCP instance) - raise_exceptions: Whether to raise exceptions from the server - read_timeout_seconds: Timeout for read operations - sampling_callback: Callback for handling sampling requests - list_roots_callback: Callback for handling list roots requests - logging_callback: Callback for handling logging notifications - message_handler: Callback for handling raw messages - client_info: Client implementation info to send to server - elicitation_callback: Callback for handling elicitation requests + target: The target to connect to. Can be: + - Server or FastMCP instance: Uses in-memory transport (for testing) + - Transport instance: Uses the transport directly + - str (URL): Uses HTTP transport (Streamable HTTP protocol) + raise_exceptions: For in-memory transport, whether to raise exceptions + from the server. Ignored for other transport types. + read_timeout_seconds: Timeout for read operations. + sampling_callback: Callback for handling sampling requests. + list_roots_callback: Callback for handling list roots requests. + logging_callback: Callback for handling logging notifications. + message_handler: Callback for handling raw messages. + client_info: Client implementation info to send to server. + elicitation_callback: Callback for handling elicitation requests. """ - self._server = server + self._target = target self._raise_exceptions = raise_exceptions self._read_timeout_seconds = read_timeout_seconds self._sampling_callback = sampling_callback @@ -100,8 +157,8 @@ async def __aenter__(self) -> Client: raise RuntimeError("Client is already entered; cannot reenter") async with AsyncExitStack() as exit_stack: - # Create transport and connect - transport = InMemoryTransport(self._server, raise_exceptions=self._raise_exceptions) + # Infer and connect transport + transport = _infer_transport(self._target, raise_exceptions=self._raise_exceptions) read_stream, write_stream = await exit_stack.enter_async_context(transport.connect()) # Create session diff --git a/src/mcp/client/transports/__init__.py b/src/mcp/client/transports/__init__.py new file mode 100644 index 000000000..f2b7cab29 --- /dev/null +++ b/src/mcp/client/transports/__init__.py @@ -0,0 +1,35 @@ +"""Transport implementations for MCP clients. + +This module provides transport abstractions for connecting to MCP servers +using different protocols: + +- InMemoryTransport: For testing servers without network overhead +- HttpTransport: For Streamable HTTP connections (recommended for HTTP) +- SSETransport: For legacy Server-Sent Events connections + +Example: + ```python + from mcp.client import Client + from mcp.client.transports import HttpTransport, SSETransport + + # Using Streamable HTTP (recommended) + async with Client(HttpTransport("http://localhost:8000/mcp")) as client: + result = await client.call_tool("my_tool", {...}) + + # Using legacy SSE + async with Client(SSETransport("http://localhost:8000/sse")) as client: + result = await client.call_tool("my_tool", {...}) + ``` +""" + +from mcp.client.transports.base import Transport +from mcp.client.transports.http import HttpTransport +from mcp.client.transports.memory import InMemoryTransport +from mcp.client.transports.sse import SSETransport + +__all__ = [ + "Transport", + "HttpTransport", + "InMemoryTransport", + "SSETransport", +] diff --git a/src/mcp/client/transports/base.py b/src/mcp/client/transports/base.py new file mode 100644 index 000000000..5422e9b75 --- /dev/null +++ b/src/mcp/client/transports/base.py @@ -0,0 +1,47 @@ +"""Base transport protocol for MCP clients.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Protocol, runtime_checkable + +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +from mcp.shared.message import SessionMessage + + +@runtime_checkable +class Transport(Protocol): + """Protocol for MCP client transports. + + All transports must implement a connect() async context manager that yields + a tuple of (read_stream, write_stream) for bidirectional communication. + + Example: + ```python + class MyTransport: + @asynccontextmanager + async def connect(self): + # Set up connection... + yield read_stream, write_stream + # Clean up... + ``` + """ + + @asynccontextmanager + async def connect( + self, + ) -> AsyncIterator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + ] + ]: + """Connect to the server and yield streams for communication. + + Yields: + A tuple of (read_stream, write_stream) for bidirectional communication. + """ + ... + yield # type: ignore[misc] diff --git a/src/mcp/client/transports/http.py b/src/mcp/client/transports/http.py new file mode 100644 index 000000000..47263d379 --- /dev/null +++ b/src/mcp/client/transports/http.py @@ -0,0 +1,103 @@ +"""Streamable HTTP transport for MCP clients.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import httpx +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +from mcp.client.streamable_http import streamable_http_client +from mcp.shared.message import SessionMessage + + +class HttpTransport: + """Streamable HTTP transport for connecting to MCP servers over HTTP. + + This transport uses the Streamable HTTP protocol, which is the recommended + transport for HTTP-based MCP connections. + + Example: + ```python + from mcp.client import Client + from mcp.client.transports import HttpTransport + + # Basic usage + async with Client(HttpTransport("http://localhost:8000/mcp")) as client: + result = await client.call_tool("my_tool", {...}) + + # Or use the convenience URL syntax + async with Client("http://localhost:8000/mcp") as client: + result = await client.call_tool("my_tool", {...}) + + # With custom headers (e.g., authentication) + transport = HttpTransport( + "http://localhost:8000/mcp", + headers={"Authorization": "Bearer token"}, + ) + async with Client(transport) as client: + result = await client.call_tool("my_tool", {...}) + + # With a pre-configured httpx client + http_client = httpx.AsyncClient( + headers={"Authorization": "Bearer token"}, + timeout=30.0, + ) + transport = HttpTransport("http://localhost:8000/mcp", httpx_client=http_client) + async with Client(transport) as client: + result = await client.call_tool("my_tool", {...}) + ``` + """ + + def __init__( + self, + url: str, + *, + headers: dict[str, str] | None = None, + httpx_client: httpx.AsyncClient | None = None, + terminate_on_close: bool = True, + ) -> None: + """Initialize the HTTP transport. + + Args: + url: The MCP server endpoint URL. + headers: Optional headers to include in requests. For authentication, + include an "Authorization" header or use httpx_client with auth + configured. Ignored if httpx_client is provided. + httpx_client: Optional pre-configured httpx.AsyncClient. If provided, + the headers parameter is ignored. The client lifecycle is managed + externally (not closed by this transport). + terminate_on_close: If True, send a DELETE request to terminate the + session when the context exits. Defaults to True. + """ + self._url = url + self._headers = headers + self._httpx_client = httpx_client + self._terminate_on_close = terminate_on_close + + @asynccontextmanager + async def connect( + self, + ) -> AsyncIterator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + ] + ]: + """Connect to the server and return streams for communication. + + Yields: + A tuple of (read_stream, write_stream) for bidirectional communication. + """ + # If headers are provided without a custom client, create one with those headers + client = self._httpx_client + if client is None and self._headers is not None: + client = httpx.AsyncClient(headers=self._headers) + + async with streamable_http_client( + self._url, + http_client=client, + terminate_on_close=self._terminate_on_close, + ) as (read_stream, write_stream, _get_session_id): + yield read_stream, write_stream diff --git a/src/mcp/client/_memory.py b/src/mcp/client/transports/memory.py similarity index 77% rename from src/mcp/client/_memory.py rename to src/mcp/client/transports/memory.py index 3589d0da7..91c35aa37 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/transports/memory.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any @@ -23,25 +23,37 @@ class InMemoryTransport: stopped when the context manager exits. Example: + ```python + from mcp.client import Client + from mcp.server.fastmcp import FastMCP + server = FastMCP("test") - transport = InMemoryTransport(server) - async with transport.connect() as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - # Use the session... + @server.tool() + def add(a: int, b: int) -> int: + return a + b - Or more commonly, use with Client: + # Direct usage with Client (recommended) async with Client(server) as client: - result = await client.call_tool("my_tool", {...}) + result = await client.call_tool("add", {"a": 1, "b": 2}) + + # Or explicit transport usage + async with Client(InMemoryTransport(server)) as client: + result = await client.call_tool("add", {"a": 1, "b": 2}) + ``` """ - def __init__(self, server: Server[Any] | FastMCP, *, raise_exceptions: bool = False) -> None: + def __init__( + self, + server: Server[Any] | FastMCP, + *, + raise_exceptions: bool = False, + ) -> None: """Initialize the in-memory transport. Args: - server: The MCP server to connect to (Server or FastMCP instance) - raise_exceptions: Whether to raise exceptions from the server + server: The MCP server to connect to (Server or FastMCP instance). + raise_exceptions: Whether to raise exceptions from the server. """ self._server = server self._raise_exceptions = raise_exceptions @@ -49,17 +61,16 @@ def __init__(self, server: Server[Any] | FastMCP, *, raise_exceptions: bool = Fa @asynccontextmanager async def connect( self, - ) -> AsyncGenerator[ + ) -> AsyncIterator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], - ], - None, + ] ]: """Connect to the server and return streams for communication. Yields: - A tuple of (read_stream, write_stream) for bidirectional communication + A tuple of (read_stream, write_stream) for bidirectional communication. """ # Unwrap FastMCP to get underlying Server actual_server: Server[Any] diff --git a/src/mcp/client/transports/sse.py b/src/mcp/client/transports/sse.py new file mode 100644 index 000000000..a10b9da96 --- /dev/null +++ b/src/mcp/client/transports/sse.py @@ -0,0 +1,93 @@ +"""Server-Sent Events (SSE) transport for MCP clients. + +Note: SSE is a legacy transport. For new implementations, prefer HttpTransport +which uses the Streamable HTTP protocol. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager +from typing import Any + +import httpx +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +from mcp.client.sse import sse_client +from mcp.shared.message import SessionMessage + + +class SSETransport: + """Server-Sent Events (SSE) transport for connecting to MCP servers. + + Note: SSE is a legacy transport. For new implementations, prefer + HttpTransport which uses the Streamable HTTP protocol. + + Example: + ```python + from mcp.client import Client + from mcp.client.transports import SSETransport + + async with Client(SSETransport("http://localhost:8000/sse")) as client: + result = await client.call_tool("my_tool", {...}) + + # With authentication + transport = SSETransport( + "http://localhost:8000/sse", + headers={"Authorization": "Bearer token"}, + ) + async with Client(transport) as client: + result = await client.call_tool("my_tool", {...}) + ``` + """ + + def __init__( + self, + url: str, + *, + headers: dict[str, Any] | None = None, + timeout: float = 5.0, + sse_read_timeout: float = 300.0, + auth: httpx.Auth | None = None, + on_session_created: Callable[[str], None] | None = None, + ) -> None: + """Initialize the SSE transport. + + Args: + url: The SSE endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations (in seconds). Defaults to 5.0. + sse_read_timeout: Timeout for SSE read operations (in seconds). Defaults to 300.0. + auth: Optional HTTPX authentication handler. + on_session_created: Optional callback invoked with the session ID when received. + """ + self._url = url + self._headers = headers + self._timeout = timeout + self._sse_read_timeout = sse_read_timeout + self._auth = auth + self._on_session_created = on_session_created + + @asynccontextmanager + async def connect( + self, + ) -> AsyncIterator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + ] + ]: + """Connect to the server and return streams for communication. + + Yields: + A tuple of (read_stream, write_stream) for bidirectional communication. + """ + async with sse_client( + self._url, + headers=self._headers, + timeout=self._timeout, + sse_read_timeout=self._sse_read_timeout, + auth=self._auth, + on_session_created=self._on_session_created, + ) as (read_stream, write_stream): + yield read_stream, write_stream diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 7314a3735..ebd398cb0 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -123,7 +123,7 @@ async def patched_create_streams(): # Apply the patch for the duration of the test # Patch both locations since InMemoryTransport imports it directly with patch("mcp.shared.memory.create_client_server_memory_streams", patched_create_streams): - with patch("mcp.client._memory.create_client_server_memory_streams", patched_create_streams): + with patch("mcp.client.transports.memory.create_client_server_memory_streams", patched_create_streams): # Return a collection with helper methods def get_spy_collection() -> StreamSpyCollection: assert client_spy is not None, "client_spy was not initialized" diff --git a/tests/client/test_client_transports.py b/tests/client/test_client_transports.py new file mode 100644 index 000000000..53f76c14c --- /dev/null +++ b/tests/client/test_client_transports.py @@ -0,0 +1,83 @@ +"""Tests for Client with different transport types.""" + +import pytest + +from mcp.client import Client +from mcp.client.transports import HttpTransport, InMemoryTransport, SSETransport +from mcp.server.fastmcp import FastMCP + + +@pytest.fixture +def test_server() -> FastMCP: + """Create a simple test server.""" + server = FastMCP("test") + + @server.tool() + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + return server + + +pytestmark = pytest.mark.anyio + + +async def test_client_with_server_directly(test_server: FastMCP): + """Test Client accepts a Server/FastMCP instance directly.""" + async with Client(test_server) as client: + result = await client.call_tool("add", {"a": 1, "b": 2}) + assert "3" in str(result.content[0]) + + +async def test_client_with_in_memory_transport(test_server: FastMCP): + """Test Client accepts an InMemoryTransport instance.""" + transport = InMemoryTransport(test_server) + async with Client(transport) as client: + result = await client.call_tool("add", {"a": 5, "b": 7}) + assert "12" in str(result.content[0]) + + +async def test_client_with_raise_exceptions(test_server: FastMCP): + """Test that raise_exceptions is passed through for in-memory transport.""" + async with Client(test_server, raise_exceptions=True) as client: + # If we got here without error, raise_exceptions was accepted + assert client.server_capabilities is not None + + +# Note: The following tests verify type acceptance but don't make network calls +# since they would require a real server running. + + +def test_client_accepts_http_transport(): + """Test that Client constructor accepts HttpTransport.""" + transport = HttpTransport("http://localhost:8000/mcp") + # Just verify it can be constructed - don't enter context manager + client = Client(transport) + assert client._target is transport + + +def test_client_accepts_sse_transport(): + """Test that Client constructor accepts SSETransport.""" + transport = SSETransport("http://localhost:8000/sse") + # Just verify it can be constructed - don't enter context manager + client = Client(transport) + assert client._target is transport + + +def test_client_accepts_url_string(): + """Test that Client constructor accepts a URL string.""" + client = Client("http://localhost:8000/mcp") + # URL string should be stored as the target + assert client._target == "http://localhost:8000/mcp" + + +def test_client_with_http_transport_and_headers(): + """Test that HttpTransport with headers can be passed to Client.""" + transport = HttpTransport( + "http://localhost:8000/mcp", + headers={"Authorization": "Bearer token123"}, + ) + client = Client(transport) + assert client._target is transport + assert transport._headers == {"Authorization": "Bearer token123"} diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index b97ebcea2..af9cc5302 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -3,7 +3,7 @@ import pytest from mcp import Client -from mcp.client._memory import InMemoryTransport +from mcp.client.transports import InMemoryTransport from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.types import Resource diff --git a/tests/client/transports/test_transport_inference.py b/tests/client/transports/test_transport_inference.py new file mode 100644 index 000000000..eef7f140c --- /dev/null +++ b/tests/client/transports/test_transport_inference.py @@ -0,0 +1,96 @@ +"""Tests for transport type inference.""" + +import pytest + +from mcp.client.client import _infer_transport +from mcp.client.transports import HttpTransport, InMemoryTransport, SSETransport, Transport +from mcp.server import Server +from mcp.server.fastmcp import FastMCP + + +def test_infer_transport_from_server(): + """Test that Server instances are wrapped in InMemoryTransport.""" + server = Server(name="test") + transport = _infer_transport(server) + + assert isinstance(transport, InMemoryTransport) + + +def test_infer_transport_from_fastmcp(): + """Test that FastMCP instances are wrapped in InMemoryTransport.""" + server = FastMCP("test") + transport = _infer_transport(server) + + assert isinstance(transport, InMemoryTransport) + + +def test_infer_transport_from_url_string(): + """Test that URL strings are wrapped in HttpTransport.""" + transport = _infer_transport("http://localhost:8000/mcp") + + assert isinstance(transport, HttpTransport) + + +def test_infer_transport_from_https_url(): + """Test that HTTPS URLs are wrapped in HttpTransport.""" + transport = _infer_transport("https://example.com/mcp") + + assert isinstance(transport, HttpTransport) + + +def test_infer_transport_passthrough_http(): + """Test that HttpTransport instances are passed through unchanged.""" + original = HttpTransport("http://localhost:8000/mcp") + transport = _infer_transport(original) + + assert transport is original + + +def test_infer_transport_passthrough_sse(): + """Test that SSETransport instances are passed through unchanged.""" + original = SSETransport("http://localhost:8000/sse") + transport = _infer_transport(original) + + assert transport is original + + +def test_infer_transport_passthrough_memory(): + """Test that InMemoryTransport instances are passed through unchanged.""" + server = FastMCP("test") + original = InMemoryTransport(server) + transport = _infer_transport(original) + + assert transport is original + + +def test_infer_transport_invalid_type(): + """Test that invalid types are passed through to HttpTransport. + + Note: After type narrowing (Transport, Server|FastMCP), remaining types + are treated as URL strings and passed to HttpTransport. This means + invalid types will fail at HttpTransport construction time, not in + _infer_transport itself. + """ + # Invalid types are treated as URL strings and passed to HttpTransport + # HttpTransport will accept any string-like input + transport = _infer_transport(12345) # type: ignore[arg-type] + assert isinstance(transport, HttpTransport) + + +def test_infer_transport_raise_exceptions_passed_to_memory(): + """Test that raise_exceptions is passed to InMemoryTransport.""" + server = FastMCP("test") + transport = _infer_transport(server, raise_exceptions=True) + + assert isinstance(transport, InMemoryTransport) + assert transport._raise_exceptions is True + + +def test_transport_protocol_compliance(): + """Test that all transport classes implement the Transport protocol.""" + server = FastMCP("test") + + # Check that each transport is recognized as a Transport + assert isinstance(InMemoryTransport(server), Transport) + assert isinstance(HttpTransport("http://localhost:8000"), Transport) + assert isinstance(SSETransport("http://localhost:8000"), Transport)