diff --git a/py/packages/genkit/src/genkit/core/registry.py b/py/packages/genkit/src/genkit/core/registry.py index 0dca7c2d67..09813f1f4c 100644 --- a/py/packages/genkit/src/genkit/core/registry.py +++ b/py/packages/genkit/src/genkit/core/registry.py @@ -454,6 +454,24 @@ async def list_actions(self, allowed_kinds: list[ActionKind] | None = None) -> l if allowed_kinds and meta.kind not in allowed_kinds: continue metas.append(meta) + + # Include actions registered directly in the registry + with self._lock: + for kind, kind_map in self._entries.items(): + if allowed_kinds and kind not in allowed_kinds: + continue + for action in kind_map.values(): + metas.append( + ActionMetadata( + kind=action.kind, + name=action.name, + description=action.description, + input_json_schema=action.input_schema, + output_json_schema=action.output_schema, + metadata=action.metadata, + ) + ) + return metas def register_schema(self, name: str, schema: dict[str, Any]) -> None: diff --git a/py/packages/genkit/tests/genkit/core/registry_test.py b/py/packages/genkit/tests/genkit/core/registry_test.py index 6fd006cf0a..cd89900e12 100644 --- a/py/packages/genkit/tests/genkit/core/registry_test.py +++ b/py/packages/genkit/tests/genkit/core/registry_test.py @@ -75,7 +75,9 @@ async def list_actions(self) -> list[ActionMetadata]: ai = Genkit(plugins=[MyPlugin()]) metas = await ai.registry.list_actions() - assert metas == [ActionMetadata(kind=ActionKind.MODEL, name='myplugin/foo')] + # Filter for the specific plugin action we expect, ignoring system actions like 'generate' + target_meta = next((m for m in metas if m.name == 'myplugin/foo'), None) + assert target_meta == ActionMetadata(kind=ActionKind.MODEL, name='myplugin/foo') action = await ai.registry.resolve_action(ActionKind.MODEL, 'myplugin/foo') diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index ef649fbe91..af8e909771 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -683,7 +683,7 @@ def _create_tool(self, tool: ToolDefinition) -> genai_types.Tool: params = genai_types.Schema(type=genai_types.Type.OBJECT, properties={}) function = genai_types.FunctionDeclaration( - name=tool.name, + name=tool.name.replace('/', '__'), description=tool.description, parameters=params, response=self._convert_schema_property(tool.output_schema) if tool.output_schema else None, diff --git a/py/plugins/mcp/pyproject.toml b/py/plugins/mcp/pyproject.toml index 6ea44f68ec..bdc353d689 100644 --- a/py/plugins/mcp/pyproject.toml +++ b/py/plugins/mcp/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development :: Libraries", ] -dependencies = ["genkit", "mcp"] +dependencies = ["genkit", "mcp", "structlog"] description = "Genkit MCP Plugin" license = "Apache-2.0" name = "genkit-plugins-mcp" @@ -45,4 +45,4 @@ build-backend = "hatchling.build" requires = ["hatchling"] [tool.hatch.build.targets.wheel] -packages = ["src"] +packages = ["src/genkit"] diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/client/client.py b/py/plugins/mcp/src/genkit/plugins/mcp/client/client.py index 597da873de..91616b77c4 100644 --- a/py/plugins/mcp/src/genkit/plugins/mcp/client/client.py +++ b/py/plugins/mcp/src/genkit/plugins/mcp/client/client.py @@ -14,18 +14,20 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any +import asyncio +from contextlib import AsyncExitStack +from typing import Any, cast import structlog -from pydantic import BaseModel +from pydantic import AnyUrl, BaseModel -from genkit.ai import Genkit, Plugin -from genkit.core.action import Action, ActionMetadata +from genkit.ai import Genkit +from genkit.ai._registry import GenkitRegistry from genkit.core.action.types import ActionKind from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client -from mcp.types import CallToolResult, Prompt, Resource, Tool +from mcp.types import CallToolResult, Prompt, Resource, TextContent, Tool logger = structlog.get_logger(__name__) @@ -38,7 +40,7 @@ class McpServerConfig(BaseModel): disabled: bool = False -class McpClient(Plugin): +class McpClient: """Client for connecting to a single MCP server.""" def __init__(self, name: str, config: McpServerConfig, server_name: str | None = None): @@ -46,49 +48,25 @@ def __init__(self, name: str, config: McpServerConfig, server_name: str | None = self.config = config self.server_name = server_name or name self.session: ClientSession | None = None - self._exit_stack = None - self._session_context = None - self.ai: Genkit | None = None + self._exit_stack = AsyncExitStack() + self.ai: GenkitRegistry | None = None def plugin_name(self) -> str: return self.name - async def init(self) -> list[Action]: - """Initialize MCP plugin. + def initialize(self, ai: GenkitRegistry) -> None: + self.ai = ai - MCP tools are registered dynamically upon connection, so this returns an empty list. - - Returns: - Empty list (tools are registered dynamically). - """ - return [] - - async def resolve(self, action_type: ActionKind, name: str) -> Action | None: - """Resolve an action by name. - - MCP uses dynamic registration, so this returns None. - - Args: - action_type: The kind of action to resolve. - name: The namespaced name of the action to resolve. - - Returns: - None (MCP uses dynamic registration). - """ - return None - - async def list_actions(self) -> list[ActionMetadata]: - """List available MCP actions. - - MCP tools are discovered at runtime, so this returns an empty list. - - Returns: - Empty list (tools are discovered at runtime). - """ - return [] + def resolve_action(self, ai: GenkitRegistry, kind: ActionKind, name: str) -> None: + # MCP tools are dynamic and currently registered upon connection/Discovery. + # This hook allows lazy resolution if we implement it. + pass async def connect(self): """Connects to the MCP server.""" + if self.session: + return + if self.config.disabled: logger.info(f'MCP server {self.server_name} is disabled.') return @@ -100,25 +78,24 @@ async def connect(self): ) # stdio_client returns (read, write) streams stdio_context = stdio_client(server_params) - read, write = await stdio_context.__aenter__() - self._exit_stack = stdio_context + read, write = await self._exit_stack.enter_async_context(stdio_context) # Create and initialize session session_context = ClientSession(read, write) - self.session = await session_context.__aenter__() - self._session_context = session_context + self.session = await self._exit_stack.enter_async_context(session_context) elif self.config.url: # TODO: Verify SSE client usage in mcp python SDK sse_context = sse_client(self.config.url) - read, write = await sse_context.__aenter__() - self._exit_stack = sse_context + read, write = await self._exit_stack.enter_async_context(sse_context) session_context = ClientSession(read, write) - self.session = await session_context.__aenter__() - self._session_context = session_context + self.session = await self._exit_stack.enter_async_context(session_context) + else: + raise ValueError(f"MCP client {self.name} configuration requires either 'command' or 'url'.") - await self.session.initialize() + if self.session: + await self.session.initialize() logger.info(f'Connected to MCP server: {self.server_name}') except Exception as e: @@ -130,16 +107,16 @@ async def connect(self): async def close(self): """Closes the connection.""" - if hasattr(self, '_session_context') and self._session_context: - try: - await self._session_context.__aexit__(None, None, None) - except Exception as e: - logger.debug(f'Error closing session: {e}') if self._exit_stack: try: - await self._exit_stack.__aexit__(None, None, None) - except Exception as e: - logger.debug(f'Error closing transport: {e}') + await self._exit_stack.aclose() + except (Exception, asyncio.CancelledError): + # Ignore errors during cleanup, especially cancellation from anyio + pass + + # Reset exit stack for potential reuse (reconnect) + self._exit_stack = AsyncExitStack() + self.session = None async def list_tools(self) -> list[Tool]: if not self.session: @@ -150,14 +127,21 @@ async def list_tools(self) -> list[Tool]: async def call_tool(self, tool_name: str, arguments: dict) -> Any: if not self.session: raise RuntimeError('MCP client is not connected') - result: CallToolResult = await self.session.call_tool(tool_name, arguments) - # Process result similarly to JS SDK - if result.isError: - raise RuntimeError(f'Tool execution failed: {result.content}') + logger.debug(f'MCP {self.server_name}: calling tool {tool_name}', arguments=arguments) + try: + result: CallToolResult = await self.session.call_tool(tool_name, arguments) + logger.debug(f'MCP {self.server_name}: tool {tool_name} returned') - # Simple text extraction for now - texts = [c.text for c in result.content if c.type == 'text'] - return ''.join(texts) + # Process result similarly to JS SDK + if result.isError: + raise RuntimeError(f'Tool execution failed: {result.content}') + + # Simple text extraction for now + texts = [c.text for c in result.content if c.type == 'text' and isinstance(c, TextContent)] + return {'content': ''.join(texts)} + except Exception as e: + logger.error(f'MCP {self.server_name}: tool {tool_name} failed', error=str(e)) + raise async def list_prompts(self) -> list[Prompt]: if not self.session: @@ -179,7 +163,7 @@ async def list_resources(self) -> list[Resource]: async def read_resource(self, uri: str) -> Any: if not self.session: raise RuntimeError('MCP client is not connected') - return await self.session.read_resource(uri) + return await self.session.read_resource(cast(AnyUrl, uri)) async def register_tools(self, ai: Genkit | None = None): """Registers all tools from connected client to Genkit.""" @@ -194,29 +178,38 @@ async def register_tools(self, ai: Genkit | None = None): try: tools = await self.list_tools() for tool in tools: - # Create a wrapper function for the tool - # We need to capture tool and client in closure - async def tool_wrapper(args: Any = None, _tool_name=tool.name): - # args might be Pydantic model or dict. Genkit passes dict usually? - # TODO: Validate args against schema if needed - arguments = args - if hasattr(args, 'model_dump'): - arguments = args.model_dump() - return await self.call_tool(_tool_name, arguments or {}) + # Create a wrapper function for the tool using a factory to capture tool name + def create_wrapper(tool_name: str): + async def tool_wrapper(args: Any = None): + # args might be Pydantic model or dict. Genkit passes dict usually? + # TODO: Validate args against schema if needed + arguments = args + if hasattr(args, 'model_dump'): + arguments = args.model_dump() + return await self.call_tool(tool_name, arguments or {}) + + return tool_wrapper + + tool_wrapper = create_wrapper(tool.name) # Use metadata to store MCP specific info metadata = {'mcp': {'_meta': tool._meta}} if hasattr(tool, '_meta') else {} # Define the tool in Genkit registry - registry.register_action( - kind=ActionKind.TOOL, - name=f'{self.server_name}/{tool.name}', + action = registry.register_action( + kind=cast(ActionKind, ActionKind.TOOL), + name=f'{self.server_name}_{tool.name}', fn=tool_wrapper, description=tool.description, metadata=metadata, - # TODO: json_schema conversion from tool.inputSchema ) - logger.debug(f'Registered MCP tool: {self.server_name}/{tool.name}') + + # Patch input schema from MCP tool definition + if tool.inputSchema: + action._input_schema = tool.inputSchema + action._metadata['inputSchema'] = tool.inputSchema + + logger.debug(f'Registered MCP tool: {self.server_name}_{tool.name}') except Exception as e: logger.error(f'Error registering tools for {self.server_name}: {e}') diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/client/host.py b/py/plugins/mcp/src/genkit/plugins/mcp/client/host.py index ecb4b9c5ed..3a5c8d726f 100644 --- a/py/plugins/mcp/src/genkit/plugins/mcp/client/host.py +++ b/py/plugins/mcp/src/genkit/plugins/mcp/client/host.py @@ -14,6 +14,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Optional from genkit.ai import Genkit @@ -58,6 +59,43 @@ async def disable(self, name: str): client.config.disabled = True await client.close() + async def reconnect(self, name: str): + """Reconnects a specific MCP client.""" + if name in self.clients: + client_to_reconnect = self.clients[name] + await client_to_reconnect.close() + await client_to_reconnect.connect() + + async def get_active_tools(self, ai: Genkit) -> list[str]: + """Returns a list of all active tool names from all clients.""" + active_tools = [] + for client in self.clients.values(): + if client.session: + try: + tools = await client.get_active_tools() + # Determine tool names as registered: server_tool + for tool in tools: + active_tools.append(f'{client.server_name}_{tool.name}') + except Exception as e: + # Log error but continue with other clients + # Use print or logger if available. Ideally structlog. + pass + return active_tools + + async def get_active_resources(self, ai: Genkit) -> list[str]: + """Returns a list of all active resource URIs from all clients.""" + active_resources = [] + for client in self.clients.values(): + if client.session: + try: + resources = await client.list_resources() + for resource in resources: + active_resources.append(resource.uri) + except Exception: + # Log error but continue with other clients + pass + return active_resources + def create_mcp_host(configs: dict[str, McpServerConfig]) -> McpHost: return McpHost(configs) diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/resource.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/resource.py index 1ac6b84fa0..cb344c5914 100644 --- a/py/plugins/mcp/src/genkit/plugins/mcp/util/resource.py +++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/resource.py @@ -20,9 +20,10 @@ including reading and converting resource content. """ -from typing import Any +from typing import Any, cast import structlog +from pydantic import AnyUrl from genkit.core.typing import Part from mcp.types import BlobResourceContents, ReadResourceResult, Resource, TextResourceContents @@ -74,7 +75,7 @@ def process_resource_content(resource_result: ReadResourceResult) -> Any: if not hasattr(resource_result, 'contents') or not resource_result.contents: return [] - return [from_mcp_resource_part(content) for content in resource_result.contents] + return [from_mcp_resource_part(content.model_dump()) for content in resource_result.contents] def convert_resource_to_genkit_part(resource: Resource) -> dict[str, Any]: @@ -114,7 +115,7 @@ def to_mcp_resource_contents(uri: str, parts: list[Part]) -> list[TextResourceCo for part in parts: if isinstance(part, dict): # Handle media/image content - if 'media' in part: + if 'media' in part and part['media']: media = part['media'] url = media.get('url', '') content_type = media.get('contentType', '') @@ -129,17 +130,17 @@ def to_mcp_resource_contents(uri: str, parts: list[Part]) -> list[TextResourceCo except ValueError as e: raise ValueError(f'Invalid data URL format: {url}') from e - contents.append(BlobResourceContents(uri=uri, mimeType=mime_type, blob=blob_data)) + contents.append(BlobResourceContents(uri=cast(AnyUrl, uri), mimeType=mime_type, blob=blob_data)) # Handle text content elif 'text' in part: - contents.append(TextResourceContents(uri=uri, text=part['text'])) + contents.append(TextResourceContents(uri=cast(AnyUrl, uri), text=part['text'])) else: raise ValueError( f'MCP resource messages only support media and text parts. ' f'Unsupported part type: {list(part.keys())}' ) elif isinstance(part, str): - contents.append(TextResourceContents(uri=uri, text=part)) + contents.append(TextResourceContents(uri=cast(AnyUrl, uri), text=part)) return contents diff --git a/py/plugins/mcp/tests/test_mcp_host.py b/py/plugins/mcp/tests/test_mcp_host.py index 9e1291505a..4dc5687c35 100644 --- a/py/plugins/mcp/tests/test_mcp_host.py +++ b/py/plugins/mcp/tests/test_mcp_host.py @@ -60,4 +60,4 @@ async def test_connect_and_register(self): # Verify tool registration ai.registry.register_action.assert_called() call_args = ai.registry.register_action.call_args[1] - self.assertIn('server1/tool1', call_args['name']) + self.assertIn('server1_tool1', call_args['name']) diff --git a/py/plugins/mcp/tests/test_mcp_integration.py b/py/plugins/mcp/tests/test_mcp_integration.py index 6d03041193..9beb7219fb 100644 --- a/py/plugins/mcp/tests/test_mcp_integration.py +++ b/py/plugins/mcp/tests/test_mcp_integration.py @@ -28,9 +28,14 @@ mock_mcp_modules() import pytest +from mcp.types import ListResourcesRequest, ListResourceTemplatesRequest from genkit.ai import Genkit -from genkit.plugins.mcp import McpClient, McpServerConfig, create_mcp_host, create_mcp_server +from genkit.blocks.resource import ResourceInput, ResourceOutput +from genkit.core.action import ActionRunContext +from genkit.core.error import GenkitError +from genkit.core.typing import Part, TextPart +from genkit.plugins.mcp import McpClient, McpServerConfig, McpServerOptions, create_mcp_host, create_mcp_server @pytest.mark.asyncio @@ -75,9 +80,10 @@ async def test_client_can_call_server_tool(self): mock_session = AsyncMock() mock_result = MagicMock() mock_result.isError = False - mock_content = MagicMock() - mock_content.type = 'text' - mock_content.text = '8' + mock_result.isError = False + from mcp.types import TextContent + + mock_content = TextContent(type='text', text='8') mock_result.content = [mock_content] mock_session.call_tool.return_value = mock_result @@ -87,7 +93,7 @@ async def test_client_can_call_server_tool(self): result = await client.call_tool('add', {'a': 5, 'b': 3}) # Verify - self.assertEqual(result, '8') + self.assertEqual(result, {'content': '8'}) mock_session.call_tool.assert_called_once_with('add', {'a': 5, 'b': 3}) async def test_client_can_list_server_resources(self): @@ -189,16 +195,18 @@ async def test_host_can_disable_and_enable_clients(self): # Mock the client client = host.clients['test'] client.session = AsyncMock() - client.close = AsyncMock() - client.connect = AsyncMock() - - # Disable - await host.disable('test') - self.assertTrue(client.config.disabled) + # Mock the client methods using patch.object to avoid type errors + with ( + patch.object(client, 'close', new_callable=AsyncMock) as mock_close, + patch.object(client, 'connect', new_callable=AsyncMock) as mock_connect, + ): + # Disable + await host.disable('test') + self.assertTrue(client.config.disabled) - # Enable - await host.enable('test') - self.assertFalse(client.config.disabled) + # Enable + await host.enable('test') + self.assertFalse(client.config.disabled) @pytest.mark.asyncio @@ -212,54 +220,60 @@ async def test_end_to_end_resource_flow(self): # 1. Server side: Define resource server_ai = Genkit() - server_ai.define_resource( - name='config', uri='app://config', fn=lambda req: {'content': [{'text': 'config data'}]} - ) - # 2. Create MCP server - from genkit.plugins.mcp import McpServerOptions + async def resource_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='config data'))]) + + server_ai.define_resource(name='config', uri='app://config', fn=resource_handler) + # 2. Create MCP server server = create_mcp_server(server_ai, McpServerOptions(name='test-server')) await server.setup() # 3. Verify server can list resources - resources_result = await server.list_resources({}) + resources_result = await server.list_resources(ListResourcesRequest()) self.assertEqual(len(resources_result.resources), 1) self.assertEqual(resources_result.resources[0].uri, 'app://config') # 4. Verify server can read resource + from typing import cast + + from mcp.types import TextResourceContents + request = MagicMock() request.params.uri = 'app://config' read_result = await server.read_resource(request) - self.assertEqual(read_result.contents[0].text, 'config data') + self.assertEqual(cast(TextResourceContents, read_result.contents[0]).text, 'config data') async def test_template_resource_matching(self): """Test that template resources match correctly.""" server_ai = Genkit() - def file_resource(req): - uri = req.uri - return {'content': [{'text': f'Contents of {uri}'}]} + async def file_resource(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + uri = input.uri + return ResourceOutput(content=[Part(root=TextPart(text=f'Contents of {uri}'))]) server_ai.define_resource(name='file', template='file://{+path}', fn=file_resource) # Create server - from genkit.plugins.mcp import McpServerOptions - server = create_mcp_server(server_ai, McpServerOptions(name='test-server')) await server.setup() # List templates - templates_result = await server.list_resource_templates({}) + templates_result = await server.list_resource_templates(ListResourceTemplatesRequest()) self.assertEqual(len(templates_result.resourceTemplates), 1) self.assertEqual(templates_result.resourceTemplates[0].uriTemplate, 'file://{+path}') # Read with different URIs + from typing import cast + + from mcp.types import TextResourceContents + for test_uri in ['file:///path/to/file.txt', 'file:///another/file.md', 'file:///deep/nested/path/doc.pdf']: request = MagicMock() request.params.uri = test_uri result = await server.read_resource(request) - self.assertIn(test_uri, result.contents[0].text) + self.assertIn(test_uri, cast(TextResourceContents, result.contents[0]).text) @pytest.mark.asyncio @@ -274,8 +288,6 @@ async def test_server_handles_missing_tool(self): def existing_tool(x: int) -> int: return x - from genkit.plugins.mcp import McpServerOptions - server = create_mcp_server(server_ai, McpServerOptions(name='test-server')) await server.setup() @@ -284,8 +296,6 @@ def existing_tool(x: int) -> int: request.params.name = 'nonexistent_tool' request.params.arguments = {} - from genkit.core.error import GenkitError - with self.assertRaises(GenkitError) as context: await server.call_tool(request) diff --git a/py/plugins/mcp/tests/test_mcp_server_resources.py b/py/plugins/mcp/tests/test_mcp_server_resources.py index e8f04493ae..f234897205 100644 --- a/py/plugins/mcp/tests/test_mcp_server_resources.py +++ b/py/plugins/mcp/tests/test_mcp_server_resources.py @@ -27,9 +27,15 @@ mock_mcp_modules() +from typing import cast + import pytest +from mcp.types import ListPromptsRequest, ListResourcesRequest, ListResourceTemplatesRequest, ListToolsRequest from genkit.ai import Genkit +from genkit.blocks.resource import ResourceInput, ResourceOutput +from genkit.core.action import ActionRunContext +from genkit.core.typing import Part, TextPart from genkit.plugins.mcp import McpServerOptions, create_mcp_server @@ -43,17 +49,24 @@ def setUp(self): async def test_list_resources_with_fixed_uri(self): """Test listing resources with fixed URIs.""" + # Define resources - self.ai.define_resource(name='config', uri='app://config', fn=lambda req: {'content': [{'text': 'config'}]}) + async def config_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='config'))]) + + self.ai.define_resource(name='config', uri='app://config', fn=config_handler) - self.ai.define_resource(name='data', uri='app://data', fn=lambda req: {'content': [{'text': 'data'}]}) + async def data_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='data'))]) + + self.ai.define_resource(name='data', uri='app://data', fn=data_handler) # Create server server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) await server.setup() # List resources - result = await server.list_resources({}) + result = await server.list_resources(ListResourcesRequest()) # Verify self.assertEqual(len(result.resources), 2) @@ -67,21 +80,24 @@ async def test_list_resources_with_fixed_uri(self): async def test_list_resource_templates(self): """Test listing resources with URI templates.""" + # Define template resources - self.ai.define_resource( - name='file', template='file://{+path}', fn=lambda req: {'content': [{'text': 'file content'}]} - ) + async def file_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='file content'))]) - self.ai.define_resource( - name='user', template='user://{id}/profile', fn=lambda req: {'content': [{'text': 'user profile'}]} - ) + self.ai.define_resource(name='file', template='file://{+path}', fn=file_handler) + + async def user_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='user profile'))]) + + self.ai.define_resource(name='user', template='user://{id}/profile', fn=user_handler) # Create server server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) await server.setup() # List resource templates - result = await server.list_resource_templates({}) + result = await server.list_resource_templates(ListResourceTemplatesRequest()) # Verify self.assertEqual(len(result.resourceTemplates), 2) @@ -95,38 +111,48 @@ async def test_list_resource_templates(self): async def test_list_resources_excludes_templates(self): """Test that list_resources excludes template resources.""" + # Define mixed resources - self.ai.define_resource(name='fixed', uri='app://fixed', fn=lambda req: {'content': [{'text': 'fixed'}]}) + async def fixed_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='fixed'))]) + + self.ai.define_resource(name='fixed', uri='app://fixed', fn=fixed_handler) + + async def template_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='template'))]) - self.ai.define_resource( - name='template', template='app://{id}', fn=lambda req: {'content': [{'text': 'template'}]} - ) + self.ai.define_resource(name='template', template='app://{id}', fn=template_handler) # Create server server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) await server.setup() # List resources (should only include fixed URI) - result = await server.list_resources({}) + result = await server.list_resources(ListResourcesRequest()) self.assertEqual(len(result.resources), 1) self.assertEqual(result.resources[0].name, 'fixed') async def test_list_resource_templates_excludes_fixed(self): """Test that list_resource_templates excludes fixed URI resources.""" + # Define mixed resources - self.ai.define_resource(name='fixed', uri='app://fixed', fn=lambda req: {'content': [{'text': 'fixed'}]}) + async def fixed_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='fixed'))]) + + self.ai.define_resource(name='fixed', uri='app://fixed', fn=fixed_handler) + + async def template_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='template'))]) - self.ai.define_resource( - name='template', template='app://{id}', fn=lambda req: {'content': [{'text': 'template'}]} - ) + self.ai.define_resource(name='template', template='app://{id}', fn=template_handler) # Create server server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) await server.setup() # List templates (should only include template) - result = await server.list_resource_templates({}) + result = await server.list_resource_templates(ListResourceTemplatesRequest()) self.assertEqual(len(result.resourceTemplates), 1) self.assertEqual(result.resourceTemplates[0].name, 'template') @@ -134,8 +160,8 @@ async def test_list_resource_templates_excludes_fixed(self): async def test_read_resource_with_fixed_uri(self): """Test reading a resource with fixed URI.""" - def config_resource(req): - return {'content': [{'text': 'Configuration data'}]} + async def config_resource(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='Configuration data'))]) self.ai.define_resource(name='config', uri='app://config', fn=config_resource) @@ -152,16 +178,18 @@ def config_resource(req): # Verify self.assertEqual(len(result.contents), 1) - self.assertEqual(result.contents[0].text, 'Configuration data') + from mcp.types import TextResourceContents + + self.assertEqual(cast(TextResourceContents, result.contents[0]).text, 'Configuration data') async def test_read_resource_with_template(self): """Test reading a resource with URI template.""" - def file_resource(req): - uri = req.uri + async def file_resource(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + uri = input.uri # Extract path from URI path = uri.replace('file://', '') - return {'content': [{'text': f'Contents of {path}'}]} + return ResourceOutput(content=[Part(root=TextPart(text=f'Contents of {path}'))]) self.ai.define_resource(name='file', template='file://{+path}', fn=file_resource) @@ -177,11 +205,17 @@ def file_resource(req): # Verify self.assertEqual(len(result.contents), 1) - self.assertIn('/home/user/document.txt', result.contents[0].text) + from mcp.types import TextResourceContents + + self.assertIn('/home/user/document.txt', cast(TextResourceContents, result.contents[0]).text) async def test_read_resource_not_found(self): """Test reading a non-existent resource.""" - self.ai.define_resource(name='existing', uri='app://existing', fn=lambda req: {'content': [{'text': 'data'}]}) + + async def existing_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='data'))]) + + self.ai.define_resource(name='existing', uri='app://existing', fn=existing_handler) # Create server server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) @@ -201,8 +235,14 @@ async def test_read_resource_not_found(self): async def test_read_resource_with_multiple_content_parts(self): """Test reading a resource that returns multiple content parts.""" - def multi_part_resource(req): - return {'content': [{'text': 'Part 1'}, {'text': 'Part 2'}, {'text': 'Part 3'}]} + async def multi_part_resource(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput( + content=[ + Part(root=TextPart(text='Part 1')), + Part(root=TextPart(text='Part 2')), + Part(root=TextPart(text='Part 3')), + ] + ) self.ai.define_resource(name='multi', uri='app://multi', fn=multi_part_resource) @@ -217,10 +257,13 @@ def multi_part_resource(req): result = await server.read_resource(request) # Verify + # Verify result content self.assertEqual(len(result.contents), 3) - self.assertEqual(result.contents[0].text, 'Part 1') - self.assertEqual(result.contents[1].text, 'Part 2') - self.assertEqual(result.contents[2].text, 'Part 3') + from mcp.types import TextResourceContents + + self.assertEqual(cast(TextResourceContents, result.contents[0]).text, 'Part 1') + self.assertEqual(cast(TextResourceContents, result.contents[1]).text, 'Part 2') + self.assertEqual(cast(TextResourceContents, result.contents[2]).text, 'Part 3') @pytest.mark.asyncio @@ -247,7 +290,7 @@ def multiply(input: dict[str, int]) -> int: await server.setup() # List tools - result = await server.list_tools({}) + result = await server.list_tools(ListToolsRequest()) # Verify self.assertEqual(len(result.tools), 2) @@ -275,7 +318,10 @@ def add(input: dict[str, int]) -> int: # Verify self.assertEqual(len(result.content), 1) - self.assertEqual(result.content[0].text, '8') + # Cast content to TextContent to access text attribute safely + from mcp.types import TextContent + + self.assertEqual(cast(TextContent, result.content[0]).text, '8') async def test_list_prompts(self): """Test listing prompts.""" @@ -288,7 +334,7 @@ async def test_list_prompts(self): await server.setup() # List prompts - result = await server.list_prompts({}) + result = await server.list_prompts(ListPromptsRequest()) # Verify self.assertGreaterEqual(len(result.prompts), 2) @@ -313,7 +359,10 @@ def test_tool(x: int) -> int: ai.define_prompt(name='test', prompt='Test prompt') # Define resource - ai.define_resource(name='test_resource', uri='test://resource', fn=lambda req: {'content': [{'text': 'test'}]}) + async def resource_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='test'))]) + + ai.define_resource(name='test_resource', uri='test://resource', fn=resource_handler) # Create server server = create_mcp_server(ai, McpServerOptions(name='integration-test')) diff --git a/py/pyproject.toml b/py/pyproject.toml index 8462f8aa59..75100e0ff7 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "genkit-plugin-flask", "genkit-plugin-google-cloud", "genkit-plugin-google-genai", + "genkit-plugins-mcp", "genkit-plugin-ollama", "genkit-plugin-vertex-ai", "genkit-plugin-xai", @@ -61,6 +62,8 @@ dev = [ "tox-uv>=1.25.0", "nox>=2025.2.9", "nox-uv>=0.2.2", + "structlog>=25.2.0", + "ty>=0.0.13", ] lint = ["ty>=0.0.1", "ruff>=0.9"] @@ -119,6 +122,7 @@ genkit-plugin-google-genai = { workspace = true } genkit-plugin-ollama = { workspace = true } genkit-plugin-vertex-ai = { workspace = true } genkit-plugin-xai = { workspace = true } +genkit-plugins-mcp = { workspace = true } google-genai-hello = { workspace = true } google-genai-image = { workspace = true } prompt-demo = { workspace = true } diff --git a/py/samples/mcp/pyproject.toml b/py/samples/mcp/pyproject.toml new file mode 100644 index 0000000000..16bad94b0b --- /dev/null +++ b/py/samples/mcp/pyproject.toml @@ -0,0 +1,64 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +authors = [{ name = "Google" }] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", +] +dependencies = [ + "genkit", + "genkit-plugin-google-genai", + "genkit-plugins-mcp", + "mcp>=1.25.0", + "pydantic>=2.10.5", + "starlette>=0.52.1", + "structlog>=25.2.0", + "uvicorn>=0.40.0", +] +description = "MCP Sample" +license = "Apache-2.0" +name = "mcp-sample" +requires-python = ">=3.10" +version = "0.1.0" + +[build-system] +build-backend = "hatchling.build" +requires = ["hatchling"] + +[tool.hatch.build.targets.wheel] +packages = ["src/main.py"] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.uv.sources] +genkit = { workspace = true } +genkit-plugin-google-genai = { workspace = true } +genkit-plugins-mcp = { workspace = true } diff --git a/py/samples/mcp/run.sh b/py/samples/mcp/run.sh new file mode 100755 index 0000000000..e60b1c94fb --- /dev/null +++ b/py/samples/mcp/run.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +exec uv run src/main.py "$@" diff --git a/py/samples/mcp/src/http_server.py b/py/samples/mcp/src/http_server.py new file mode 100644 index 0000000000..c1aabbeb2f --- /dev/null +++ b/py/samples/mcp/src/http_server.py @@ -0,0 +1,95 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +HTTP MCP Server Example + +This demonstrates creating an HTTP-based MCP server using SSE transport +with Starlette and the official MCP Python SDK. +""" + +import asyncio +import logging + +import mcp.types as types +import uvicorn +from mcp.server import Server +from mcp.server.sse import SseServerTransport +from starlette.applications import Starlette +from starlette.responses import Response +from starlette.routing import Mount, Route + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main(): + """Start the HTTP MCP server.""" + + # Create SSE transport logic + # The endpoint '/mcp/' is where clients will POST messages + sse = SseServerTransport('/mcp/') + + async def handle_sse(request): + """Handle incoming SSE connections.""" + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: + read_stream, write_stream = streams + + # Create a new server instance for this session + server = Server('example-server', version='1.0.0') + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name='test_http', + description='Test HTTP transport', + inputSchema={'type': 'object', 'properties': {}}, + ) + ] + + @server.call_tool() + async def call_tool(name: str, arguments: dict) -> list[types.TextContent]: + if name == 'test_http': + # In this SSE implementation, valid session ID is internal + # but we can return a confirmation. + return [types.TextContent(type='text', text='Session Active')] + raise ValueError(f'Unknown tool: {name}') + + # Run the server with the streams + await server.run(read_stream, write_stream, server.create_initialization_options()) + + # Return empty response after connection closes + return Response() + + # Define routes + # GET /mcp -> Starts SSE stream + # POST /mcp/ -> Handles messages (via SseServerTransport) + routes = [ + Route('/mcp', endpoint=handle_sse, methods=['GET']), + Mount('/mcp/', app=sse.handle_post_message), + ] + + app = Starlette(routes=routes) + + config = uvicorn.Config(app, host='0.0.0.0', port=3334, log_level='info') + server = uvicorn.Server(config) + + print('HTTP MCP server running on http://localhost:3334/mcp') + await server.serve() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/py/samples/mcp/src/main.py b/py/samples/mcp/src/main.py new file mode 100644 index 0000000000..e65bc95ff4 --- /dev/null +++ b/py/samples/mcp/src/main.py @@ -0,0 +1,383 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + +import asyncio +from functools import wraps +from pathlib import Path +from typing import cast + +import structlog +from pydantic import BaseModel + +from genkit.ai import Genkit +from genkit.core.action.types import ActionKind +from genkit.core.typing import Part, Resource1, ResourcePart, TextPart +from genkit.plugins.google_genai import GoogleAI +from genkit.plugins.mcp import McpServerConfig, create_mcp_host + +try: + from mcp import McpError +except ImportError: + + class McpError(Exception): + pass + + +logger = structlog.get_logger(__name__) + +# Get the current directory +current_dir = Path(__file__).parent +workspace_dir = current_dir.parent / 'test-workspace' +# repo_root is 4 levels up: py/samples/mcp/src -> py/samples/mcp -> py/samples -> py -> root +repo_root = current_dir.parent.parent.parent.parent + +# Initialize Genkit with GoogleAI +ai = Genkit(plugins=[GoogleAI()], model='googleai/gemini-2.5-flash') + +# Create MCP host with multiple servers +mcp_host = create_mcp_host({ + 'git-client': McpServerConfig(command='uvx', args=['mcp-server-git']), + 'fs': McpServerConfig(command='npx', args=['-y', '@modelcontextprotocol/server-filesystem', str(workspace_dir)]), + 'everything': McpServerConfig(command='npx', args=['-y', '@modelcontextprotocol/server-everything']), +}) + + +def with_mcp_host(func): + @wraps(func) + async def wrapper(*args, **kwargs): + await mcp_host.start() + try: + return await func(*args, **kwargs) + finally: + await mcp_host.close() + + return wrapper + + +async def read_resource_from_host(host, uri: str) -> str: + """Try to read a resource from any connected MCP client.""" + errors = [] + for client in host.clients.values(): + if not client.session: + continue + try: + # client.read_resource returns ReadResourceResult + res = await client.read_resource(uri) + # Combine text content + text = '' + if hasattr(res, 'contents'): + for c in res.contents: + if hasattr(c, 'text') and c.text: + text += c.text + '\n' + elif hasattr(c, 'blob'): + text += f'[Blob data type={getattr(c, "mimeType", "?")}]\n' + return text + except Exception as e: + errors.append(f'{client.name}: {e}') + + if not errors: + return 'No connected clients found.' + raise RuntimeError(f'Could not read resource {uri}. Errors: {errors}') + + +async def resolve_prompt_resources(prompt: list[Part], host) -> list[Part]: + """Manually resolve ResourceParts in the prompt to TextParts.""" + new_prompt = [] + for part in prompt: + if isinstance(part.root, ResourcePart): + uri = part.root.resource.uri + try: + content = await read_resource_from_host(host, uri) + new_prompt.append(Part(root=TextPart(text=f'Resource {uri} Content:\n{content}'))) + except Exception as e: + new_prompt.append(Part(root=TextPart(text=f'Failed to load resource {uri}: {e}'))) + else: + new_prompt.append(part) + return new_prompt + + +@ai.flow(name='git-commits') +@with_mcp_host +async def git_commits(query: str = ''): + """Summarize recent git commits using MCP git client.""" + # Register tools to registry directly + await mcp_host.register_tools(ai) + + # Get active tool names for this call + tools = await mcp_host.get_active_tools(ai) + + result = await ai.generate(prompt=f"summarize last 5 commits in '{repo_root}'", tools=tools) + return result.text + + +@ai.flow(name='dynamic-git-commits') +@with_mcp_host +async def dynamic_git_commits(query: str = ''): + """Summarize recent git commits using wildcard tool selection.""" + await mcp_host.register_tools(ai) + + # In Python, we might not support wildcards in tools list yet, + # so we'll simulate by getting all tools matching the pattern. + # So we use the string pattern if supported. + # tools=['git-client_*'] + + all_tools = await mcp_host.get_active_tools(ai) + tools = [t for t in all_tools if t.startswith('git-client_')] + + result = await ai.generate( + prompt=f"summarize last 5 commits. You must use the argument key 'repo_path' set to '{repo_root}'. Do not use 'path'.", + tools=tools, + ) + return result.text + + +@ai.flow(name='get-file') +@with_mcp_host +async def get_file(query: str = ''): + """Read and summarize a file using MCP filesystem client.""" + await mcp_host.register_tools(ai) + tools = await mcp_host.get_active_tools(ai) + + result = await ai.generate(prompt=f"summarize contents of hello-world.txt (in '{workspace_dir}')", tools=tools) + return result.text + + +@ai.flow(name='dynamic-get-file') +@with_mcp_host +async def dynamic_get_file(query: str = ''): + """Read file using specific tool selection.""" + await mcp_host.register_tools(ai) + + # Filter for specific tool: 'fs_read_file' + # Filter for specific tool: 'fs_read_file' or newer variants + tools = [t for t in await mcp_host.get_active_tools(ai) if t.endswith('read_file') or t.endswith('read_text_file')] + + result = await ai.generate(prompt=f"summarize contents of hello-world.txt (in '{workspace_dir}')", tools=tools) + return result.text + + +@ai.flow(name='dynamic-prefix-tool') +@with_mcp_host +async def dynamic_prefix_tool(query: str = ''): + """Read file using prefix tool selection.""" + await mcp_host.register_tools(ai) + + # Filter for prefix: 'fs_read_' + all_tools = await mcp_host.get_active_tools(ai) + tools = [t for t in all_tools if t.startswith('fs_read_')] + + result = await ai.generate(prompt=f"summarize contents of hello-world.txt (in '{workspace_dir}')", tools=tools) + return result.text + + +# @ai.flow(name='dynamic-disable-enable') +# @with_mcp_host +# async def dynamic_disable_enable(query: str = ''): +# """Test disabling and re-enabling an MCP client.""" +# return "Skipped dynamic-disable-enable flow due to hang issues." +# await mcp_host.register_tools(ai) +# tools = [t for t in await mcp_host.get_active_tools(ai) if t == 'fs_read_file'] + +# # Run successfully +# result1 = await ai.generate(prompt=f"summarize contents of hello-world.txt (in '{workspace_dir}')", tools=tools) +# text1 = result1.text + +# # Disable 'fs' and try to run (should fail) +# await mcp_host.disable('fs') +# text2 = '' +# try: +# # We don't re-register tools, hoping the registry or generate handles the disabled client +# result = await ai.generate(prompt=f"summarize contents of hello-world.txt (in '{workspace_dir}')", tools=tools) +# text2 = f'ERROR! This should have failed but succeeded: {result.text}' +# except Exception as e: +# text2 = str(e) + +# # Re-enable 'fs' and run +# await mcp_host.enable('fs') +# # Re-connect/re-register might be needed +# await mcp_host.register_tools(ai) + +# result3 = await ai.generate(prompt=f"summarize contents of hello-world.txt (in '{workspace_dir}')", tools=tools) +# text3 = result3.text + +# return f'Original:
{text1}
After Disable:
{text2}
After Enable:
{text3}' + + +# @ai.flow(name='test-resource') +# @with_mcp_host +# async def test_resource(query: str = ''): +# """Test reading a resource.""" +# try: +# # Pass resources as grounding context if supported +# resources = await mcp_host.get_active_resources(ai) +# +# # Manually resolve resources because the plugin might not support ResourcePart +# raw_prompt = [ +# Part(root=TextPart(text='analyze this: ')), +# Part(root=ResourcePart(resource=Resource1(uri='test://static/resource/1'))) +# ] +# resolved_prompt = await resolve_prompt_resources(raw_prompt, mcp_host) +# +# result = await ai.generate( +# prompt=resolved_prompt, +# context={'resources': resources} +# ) +# return result.text +# except McpError as e: +# return f"MCP Error (Server likely doesn't support reading this resource): {e}" +# except Exception as e: +# return f"Flow failed: {e}" +# +# +# @ai.flow(name='dynamic-test-resources') +# @with_mcp_host +# async def dynamic_test_resources(query: str = ''): +# """Test reading resources with wildcard.""" +# # Simulate wildcard resources if not natively supported +# # resources=['resource/*'] +# +# try: +# all_resources = await mcp_host.get_active_resources(ai) +# resources = [r for r in all_resources if r.startswith('test://')] # simplified filter +# +# raw_prompt = [ +# Part(root=TextPart(text='analyze this: ')), +# Part(root=ResourcePart(resource=Resource1(uri='test://static/resource/1'))) +# ] +# resolved_prompt = await resolve_prompt_resources(raw_prompt, mcp_host) +# +# result = await ai.generate( +# prompt=resolved_prompt, +# context={'resources': resources} +# ) +# return result.text +# except McpError as e: +# return f"MCP Error: {e}" +# except Exception as e: +# return f"Flow failed: {e}" + + +@ai.flow(name='dynamic-test-one-resource') +@with_mcp_host +async def dynamic_test_one_resource(query: str = ''): + """Test reading one specific resource.""" + resources = ['test://static/resource/1'] + + try: + raw_prompt = [ + Part(root=TextPart(text='analyze this: ')), + Part(root=ResourcePart(resource=Resource1(uri='test://static/resource/1'))), + ] + resolved_prompt = await resolve_prompt_resources(raw_prompt, mcp_host) + + result = await ai.generate(prompt=resolved_prompt, context={'resources': resources}) + return result.text + except McpError as e: + return f'MCP Error: {e}' + except Exception as e: + return f'Flow failed: {e}' + + +@ai.flow(name='update-file') +@with_mcp_host +async def update_file(query: str = ''): + """Update a file using MCP filesystem client.""" + await mcp_host.register_tools(ai) + tools = await mcp_host.get_active_tools(ai) + + result = await ai.generate( + prompt=f"Improve hello-world.txt (in '{workspace_dir}') by rewriting the text, making it longer, use your imagination.", + tools=tools, + ) + return result.text + + +class ControlMcpInput(BaseModel): + action: str # 'RECONNECT', 'ENABLE', 'DISABLE', 'DISCONNECT' + client_id: str | None = 'git-client' + + +@ai.flow(name='control_mcp') +async def control_mcp(input: ControlMcpInput): + """Control MCP client connections (enable/disable/reconnect).""" + client_id = input.client_id + action = input.action.upper() + + if action == 'DISABLE': + if not client_id: + raise ValueError('client_id is required for DISABLE action') + await mcp_host.disable(client_id) + elif action == 'DISCONNECT': + # Assuming disconnect is equivalent to close for a specific client + if client_id and client_id in mcp_host.clients: + await mcp_host.clients[client_id].close() + elif action == 'RECONNECT': + if not client_id: + raise ValueError('client_id is required for RECONNECT action') + await mcp_host.reconnect(client_id) + elif action == 'ENABLE': + if not client_id: + raise ValueError('client_id is required for ENABLE action') + await mcp_host.enable(client_id) + + return f'Action {action} completed for {client_id}' + + +async def main(): + """Run sample flows.""" + import os + + # Only run test flows if not in dev mode (Dev UI) + if os.getenv('GENKIT_ENV') == 'dev': + logger.info('Running in dev mode - flows available in Dev UI') + logger.info('Genkit server running. Press Ctrl+C to stop.') + # Keep the process alive for Dev UI + await asyncio.Event().wait() + return + + logger.info('Starting MCP sample application') + + flows = ai.registry.get_actions_by_kind(cast(ActionKind, ActionKind.FLOW)) + logger.info(f'DEBUG: Registered flows: {list(flows.keys())}') + + # Test git commits flow + logger.info('Testing git-commits flow...') + try: + result = await git_commits() + logger.info('git-commits result', result=result[:200]) + except Exception as e: + logger.error('git-commits failed', error=str(e), exc_info=True) + + # Test get-file flow + logger.info('Testing get-file flow...') + try: + result = await get_file() + logger.info('get-file result', result=result[:200]) + except Exception as e: + logger.error('get-file failed', error=str(e), exc_info=True) + + +if __name__ == '__main__': + import sys + + # If running directly (not via genkit start), execute the test flows + if len(sys.argv) == 1: + ai.run_main(main()) + # Otherwise, just keep the server running for Dev UI + else: + # This allows genkit start to work properly + pass diff --git a/py/samples/mcp/src/server.py b/py/samples/mcp/src/server.py new file mode 100644 index 0000000000..9978a1f12e --- /dev/null +++ b/py/samples/mcp/src/server.py @@ -0,0 +1,83 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +MCP Server Example + +This demonstrates creating an MCP server that exposes Genkit tools, prompts, +and resources through the Model Context Protocol. +""" + +import asyncio + +from pydantic import BaseModel, Field + +from genkit.ai import Genkit +from genkit.blocks.resource import ResourceInput, ResourceOutput +from genkit.core.action import ActionRunContext +from genkit.core.typing import Part, TextPart +from genkit.plugins.mcp import McpServerOptions, create_mcp_server + +# Initialize Genkit +ai = Genkit(plugins=[]) + + +# Define a tool +class AddInput(BaseModel): + a: int = Field(..., description='First number') + b: int = Field(..., description='Second number') + + +@ai.tool(name='add', description='add two numbers together') +def add(input: AddInput) -> int: + return input.a + input.b + + +# Define a prompt +happy_prompt = ai.define_prompt( + name='happy', + input_schema={'action': str}, + prompt="If you're happy and you know it, {{action}}.", +) + + +# Define resources +async def my_resource_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text='my resource'))]) + + +ai.define_resource(name='my resources', uri='test://static/resource/1', fn=my_resource_handler) + + +async def file_resource_handler(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + uri = input.uri + return ResourceOutput(content=[Part(root=TextPart(text=f'file contents for {uri}'))]) + + +ai.define_resource(name='file', template='file://{path}', fn=file_resource_handler) + + +async def main(): + """Start the MCP server.""" + # Create MCP server + server = create_mcp_server(ai, McpServerOptions(name='example_server', version='0.0.1')) + + print('Starting MCP server on stdio...') + await server.start() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/py/samples/mcp/test-workspace/hello-world.txt b/py/samples/mcp/test-workspace/hello-world.txt new file mode 100644 index 0000000000..cd0875583a --- /dev/null +++ b/py/samples/mcp/test-workspace/hello-world.txt @@ -0,0 +1 @@ +Hello world! diff --git a/py/uv.lock b/py/uv.lock index dc87eacd93..ae90265e53 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -39,6 +39,7 @@ members = [ "google-genai-image", "google-genai-vertexai-hello", "google-genai-vertexai-image", + "mcp-sample", "menu", "model-garden-example", "multi-server", @@ -2062,12 +2063,14 @@ source = { editable = "plugins/mcp" } dependencies = [ { name = "genkit" }, { name = "mcp" }, + { name = "structlog" }, ] [package.metadata] requires-dist = [ { name = "genkit", editable = "packages/genkit" }, { name = "mcp" }, + { name = "structlog" }, ] [[package]] @@ -2089,6 +2092,7 @@ dependencies = [ { name = "genkit-plugin-ollama" }, { name = "genkit-plugin-vertex-ai" }, { name = "genkit-plugin-xai" }, + { name = "genkit-plugins-mcp" }, { name = "liccheck" }, { name = "mcp" }, { name = "strenum", marker = "python_full_version < '3.11'" }, @@ -2110,9 +2114,11 @@ dev = [ { name = "pytest-cov" }, { name = "pytest-mock" }, { name = "pytest-watcher" }, + { name = "structlog" }, { name = "tox" }, { name = "tox-uv" }, { name = "twine" }, + { name = "ty" }, ] lint = [ { name = "ruff" }, @@ -2135,6 +2141,7 @@ requires-dist = [ { name = "genkit-plugin-ollama", editable = "plugins/ollama" }, { name = "genkit-plugin-vertex-ai", editable = "plugins/vertex-ai" }, { name = "genkit-plugin-xai", editable = "plugins/xai" }, + { name = "genkit-plugins-mcp", editable = "plugins/mcp" }, { name = "liccheck", specifier = ">=0.9.2" }, { name = "mcp", specifier = ">=1.25.0" }, { name = "strenum", marker = "python_full_version < '3.11'", specifier = ">=0.4.15" }, @@ -2156,9 +2163,11 @@ dev = [ { name = "pytest-cov", specifier = ">=6.0.0" }, { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "pytest-watcher", specifier = ">=0.4.3" }, + { name = "structlog", specifier = ">=25.2.0" }, { name = "tox", specifier = ">=4.25.0" }, { name = "tox-uv", specifier = ">=1.25.0" }, { name = "twine", specifier = ">=6.1.0" }, + { name = "ty", specifier = ">=0.0.13" }, ] lint = [ { name = "ruff", specifier = ">=0.9" }, @@ -3599,6 +3608,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e2/fc/6dc7659c2ae5ddf280477011f4213a74f806862856b796ef08f028e664bf/mcp-1.25.0-py3-none-any.whl", hash = "sha256:b37c38144a666add0862614cc79ec276e97d72aa8ca26d622818d4e278b9721a", size = 233076, upload-time = "2025-12-19T10:19:55.416Z" }, ] +[[package]] +name = "mcp-sample" +version = "0.1.0" +source = { editable = "samples/mcp" } +dependencies = [ + { name = "genkit" }, + { name = "genkit-plugin-google-genai" }, + { name = "genkit-plugins-mcp" }, + { name = "mcp" }, + { name = "pydantic" }, + { name = "starlette" }, + { name = "structlog" }, + { name = "uvicorn" }, +] + +[package.metadata] +requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, + { name = "genkit-plugin-google-genai", editable = "plugins/google-genai" }, + { name = "genkit-plugins-mcp", editable = "plugins/mcp" }, + { name = "mcp", specifier = ">=1.25.0" }, + { name = "pydantic", specifier = ">=2.10.5" }, + { name = "starlette", specifier = ">=0.52.1" }, + { name = "structlog", specifier = ">=25.2.0" }, + { name = "uvicorn", specifier = ">=0.40.0" }, +] + [[package]] name = "mdurl" version = "0.1.2"