Skip to content

Commit 65e5d63

Browse files
committed
refactor: move request_ctx contextvar to lowlevel server, remove per-handler boilerplate
Set request_ctx once in _handle_request and _handle_notification instead of wrapping every MCPServer handler closure with token set/reset. MCPServer's get_context() now reads request_ctx directly.
1 parent 711af73 commit 65e5d63

File tree

3 files changed

+65
-91
lines changed

3 files changed

+65
-91
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .handler import Handler, NotificationHandler, RequestHandler
2-
from .server import NotificationOptions, Server
2+
from .server import NotificationOptions, Server, request_ctx
33

4-
__all__ = ["Handler", "NotificationHandler", "NotificationOptions", "RequestHandler", "Server"]
4+
__all__ = ["Handler", "NotificationHandler", "NotificationOptions", "RequestHandler", "Server", "request_ctx"]

src/mcp/server/lowlevel/server.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ async def main():
3838

3939
from __future__ import annotations
4040

41+
import contextvars
4142
import logging
4243
import warnings
4344
from collections.abc import AsyncIterator, Callable, Sequence
@@ -77,6 +78,8 @@ async def main():
7778
LifespanResultT = TypeVar("LifespanResultT", default=Any)
7879
RequestT = TypeVar("RequestT", default=Any)
7980

81+
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx")
82+
8083

8184
class NotificationOptions:
8285
def __init__(
@@ -387,7 +390,11 @@ async def _handle_request(
387390
close_sse_stream=close_sse_stream_cb,
388391
close_standalone_sse_stream=close_standalone_sse_stream_cb,
389392
)
390-
response = await handler.handle(ctx, req.params)
393+
token = request_ctx.set(ctx)
394+
try:
395+
response = await handler.handle(ctx, req.params)
396+
finally:
397+
request_ctx.reset(token)
391398
except MCPError as err:
392399
response = err.error
393400
except anyio.get_cancelled_exc_class():
@@ -426,7 +433,11 @@ async def _handle_notification(
426433
_task_support=task_support,
427434
),
428435
)
429-
await handler.handle(ctx, getattr(notify, "params", None))
436+
token = request_ctx.set(ctx)
437+
try:
438+
await handler.handle(ctx, getattr(notify, "params", None))
439+
finally:
440+
request_ctx.reset(token)
430441
except Exception: # pragma: no cover
431442
logger.exception("Uncaught exception in notification handler")
432443

src/mcp/server/mcpserver/server.py

Lines changed: 50 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import base64
6-
import contextvars
76
import inspect
87
import json
98
import re
@@ -32,7 +31,7 @@
3231
from mcp.server.elicitation import elicit_url as _elicit_url
3332
from mcp.server.lowlevel.handler import Handler, RequestHandler
3433
from mcp.server.lowlevel.helper_types import ReadResourceContents
35-
from mcp.server.lowlevel.server import LifespanResultT, Server
34+
from mcp.server.lowlevel.server import LifespanResultT, Server, request_ctx
3635
from mcp.server.lowlevel.server import lifespan as default_lifespan
3736
from mcp.server.mcpserver.exceptions import ResourceError
3837
from mcp.server.mcpserver.prompts import Prompt, PromptManager
@@ -76,10 +75,6 @@
7675

7776
_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
7877

79-
_mcp_server_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar(
80-
"_mcp_server_ctx"
81-
)
82-
8378

8479
class Settings(BaseSettings, Generic[LifespanResultT]):
8580
"""MCPServer settings.
@@ -292,93 +287,65 @@ def _create_handlers(self) -> list[Handler]:
292287
"""Create core MCP protocol handlers."""
293288

294289
async def handle_list_tools(ctx: Any, params: Any) -> ListToolsResult:
295-
token = _mcp_server_ctx.set(ctx)
296-
try:
297-
return ListToolsResult(tools=await self.list_tools())
298-
finally:
299-
_mcp_server_ctx.reset(token)
290+
return ListToolsResult(tools=await self.list_tools())
300291

301292
async def handle_call_tool(ctx: Any, params: Any) -> CallToolResult:
302-
token = _mcp_server_ctx.set(ctx)
303293
try:
304-
try:
305-
result = await self.call_tool(params.name, params.arguments or {})
306-
except MCPError:
307-
raise
308-
except Exception as e:
309-
return CallToolResult(content=[TextContent(type="text", text=str(e))], is_error=True)
310-
if isinstance(result, CallToolResult):
311-
return result
312-
if isinstance(result, tuple) and len(result) == 2:
313-
unstructured_content, structured_content = result
314-
return CallToolResult(
315-
content=list(unstructured_content), # type: ignore[arg-type]
316-
structured_content=structured_content, # type: ignore[arg-type]
317-
)
318-
if isinstance(result, dict):
319-
return CallToolResult(
320-
content=[TextContent(type="text", text=json.dumps(result, indent=2))],
321-
structured_content=result,
322-
)
323-
return CallToolResult(content=list(result))
324-
finally:
325-
_mcp_server_ctx.reset(token)
294+
result = await self.call_tool(params.name, params.arguments or {})
295+
except MCPError:
296+
raise
297+
except Exception as e:
298+
return CallToolResult(content=[TextContent(type="text", text=str(e))], is_error=True)
299+
if isinstance(result, CallToolResult):
300+
return result
301+
if isinstance(result, tuple) and len(result) == 2:
302+
unstructured_content, structured_content = result
303+
return CallToolResult(
304+
content=list(unstructured_content), # type: ignore[arg-type]
305+
structured_content=structured_content, # type: ignore[arg-type]
306+
)
307+
if isinstance(result, dict):
308+
return CallToolResult(
309+
content=[TextContent(type="text", text=json.dumps(result, indent=2))],
310+
structured_content=result,
311+
)
312+
return CallToolResult(content=list(result))
326313

327314
async def handle_list_resources(ctx: Any, params: Any) -> ListResourcesResult:
328-
token = _mcp_server_ctx.set(ctx)
329-
try:
330-
return ListResourcesResult(resources=await self.list_resources())
331-
finally:
332-
_mcp_server_ctx.reset(token)
315+
return ListResourcesResult(resources=await self.list_resources())
333316

334317
async def handle_read_resource(ctx: Any, params: Any) -> ReadResourceResult:
335-
token = _mcp_server_ctx.set(ctx)
336-
try:
337-
results = await self.read_resource(params.uri)
338-
contents: list[TextResourceContents | BlobResourceContents] = []
339-
for item in results:
340-
if isinstance(item.content, bytes):
341-
contents.append(
342-
BlobResourceContents(
343-
uri=params.uri,
344-
blob=base64.b64encode(item.content).decode(),
345-
mime_type=item.mime_type or "application/octet-stream",
346-
_meta=item.meta,
347-
)
318+
results = await self.read_resource(params.uri)
319+
contents: list[TextResourceContents | BlobResourceContents] = []
320+
for item in results:
321+
if isinstance(item.content, bytes):
322+
contents.append(
323+
BlobResourceContents(
324+
uri=params.uri,
325+
blob=base64.b64encode(item.content).decode(),
326+
mime_type=item.mime_type or "application/octet-stream",
327+
_meta=item.meta,
348328
)
349-
else:
350-
contents.append(
351-
TextResourceContents(
352-
uri=params.uri,
353-
text=item.content,
354-
mime_type=item.mime_type or "text/plain",
355-
_meta=item.meta,
356-
)
329+
)
330+
else:
331+
contents.append(
332+
TextResourceContents(
333+
uri=params.uri,
334+
text=item.content,
335+
mime_type=item.mime_type or "text/plain",
336+
_meta=item.meta,
357337
)
358-
return ReadResourceResult(contents=contents)
359-
finally:
360-
_mcp_server_ctx.reset(token)
338+
)
339+
return ReadResourceResult(contents=contents)
361340

362341
async def handle_list_resource_templates(ctx: Any, params: Any) -> ListResourceTemplatesResult:
363-
token = _mcp_server_ctx.set(ctx)
364-
try:
365-
return ListResourceTemplatesResult(resource_templates=await self.list_resource_templates())
366-
finally:
367-
_mcp_server_ctx.reset(token)
342+
return ListResourceTemplatesResult(resource_templates=await self.list_resource_templates())
368343

369344
async def handle_list_prompts(ctx: Any, params: Any) -> ListPromptsResult:
370-
token = _mcp_server_ctx.set(ctx)
371-
try:
372-
return ListPromptsResult(prompts=await self.list_prompts())
373-
finally:
374-
_mcp_server_ctx.reset(token)
345+
return ListPromptsResult(prompts=await self.list_prompts())
375346

376347
async def handle_get_prompt(ctx: Any, params: Any) -> GetPromptResult:
377-
token = _mcp_server_ctx.set(ctx)
378-
try:
379-
return await self.get_prompt(params.name, params.arguments)
380-
finally:
381-
_mcp_server_ctx.reset(token)
348+
return await self.get_prompt(params.name, params.arguments)
382349

383350
return [
384351
RequestHandler("tools/list", handler=handle_list_tools),
@@ -412,7 +379,7 @@ def get_context(self) -> Context[ServerSession, LifespanResultT, Request]:
412379
during a request; outside a request, most methods will error.
413380
"""
414381
try:
415-
request_context = _mcp_server_ctx.get()
382+
request_context = request_ctx.get()
416383
except LookupError:
417384
request_context = None
418385
return Context(request_context=request_context, mcp_server=self)
@@ -603,14 +570,10 @@ async def handle_completion(ref, argument, context):
603570

604571
def decorator(func: _CallableT) -> _CallableT:
605572
async def handler(ctx: Any, params: Any) -> CompleteResult:
606-
token = _mcp_server_ctx.set(ctx)
607-
try:
608-
result = await func(params.ref, params.argument, params.context)
609-
return CompleteResult(
610-
completion=result if result is not None else Completion(values=[], total=None, has_more=None),
611-
)
612-
finally:
613-
_mcp_server_ctx.reset(token)
573+
result = await func(params.ref, params.argument, params.context)
574+
return CompleteResult(
575+
completion=result if result is not None else Completion(values=[], total=None, has_more=None),
576+
)
614577

615578
# TODO(maxisbey): remove private access — completion needs post-construction
616579
# handler registration, find a better pattern for this

0 commit comments

Comments
 (0)