Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/mcp/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""MCP Client module."""

from mcp.client.client import Client
from mcp.client.client import Client, ClientTarget
from mcp.client.session import ClientSession

__all__ = [
"Client",
"ClientSession",
"ClientTarget",
]
109 changes: 83 additions & 26 deletions src/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pydantic import AnyUrl

import mcp.types as types
from mcp.client._memory import InMemoryTransport
from mcp.client.session import (
ClientSession,
ElicitationFnT,
Expand All @@ -18,21 +17,62 @@
MessageHandlerFnT,
SamplingFnT,
)
from mcp.client.transports import HttpTransport, InMemoryTransport, Transport
from mcp.server import Server
from mcp.server.fastmcp import FastMCP
from mcp.shared.session import ProgressFnT

logger = logging.getLogger(__name__)

# Type alias for all accepted target types
ClientTarget = Server[Any] | FastMCP | Transport | str


def _infer_transport(
target: ClientTarget,
*,
raise_exceptions: bool = False,
) -> Transport:
"""Infer the appropriate transport from the target type.

Args:
target: The target to connect to. Can be:
- Server or FastMCP instance: Uses InMemoryTransport
- Transport instance: Uses the transport directly
- str (URL): Uses HttpTransport (Streamable HTTP)
raise_exceptions: For InMemoryTransport, whether to raise exceptions
from the server. Ignored for other transport types.

Returns:
A Transport instance ready to connect.

Raises:
TypeError: If the target type is not recognized.
"""
# Already a transport - use directly
if isinstance(target, Transport):
return target

# Server or FastMCP - use in-memory transport for testing
if isinstance(target, Server | FastMCP):
return InMemoryTransport(target, raise_exceptions=raise_exceptions)

# URL string - use Streamable HTTP transport (modern standard)
# Note: After type narrowing above, target is str here
return HttpTransport(target)


class Client:
"""A high-level MCP client for connecting to MCP servers.

Currently supports in-memory transport for testing. Pass a Server or
FastMCP instance directly to the constructor.
Supports multiple transport types:
- In-memory: Pass a Server or FastMCP instance directly (for testing)
- HTTP: Pass a URL string or HttpTransport instance
- SSE: Pass an SSETransport instance (legacy)

Example:
Examples:
```python
# In-memory testing (recommended for unit tests)
from mcp.client import Client
from mcp.server.fastmcp import FastMCP

Expand All @@ -44,21 +84,34 @@ def add(a: int, b: int) -> int:

async with Client(server) as client:
result = await client.call_tool("add", {"a": 1, "b": 2})

# HTTP connection via URL string
async with Client("http://localhost:8000/mcp") as client:
result = await client.call_tool("my_tool", {...})

# HTTP connection with custom headers
from mcp.client.transports import HttpTransport

transport = HttpTransport(
"http://localhost:8000/mcp",
headers={"Authorization": "Bearer token"},
)
async with Client(transport) as client:
result = await client.call_tool("my_tool", {...})

# Legacy SSE connection
from mcp.client.transports import SSETransport

async with Client(SSETransport("http://localhost:8000/sse")) as client:
result = await client.call_tool("my_tool", {...})
```
"""

# TODO(felixweinberger): Expand to support all transport types (like FastMCP 2):
# - Add ClientTransport base class with connect_session() method
# - Add StreamableHttpTransport, SSETransport, StdioTransport
# - Add infer_transport() to auto-detect transport from input type
# - Accept URL strings, Path objects, config dicts in constructor
# - Add auth support (OAuth, bearer tokens)

def __init__(
self,
server: Server[Any] | FastMCP,
target: ClientTarget,
*,
# TODO(Marcelo): When do `raise_exceptions=True` actually raises?
# TODO(Marcelo): When does `raise_exceptions=True` actually raise?
raise_exceptions: bool = False,
read_timeout_seconds: float | None = None,
sampling_callback: SamplingFnT | None = None,
Expand All @@ -68,20 +121,24 @@ def __init__(
client_info: types.Implementation | None = None,
elicitation_callback: ElicitationFnT | None = None,
) -> None:
"""Initialize the client with a server.
"""Initialize the client.

Args:
server: The MCP server to connect to (Server or FastMCP instance)
raise_exceptions: Whether to raise exceptions from the server
read_timeout_seconds: Timeout for read operations
sampling_callback: Callback for handling sampling requests
list_roots_callback: Callback for handling list roots requests
logging_callback: Callback for handling logging notifications
message_handler: Callback for handling raw messages
client_info: Client implementation info to send to server
elicitation_callback: Callback for handling elicitation requests
target: The target to connect to. Can be:
- Server or FastMCP instance: Uses in-memory transport (for testing)
- Transport instance: Uses the transport directly
- str (URL): Uses HTTP transport (Streamable HTTP protocol)
raise_exceptions: For in-memory transport, whether to raise exceptions
from the server. Ignored for other transport types.
read_timeout_seconds: Timeout for read operations.
sampling_callback: Callback for handling sampling requests.
list_roots_callback: Callback for handling list roots requests.
logging_callback: Callback for handling logging notifications.
message_handler: Callback for handling raw messages.
client_info: Client implementation info to send to server.
elicitation_callback: Callback for handling elicitation requests.
"""
self._server = server
self._target = target
self._raise_exceptions = raise_exceptions
self._read_timeout_seconds = read_timeout_seconds
self._sampling_callback = sampling_callback
Expand All @@ -100,8 +157,8 @@ async def __aenter__(self) -> Client:
raise RuntimeError("Client is already entered; cannot reenter")

async with AsyncExitStack() as exit_stack:
# Create transport and connect
transport = InMemoryTransport(self._server, raise_exceptions=self._raise_exceptions)
# Infer and connect transport
transport = _infer_transport(self._target, raise_exceptions=self._raise_exceptions)
read_stream, write_stream = await exit_stack.enter_async_context(transport.connect())

# Create session
Expand Down
35 changes: 35 additions & 0 deletions src/mcp/client/transports/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Transport implementations for MCP clients.

This module provides transport abstractions for connecting to MCP servers
using different protocols:

- InMemoryTransport: For testing servers without network overhead
- HttpTransport: For Streamable HTTP connections (recommended for HTTP)
- SSETransport: For legacy Server-Sent Events connections

Example:
```python
from mcp.client import Client
from mcp.client.transports import HttpTransport, SSETransport

# Using Streamable HTTP (recommended)
async with Client(HttpTransport("http://localhost:8000/mcp")) as client:
result = await client.call_tool("my_tool", {...})

# Using legacy SSE
async with Client(SSETransport("http://localhost:8000/sse")) as client:
result = await client.call_tool("my_tool", {...})
```
"""

from mcp.client.transports.base import Transport
from mcp.client.transports.http import HttpTransport
from mcp.client.transports.memory import InMemoryTransport
from mcp.client.transports.sse import SSETransport

__all__ = [
"Transport",
"HttpTransport",
"InMemoryTransport",
"SSETransport",
]
47 changes: 47 additions & 0 deletions src/mcp/client/transports/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Base transport protocol for MCP clients."""

from __future__ import annotations

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Protocol, runtime_checkable

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp.shared.message import SessionMessage


@runtime_checkable
class Transport(Protocol):
"""Protocol for MCP client transports.

All transports must implement a connect() async context manager that yields
a tuple of (read_stream, write_stream) for bidirectional communication.

Example:
```python
class MyTransport:
@asynccontextmanager
async def connect(self):
# Set up connection...
yield read_stream, write_stream
# Clean up...
```
"""

@asynccontextmanager
async def connect(
self,
) -> AsyncIterator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
]
]:
"""Connect to the server and yield streams for communication.

Yields:
A tuple of (read_stream, write_stream) for bidirectional communication.
"""
...
yield # type: ignore[misc]
103 changes: 103 additions & 0 deletions src/mcp/client/transports/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Streamable HTTP transport for MCP clients."""

from __future__ import annotations

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager

import httpx
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp.client.streamable_http import streamable_http_client
from mcp.shared.message import SessionMessage


class HttpTransport:
"""Streamable HTTP transport for connecting to MCP servers over HTTP.

This transport uses the Streamable HTTP protocol, which is the recommended
transport for HTTP-based MCP connections.

Example:
```python
from mcp.client import Client
from mcp.client.transports import HttpTransport

# Basic usage
async with Client(HttpTransport("http://localhost:8000/mcp")) as client:
result = await client.call_tool("my_tool", {...})

# Or use the convenience URL syntax
async with Client("http://localhost:8000/mcp") as client:
result = await client.call_tool("my_tool", {...})

# With custom headers (e.g., authentication)
transport = HttpTransport(
"http://localhost:8000/mcp",
headers={"Authorization": "Bearer token"},
)
async with Client(transport) as client:
result = await client.call_tool("my_tool", {...})

# With a pre-configured httpx client
http_client = httpx.AsyncClient(
headers={"Authorization": "Bearer token"},
timeout=30.0,
)
transport = HttpTransport("http://localhost:8000/mcp", httpx_client=http_client)
async with Client(transport) as client:
result = await client.call_tool("my_tool", {...})
```
"""

def __init__(
self,
url: str,
*,
headers: dict[str, str] | None = None,
httpx_client: httpx.AsyncClient | None = None,
terminate_on_close: bool = True,
) -> None:
"""Initialize the HTTP transport.

Args:
url: The MCP server endpoint URL.
headers: Optional headers to include in requests. For authentication,
include an "Authorization" header or use httpx_client with auth
configured. Ignored if httpx_client is provided.
httpx_client: Optional pre-configured httpx.AsyncClient. If provided,
the headers parameter is ignored. The client lifecycle is managed
externally (not closed by this transport).
terminate_on_close: If True, send a DELETE request to terminate the
session when the context exits. Defaults to True.
"""
self._url = url
self._headers = headers
self._httpx_client = httpx_client
self._terminate_on_close = terminate_on_close

@asynccontextmanager
async def connect(
self,
) -> AsyncIterator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
]
]:
"""Connect to the server and return streams for communication.

Yields:
A tuple of (read_stream, write_stream) for bidirectional communication.
"""
# If headers are provided without a custom client, create one with those headers
client = self._httpx_client
if client is None and self._headers is not None:
client = httpx.AsyncClient(headers=self._headers)

async with streamable_http_client(
self._url,
http_client=client,
terminate_on_close=self._terminate_on_close,
) as (read_stream, write_stream, _get_session_id):
yield read_stream, write_stream
Loading