From 0bd4eded4fa99848d5d1ca42f5c57c0311a83254 Mon Sep 17 00:00:00 2001 From: Mengqin Shen Date: Wed, 28 Jan 2026 21:13:37 -0800 Subject: [PATCH] fix(py): sync with main --- py/bin/run_sample | 9 +- py/packages/genkit/src/genkit/core/typing.py | 5 +- .../src/genkit/plugins/google_genai/google.py | 335 ++++++----- .../plugins/google_genai/models/gemini.py | 509 +++++++++++++---- .../plugins/google_genai/models/imagen.py | 16 +- .../genkit/plugins/google_genai/models/veo.py | 193 +++++++ .../google-genai/test/google_plugin_test.py | 333 ++++++----- .../xai/src/genkit/plugins/xai/model_info.py | 111 ++-- .../xai/src/genkit/plugins/xai/models.py | 127 ++--- .../xai/src/genkit/plugins/xai/plugin.py | 7 +- py/plugins/xai/tests/xai_models_test.py | 33 +- py/pyproject.toml | 3 - py/samples/anthropic-hello/src/main.py | 16 +- py/samples/deepseek-hello/src/main.py | 97 ++-- py/samples/google-genai-hello/src/main.py | 538 ++++++++++-------- py/samples/shared/__init__.py | 54 -- py/samples/shared/flows.py | 144 ----- py/samples/shared/tools.py | 98 ---- py/samples/shared/types.py | 58 -- py/samples/xai-hello/src/main.py | 241 ++++++-- 20 files changed, 1689 insertions(+), 1238 deletions(-) create mode 100644 py/plugins/google-genai/src/genkit/plugins/google_genai/models/veo.py delete mode 100644 py/samples/shared/__init__.py delete mode 100644 py/samples/shared/flows.py delete mode 100644 py/samples/shared/tools.py delete mode 100644 py/samples/shared/types.py diff --git a/py/bin/run_sample b/py/bin/run_sample index cb6784a9f1..ab70ba202c 100755 --- a/py/bin/run_sample +++ b/py/bin/run_sample @@ -26,9 +26,6 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PY_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" SAMPLES_DIR="$PY_DIR/samples" -PYTHONPATH="${PYTHONPATH:-}" -export PYTHONPATH="$PY_DIR:$PYTHONPATH" - # List available samples using find for robustness list_samples() { # Only list directories that contain a run.sh script @@ -82,7 +79,7 @@ fi if [[ $# -lt 1 ]]; then while true; do SAMPLES=$(list_samples) - + if [[ -n "$FZF_CMD" ]]; then # Use fzf with preview window if available SAMPLE_NAME=$(echo "$SAMPLES" | $FZF_CMD --preview "if [ -f $SAMPLES_DIR/{}/README.md ]; then cat $SAMPLES_DIR/{}/README.md; else echo 'No README found.'; fi" --preview-window=right:60%:wrap --height=100% --reverse --header="Select a sample to run (Ctrl-C to exit sample and return here)") @@ -132,7 +129,7 @@ if [[ $# -lt 1 ]]; then echo "Error: Sample '$SAMPLE_NAME' does not have a 'run.sh' script." fi ) - + echo "" echo "Sample '$SAMPLE_NAME' exited. Returning to menu..." echo "" @@ -144,7 +141,7 @@ if [[ $# -lt 1 ]]; then else SAMPLE_NAME="$1" shift - + SAMPLE_DIR="$SAMPLES_DIR/$SAMPLE_NAME" if [[ ! -d "$SAMPLE_DIR" ]]; then diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index f004106bbf..9e9e91372a 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -32,10 +32,11 @@ else: from enum import StrEnum -from typing import Any, Literal +from pydantic.alias_generators import to_camel + +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field, RootModel -from pydantic.alias_generators import to_camel class Model(RootModel[Any]): diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py index 492858fab6..f98d8dd691 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py @@ -40,6 +40,7 @@ SUPPORTED_MODELS, GeminiConfigSchema, GeminiModel, + get_model_config_schema, google_model_info, ) from genkit.plugins.google_genai.models.imagen import ( @@ -47,6 +48,78 @@ ImagenModel, vertexai_image_model_info, ) +from genkit.plugins.google_genai.models.veo import ( + VeoConfigSchema, + VeoModel, + veo_model_info, +) + + +class GenaiModels: + """Container for models discovered from the API.""" + + gemini: list[str] = [] + imagen: list[str] = [] + embedders: list[str] = [] + veo: list[str] = [] + + def __init__(self): + self.gemini = [] + self.imagen = [] + self.embedders = [] + self.veo = [] + + +def _list_genai_models(client: genai.Client, is_vertex: bool) -> GenaiModels: + """Lists supported models and embedders from the Google GenAI SDK. + + Mirrors logic from Go plugin's listGenaiModels. + """ + models = GenaiModels() + # Go logic uses "gemini", "gemma" as allowed substrings for Gemini models + # and checks for "embedContent" capability for embedders + # and "predict" + "imagen" (or similar) for Imagen. + + for m in client.models.list(): + name = m.name + # Cleanup prefix + if is_vertex: + if name.startswith('publishers/google/models/'): + name = name[25:] + elif name.startswith('models/'): + name = name[7:] + + description = (m.description or '').lower() + if 'deprecated' in description: + continue + + # Embedders + if 'embedContent' in m.supported_actions: + models.embedders.append(name) + continue + + # Imagen (Vertex mostly) + # Go checks: slices.Contains(item.SupportedActions, "predict") && strings.Contains(name, "imagen") + if 'predict' in m.supported_actions and 'imagen' in name.lower(): + models.imagen.append(name) + continue + + # Veo + if 'generateVideos' in m.supported_actions or 'veo' in name.lower(): + models.veo.append(name) + continue + + # Gemini / Gemma + # Go checks: slices.Contains(item.SupportedActions, "generateContent") + # then filters for "gemini" or "gemma" in name + if 'generateContent' in m.supported_actions: + lower_name = name.lower() + if 'gemini' in lower_name or 'gemma' in lower_name: + models.gemini.append(name) + continue + + return models + GOOGLEAI_PLUGIN_NAME = 'googleai' VERTEXAI_PLUGIN_NAME = 'vertexai' @@ -136,44 +209,40 @@ async def init(self) -> list[Action]: Returns: List of Action objects for known/supported models. """ - return [ - *self._list_known_models(), - *self._list_known_embedders(), - ] + genai_models = _list_genai_models(self._client, is_vertex=False) + + actions = [] + # Gemini Models + for name in genai_models.gemini: + actions.append(self._resolve_model(googleai_name(name))) + + # Embedders + for name in genai_models.embedders: + actions.append(self._resolve_embedder(googleai_name(name))) + + return actions def _list_known_models(self) -> list[Action]: """List known models as Action objects. - Returns: - List of Action objects for known Gemini models. + Deprecated: Used only for internal testing if needed, but 'init' should be source of truth. + Keeping for compatibility but redirecting to dynamic list logic if accessed directly? + The interface defines init(), this helper was internal. """ - known_model_names = [ - 'gemini-3-flash-preview', - 'gemini-3-pro-preview', - 'gemini-2.5-pro', - 'gemini-2.5-flash', - 'gemini-2.5-flash-lite', - 'gemini-2.0-flash', - 'gemini-2.0-flash-lite', - ] + # Re-use init logic synchronously? init is async. + # Let's implementation just mimic init logic but sync call to client.models.list is fine (it is iterator) + genai_models = _list_genai_models(self._client, is_vertex=False) actions = [] - for model_name in known_model_names: - actions.append(self._resolve_model(googleai_name(model_name))) + for name in genai_models.gemini: + actions.append(self._resolve_model(googleai_name(name))) return actions def _list_known_embedders(self) -> list[Action]: - """List known embedders as Action objects. - - Returns: - List of Action objects for known embedders. - """ - known_embedders = [ - GeminiEmbeddingModels.TEXT_EMBEDDING_004, - GeminiEmbeddingModels.GEMINI_EMBEDDING_001, - ] + """List known embedders as Action objects.""" + genai_models = _list_genai_models(self._client, is_vertex=False) actions = [] - for embedder_name in known_embedders: - actions.append(self._resolve_embedder(googleai_name(embedder_name.value))) + for name in genai_models.embedders: + actions.append(self._resolve_embedder(googleai_name(name))) return actions async def resolve(self, action_type: ActionKind, name: str) -> Action | None: @@ -259,32 +328,30 @@ async def list_actions(self) -> list[ActionMetadata]: - info (dict): The metadata dictionary describing the model configuration and properties. - config_schema (type): The schema class used for validating the model's configuration. """ - actions_list = list() - for m in self._client.models.list(): - model_name = m.name - if not model_name: - continue - name = model_name.replace('models/', '') - if m.supported_actions and 'generateContent' in m.supported_actions: - actions_list.append( - model_action_metadata( - name=googleai_name(name), - info=google_model_info(name).model_dump(), - ), + genai_models = _list_genai_models(self._client, is_vertex=False) + actions_list = [] + + for name in genai_models.gemini: + actions_list.append( + model_action_metadata( + name=googleai_name(name), + info=google_model_info(name).model_dump(by_alias=True), + config_schema=get_model_config_schema(name), ) + ) - if m.supported_actions and 'embedContent' in m.supported_actions: - embed_info = default_embedder_info(name) - actions_list.append( - embedder_action_metadata( - name=googleai_name(name), - options=EmbedderOptions( - label=embed_info.get('label'), - supports=EmbedderSupports(input=embed_info.get('supports', {}).get('input')), - dimensions=embed_info.get('dimensions'), - ), - ) + for name in genai_models.embedders: + embed_info = default_embedder_info(name) + actions_list.append( + embedder_action_metadata( + name=googleai_name(name), + options=EmbedderOptions( + label=embed_info.get('label'), + supports=EmbedderSupports(input=embed_info.get('supports', {}).get('input')), + dimensions=embed_info.get('dimensions'), + ), ) + ) return actions_list @@ -349,47 +416,41 @@ async def init(self) -> list[Action]: Returns: List of Action objects for known/supported models. """ - return [ - *self._list_known_models(), - *self._list_known_embedders(), - ] + genai_models = _list_genai_models(self._client, is_vertex=True) + actions = [] - def _list_known_models(self) -> list[Action]: - """List known models as Action objects. + for name in genai_models.gemini: + actions.append(self._resolve_model(vertexai_name(name))) - Returns: - List of Action objects for known Gemini and Imagen models. - """ - known_model_names = [ - 'gemini-2.5-flash-lite', - 'gemini-2.5-pro', - 'gemini-2.5-flash', - 'gemini-2.0-flash-001', - 'gemini-2.0-flash', - 'gemini-2.0-flash-lite', - 'gemini-2.0-flash-lite-001', - 'imagen-4.0-generate-001', - ] + for name in genai_models.imagen: + actions.append(self._resolve_model(vertexai_name(name))) + + for name in genai_models.veo: + actions.append(self._resolve_model(vertexai_name(name))) + + for name in genai_models.embedders: + actions.append(self._resolve_embedder(vertexai_name(name))) + + return actions + + def _list_known_models(self) -> list[Action]: + """List known models as Action objects.""" + genai_models = _list_genai_models(self._client, is_vertex=True) actions = [] - for model_name in known_model_names: - actions.append(self._resolve_model(vertexai_name(model_name))) + for name in genai_models.gemini: + actions.append(self._resolve_model(vertexai_name(name))) + for name in genai_models.imagen: + actions.append(self._resolve_model(vertexai_name(name))) + for name in genai_models.veo: + actions.append(self._resolve_model(vertexai_name(name))) return actions def _list_known_embedders(self) -> list[Action]: - """List known embedders as Action objects. - - Returns: - List of Action objects for known embedders. - """ - known_embedders = [ - VertexEmbeddingModels.TEXT_EMBEDDING_005_ENG, - VertexEmbeddingModels.TEXT_EMBEDDING_002_MULTILINGUAL, - # Note: multimodalembedding@001 requires different API structure (not yet implemented) - VertexEmbeddingModels.GEMINI_EMBEDDING_001, - ] + """List known embedders as Action objects.""" + genai_models = _list_genai_models(self._client, is_vertex=True) actions = [] - for embedder_name in known_embedders: - actions.append(self._resolve_embedder(vertexai_name(embedder_name.value))) + for name in genai_models.embedders: + actions.append(self._resolve_embedder(vertexai_name(name))) return actions async def resolve(self, action_type: ActionKind, name: str) -> Action | None: @@ -424,6 +485,9 @@ def _resolve_model(self, name: str) -> Action: model_ref = vertexai_image_model_info(_clean_name) model = ImagenModel(_clean_name, self._client) IMAGE_SUPPORTED_MODELS[_clean_name] = model_ref + elif _clean_name.lower().startswith('veo'): + model_ref = veo_model_info(_clean_name) + model = VeoModel(_clean_name, self._client) else: model_ref = google_model_info(_clean_name) model = GeminiModel(_clean_name, self._client) @@ -481,24 +545,49 @@ async def list_actions(self) -> list[ActionMetadata]: - info (dict): The metadata dictionary describing the model configuration and properties. - config_schema (type): The schema class used for validating the model's configuration. """ - actions_list = list() - for m in self._client.models.list(): - model_name = m.name - if not model_name: - continue - name = model_name.replace('publishers/google/models/', '') - if 'embed' in name.lower(): - embed_info = default_embedder_info(name) - actions_list.append( - embedder_action_metadata( - name=vertexai_name(name), - options=EmbedderOptions( - label=embed_info.get('label'), - supports=EmbedderSupports(input=embed_info.get('supports', {}).get('input')), - dimensions=embed_info.get('dimensions'), - ), - ) + genai_models = _list_genai_models(self._client, is_vertex=True) + actions_list = [] + + for name in genai_models.gemini: + actions_list.append( + model_action_metadata( + name=vertexai_name(name), + info=google_model_info(name).model_dump(), + config_schema=get_model_config_schema(name), ) + ) + + for name in genai_models.imagen: + # Imagen models might use vertexai_image_model_info or similar + actions_list.append( + model_action_metadata( + name=vertexai_name(name), + info=vertexai_image_model_info(name).model_dump(), + # Image models config? Maybe different + ) + ) + + for name in genai_models.veo: + actions_list.append( + model_action_metadata( + name=vertexai_name(name), + info=veo_model_info(name).model_dump(), + config_schema=VeoConfigSchema, + ) + ) + + for name in genai_models.embedders: + embed_info = default_embedder_info(name) + actions_list.append( + embedder_action_metadata( + name=vertexai_name(name), + options=EmbedderOptions( + label=embed_info.get('label'), + supports=EmbedderSupports(input=embed_info.get('supports', {}).get('input')), + dimensions=embed_info.get('dimensions'), + ), + ) + # List all the vertexai models for generate actions actions_list.append( model_action_metadata( @@ -511,39 +600,25 @@ async def list_actions(self) -> list[ActionMetadata]: return actions_list -def _inject_attribution_headers( - http_options: HttpOptions | HttpOptionsDict | None = None, - base_url: str | None = None, - api_version: str | None = None, -) -> HttpOptions: +def _inject_attribution_headers(http_options: HttpOptions | dict | None = None) -> HttpOptions: """Adds genkit client info to the appropriate http headers.""" - # Normalize to HttpOptions instance - opts: HttpOptions - if http_options is None: - opts = HttpOptions() - elif isinstance(http_options, HttpOptions): - opts = http_options + if not http_options: + http_options = HttpOptions() else: - # HttpOptionsDict or other dict-like - use model_validate for proper type conversion - opts = HttpOptions.model_validate(http_options) - - if base_url: - opts.base_url = base_url - - if api_version: - opts.api_version = api_version + if isinstance(http_options, dict): + http_options = HttpOptions(**http_options) - if not opts.headers: - opts.headers = {} + if not http_options.headers: + http_options.headers = {} - if 'x-goog-api-client' not in opts.headers: - opts.headers['x-goog-api-client'] = GENKIT_CLIENT_HEADER + if 'x-goog-api-client' not in http_options.headers: + http_options.headers['x-goog-api-client'] = GENKIT_CLIENT_HEADER else: - opts.headers['x-goog-api-client'] += f' {GENKIT_CLIENT_HEADER}' + http_options.headers['x-goog-api-client'] += f' {GENKIT_CLIENT_HEADER}' - if 'user-agent' not in opts.headers: - opts.headers['user-agent'] = GENKIT_CLIENT_HEADER + if 'user-agent' not in http_options.headers: + http_options.headers['user-agent'] = GENKIT_CLIENT_HEADER else: - opts.headers['user-agent'] += f' {GENKIT_CLIENT_HEADER}' + http_options.headers['user-agent'] += f' {GENKIT_CLIENT_HEADER}' - return opts + return http_options diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index 014a4b31f2..4d2b3d3a66 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -143,11 +143,11 @@ from enum import StrEnum from functools import cached_property -from typing import Any, cast +from typing import Annotated, Any, cast from google import genai from google.genai import types as genai_types -from pydantic import ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, WithJsonSchema from genkit.ai import ( ActionRunContext, @@ -177,9 +177,121 @@ ) -class GeminiConfigSchema(genai_types.GenerateContentConfig): +class HarmCategory(StrEnum): + HARM_CATEGORY_UNSPECIFIED = 'HARM_CATEGORY_UNSPECIFIED' + HARM_CATEGORY_HATE_SPEECH = 'HARM_CATEGORY_HATE_SPEECH' + HARM_CATEGORY_SEXUALLY_EXPLICIT = 'HARM_CATEGORY_SEXUALLY_EXPLICIT' + HARM_CATEGORY_HARASSMENT = 'HARM_CATEGORY_HARASSMENT' + HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT' + + +class HarmBlockThreshold(StrEnum): + BLOCK_LOW_AND_ABOVE = 'BLOCK_LOW_AND_ABOVE' + BLOCK_MEDIUM_AND_ABOVE = 'BLOCK_MEDIUM_AND_ABOVE' + BLOCK_ONLY_HIGH = 'BLOCK_ONLY_HIGH' + BLOCK_NONE = 'BLOCK_NONE' + + +class SafetySettingsSchema(BaseModel): + model_config = ConfigDict(extra='allow', populate_by_name=True) + category: HarmCategory + threshold: HarmBlockThreshold + + +class PrebuiltVoiceConfig(BaseModel): + model_config = ConfigDict(extra='allow', populate_by_name=True) + voice_name: str | None = Field(None, alias='voiceName') + + +class FunctionCallingMode(StrEnum): + MODE_UNSPECIFIED = 'MODE_UNSPECIFIED' + AUTO = 'AUTO' + ANY = 'ANY' + NONE = 'NONE' + + +class FunctionCallingConfig(BaseModel): + model_config = ConfigDict(extra='allow', populate_by_name=True) + mode: FunctionCallingMode | None = None + allowed_function_names: list[str] | None = Field(None, alias='allowedFunctionNames') + + +class ThinkingLevel(StrEnum): + MINIMAL = 'MINIMAL' + LOW = 'LOW' + MEDIUM = 'MEDIUM' + HIGH = 'HIGH' + + +class ThinkingConfigSchema(BaseModel): + model_config = ConfigDict(extra='allow', populate_by_name=True) + include_thoughts: bool | None = Field(None, alias='includeThoughts') + thinking_budget: int | None = Field(None, alias='thinkingBudget') + thinking_level: ThinkingLevel | None = Field(None, alias='thinkingLevel') + + +class FileSearchConfigSchema(BaseModel): + model_config = ConfigDict(extra='allow', populate_by_name=True) + file_search_store_names: list[str] | None = Field(None, alias='fileSearchStoreNames') + metadata_filter: str | None = Field(None, alias='metadataFilter') + top_k: int | None = Field(None, alias='topK') + + +class ImageAspectRatio(StrEnum): + RATIO_1_1 = '1:1' + RATIO_2_3 = '2:3' + RATIO_3_2 = '3:2' + RATIO_3_4 = '3:4' + RATIO_4_3 = '4:3' + RATIO_4_5 = '4:5' + RATIO_5_4 = '5:4' + RATIO_9_16 = '9:16' + RATIO_16_9 = '16:9' + RATIO_21_9 = '21:9' + + +class ImageSize(StrEnum): + SIZE_1K = '1K' + SIZE_2K = '2K' + SIZE_4K = '4K' + + +class ImageConfigSchema(BaseModel): + model_config = ConfigDict(extra='allow', populate_by_name=True) + aspect_ratio: ImageAspectRatio | None = Field(None, alias='aspectRatio') + image_size: ImageSize | None = Field(None, alias='imageSize') + + +class VoiceConfigSchema(BaseModel): + model_config = ConfigDict(extra='allow', populate_by_name=True) + prebuilt_voice_config: PrebuiltVoiceConfig | None = Field(None, alias='prebuiltVoiceConfig') + + +class GeminiConfigSchema(GenerationCommonConfig): """Gemini Config Schema.""" + model_config = ConfigDict(extra='allow', populate_by_name=True) + + safety_settings: Annotated[ + list[SafetySettingsSchema] | None, + WithJsonSchema({ + 'type': 'array', + 'items': { + 'type': 'object', + 'properties': { + 'category': {'type': 'string', 'enum': [e.value for e in HarmCategory]}, + 'threshold': {'type': 'string', 'enum': [e.value for e in HarmBlockThreshold]}, + }, + 'required': ['category', 'threshold'], + 'additionalProperties': True, + }, + 'description': 'Adjust how likely you are to see responses that could be harmful. Content is blocked based on the probability that it is harmful.', + }), + ] = Field( + None, + alias='safetySettings', + ) + # Gemini specific model_config = ConfigDict(extra='allow') code_execution: bool | None = None @@ -187,7 +299,137 @@ class GeminiConfigSchema(genai_types.GenerateContentConfig): thinking_config: dict[str, object] | None = None file_search: dict[str, object] | None = None url_context: dict[str, object] | None = None + google_search_retrieval: dict[str, Any] | None = None + function_calling_config: dict[str, Any] | None = None api_version: str | None = None + base_url: str | None = None + + # inherited from GenerationCommonConfig: + # version, temperature, max_output_tokens, top_k, top_p, stop_sequences + + temperature: float | None = Field( + default=None, + description='Controls the randomness of the output. Values can range over [0.0, 2.0].', + ) + top_p: float | None = Field( + default=None, + alias='topP', + description='The maximum cumulative probability of tokens to consider when sampling. Values can range over [0.0, 1.0].', + ) + top_k: int | None = Field( + default=None, + alias='topK', + description='The maximum number of tokens to consider when sampling. Values can range over [1, 40].', + ) + candidate_count: int | None = Field( + default=None, description='Number of generated responses to return.', alias='candidateCount' + ) + max_output_tokens: int | None = Field( + default=None, alias='maxOutputTokens', description='Maximum number of tokens to generate.' + ) + stop_sequences: list[str] | None = Field(default=None, alias='stopSequences', description='Stop sequences.') + presence_penalty: float | None = Field(default=None, description='Presence penalty.', alias='presencePenalty') + frequency_penalty: float | None = Field(default=None, description='Frequency penalty.', alias='frequencyPenalty') + response_mime_type: str | None = Field(default=None, description='Response MIME type.', alias='responseMimeType') + response_schema: dict[str, Any] | None = Field(default=None, description='Response schema.', alias='responseSchema') + + code_execution: bool | dict[str, Any] | None = Field( + None, description='Enables the model to generate and run code.', alias='codeExecution' + ) + response_modalities: list[str] | None = Field( + None, + description="The modalities to be used in response. Only supported for 'gemini-2.0-flash-exp' model at present.", + alias='responseModalities', + ) + + thinking_config: Annotated[ + ThinkingConfigSchema | None, + WithJsonSchema({ + 'type': 'object', + 'properties': { + 'includeThoughts': { + 'type': 'boolean', + 'description': 'Indicates whether to include thoughts in the response. If true, thoughts are returned only if the model supports thought and thoughts are available.', + }, + 'thinkingBudget': { + 'type': 'integer', + 'description': 'For Gemini 2.5 - Indicates the thinking budget in tokens. 0 is DISABLED. -1 is AUTOMATIC. The default values and allowed ranges are model dependent. The thinking budget parameter gives the model guidance on the number of thinking tokens it can use when generating a response. A greater number of tokens is typically associated with more detailed thinking, which is needed for solving more complex tasks.', + }, + 'thinkingLevel': { + 'type': 'string', + 'enum': [e.value for e in ThinkingLevel], + 'description': 'For Gemini 3.0 - Indicates the thinking level. A higher level is associated with more detailed thinking, which is needed for solving more complex tasks.', + }, + }, + 'additionalProperties': True, + }), + ] = Field(None, alias='thinkingConfig') + + file_search: Annotated[ + FileSearchConfigSchema | None, + WithJsonSchema({ + 'type': 'object', + 'properties': { + 'fileSearchStoreNames': { + 'type': 'array', + 'items': {'type': 'string'}, + 'description': 'The names of the fileSearchStores to retrieve from. Example: fileSearchStores/my-file-search-store-123', + }, + 'metadataFilter': { + 'type': 'string', + 'description': 'Metadata filter to apply to the semantic retrieval documents and chunks.', + }, + 'topK': { + 'type': 'integer', + 'description': 'The number of semantic retrieval chunks to retrieve.', + }, + }, + 'additionalProperties': True, + }), + ] = Field(None, alias='fileSearch') + + url_context: bool | dict[str, Any] | None = Field( + None, description='Return grounding metadata from links included in the query', alias='urlContext' + ) + google_search_retrieval: bool | dict[str, Any] | None = Field( + None, + description='Retrieve public web data for grounding, powered by Google Search.', + alias='googleSearchRetrieval', + ) + function_calling_config: Annotated[ + FunctionCallingConfig | None, + WithJsonSchema({ + 'type': 'object', + 'properties': { + 'mode': {'type': 'string', 'enum': [e.value for e in FunctionCallingMode]}, + 'allowedFunctionNames': {'type': 'array', 'items': {'type': 'string'}}, + }, + 'description': 'Controls how the model uses the provided tools (function declarations). With AUTO (Default) mode, the model decides whether to generate a natural language response or suggest a function call based on the prompt and context. With ANY, the model is constrained to always predict a function call and guarantee function schema adherence. With NONE, the model is prohibited from making function calls.', + 'additionalProperties': True, + }), + ] = Field( + None, + alias='functionCallingConfig', + ) + + api_version: str | None = Field( + None, description='Overrides the plugin-configured or default apiVersion, if specified.', alias='apiVersion' + ) + base_url: str | None = Field( + None, description='Overrides the plugin-configured or default baseUrl, if specified.', alias='baseUrl' + ) + api_key: str | None = Field( + None, description='Overrides the plugin-configured API key, if specified.', alias='apiKey' + ) + context_cache: bool | None = Field( + None, + description='Context caching allows you to save and reuse precomputed input tokens that you wish to use repeatedly.', + alias='contextCache', + ) + + +class SpeechConfigSchema(BaseModel): + voice_config: VoiceConfigSchema | None = Field(None, alias='voiceConfig') http_options: Any | None = Field(None, exclude=True) tools: Any | None = Field(None, exclude=True) @@ -199,18 +441,29 @@ class GeminiConfigSchema(genai_types.GenerateContentConfig): class GeminiTtsConfigSchema(GeminiConfigSchema): """Gemini TTS Config Schema.""" - speech_config: dict[str, object] | None = None + speech_config: SpeechConfigSchema | None = Field(None, alias='speechConfig') class GeminiImageConfigSchema(GeminiConfigSchema): """Gemini Image Config Schema.""" - image_config: dict[str, object] | None = None + image_config: Annotated[ + ImageConfigSchema | None, + WithJsonSchema({ + 'type': 'object', + 'properties': { + 'aspectRatio': {'type': 'string', 'enum': [e.value for e in ImageAspectRatio]}, + 'imageSize': {'type': 'string', 'enum': [e.value for e in ImageSize]}, + }, + 'additionalProperties': True, + }), + ] = Field(None, alias='imageConfig') class GemmaConfigSchema(GeminiConfigSchema): """Gemma Config Schema.""" + # Inherits temperature from GeminiConfigSchema temperature: float | None = None @@ -229,7 +482,6 @@ class GemmaConfigSchema(GeminiConfigSchema): tool_choice=True, system_role=True, constrained=Constrained.NO_TOOLS, - output=['text', 'json'], ), ) @@ -436,12 +688,7 @@ class GemmaConfigSchema(GeminiConfigSchema): ) -Deprecations = deprecated_enum_metafactory({ - 'GEMINI_1_0_PRO': DeprecationInfo(recommendation='GEMINI_2_0_FLASH', status=DeprecationStatus.DEPRECATED), - 'GEMINI_1_5_PRO': DeprecationInfo(recommendation='GEMINI_2_0_FLASH', status=DeprecationStatus.DEPRECATED), - 'GEMINI_1_5_FLASH': DeprecationInfo(recommendation='GEMINI_2_0_FLASH', status=DeprecationStatus.DEPRECATED), - 'GEMINI_1_5_FLASH_8B': DeprecationInfo(recommendation='GEMINI_2_0_FLASH', status=DeprecationStatus.DEPRECATED), -}) +Deprecations = deprecated_enum_metafactory({}) class VertexAIGeminiVersion(StrEnum, metaclass=Deprecations): @@ -479,9 +726,6 @@ class VertexAIGeminiVersion(StrEnum, metaclass=Deprecations): | `gemma-3n-e4b-it` | Gemma 3n E4B IT | Supported | """ - GEMINI_1_5_FLASH = 'gemini-1.5-flash' - GEMINI_1_5_FLASH_8B = 'gemini-1.5-flash-8b' - GEMINI_1_5_PRO = 'gemini-1.5-pro' GEMINI_2_0_FLASH = 'gemini-2.0-flash' GEMINI_2_0_FLASH_EXP = 'gemini-2.0-flash-exp' GEMINI_2_0_FLASH_LITE = 'gemini-2.0-flash-lite' @@ -542,9 +786,6 @@ class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): | `gemma-3n-e4b-it` | Gemma 3n E4B IT | Supported | """ - GEMINI_1_5_FLASH = 'gemini-1.5-flash' - GEMINI_1_5_FLASH_8B = 'gemini-1.5-flash-8b' - GEMINI_1_5_PRO = 'gemini-1.5-pro' GEMINI_2_0_FLASH = 'gemini-2.0-flash' GEMINI_2_0_FLASH_EXP = 'gemini-2.0-flash-exp' GEMINI_2_0_FLASH_LITE = 'gemini-2.0-flash-lite' @@ -570,60 +811,7 @@ class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): GEMMA_3N_E4B_IT = 'gemma-3n-e4b-it' -SUPPORTED_MODELS = { - GoogleAIGeminiVersion.GEMINI_1_5_FLASH: GEMINI_1_5_FLASH, - GoogleAIGeminiVersion.GEMINI_1_5_FLASH_8B: GEMINI_1_5_FLASH_8B, - GoogleAIGeminiVersion.GEMINI_1_5_PRO: GEMINI_1_5_PRO, - GoogleAIGeminiVersion.GEMINI_2_0_FLASH: GEMINI_2_0_FLASH, - GoogleAIGeminiVersion.GEMINI_2_0_FLASH_EXP: GEMINI_2_0_FLASH_EXP_IMAGEN, - GoogleAIGeminiVersion.GEMINI_2_0_FLASH_LITE: GEMINI_2_0_FLASH_LITE, - GoogleAIGeminiVersion.GEMINI_2_0_FLASH_THINKING_EXP_01_21: GEMINI_2_0_FLASH_THINKING_EXP_01_21, - GoogleAIGeminiVersion.GEMINI_2_0_PRO_EXP_02_05: GEMINI_2_0_PRO_EXP_02_05, - GoogleAIGeminiVersion.GEMINI_2_5_PRO_EXP_03_25: GEMINI_2_5_PRO_EXP_03_25, - GoogleAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_03_25: GEMINI_2_5_PRO_PREVIEW_03_25, - GoogleAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_05_06: GEMINI_2_5_PRO_PREVIEW_05_06, - GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW: GENERIC_GEMINI_MODEL, - GoogleAIGeminiVersion.GEMINI_3_PRO_PREVIEW: GENERIC_GEMINI_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_PRO: GENERIC_GEMINI_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_FLASH: GENERIC_GEMINI_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_FLASH_LITE: GENERIC_GEMINI_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_FLASH_PREVIEW_TTS: GENERIC_TTS_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_TTS: GENERIC_TTS_MODEL, - GoogleAIGeminiVersion.GEMINI_3_PRO_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE: GENERIC_IMAGE_MODEL, - GoogleAIGeminiVersion.GEMMA_3_12B_IT: GENERIC_GEMMA_MODEL, - GoogleAIGeminiVersion.GEMMA_3_1B_IT: GENERIC_GEMMA_MODEL, - GoogleAIGeminiVersion.GEMMA_3_27B_IT: GENERIC_GEMMA_MODEL, - GoogleAIGeminiVersion.GEMMA_3_4B_IT: GENERIC_GEMMA_MODEL, - GoogleAIGeminiVersion.GEMMA_3N_E4B_IT: GENERIC_GEMMA_MODEL, - VertexAIGeminiVersion.GEMINI_1_5_FLASH: GEMINI_1_5_FLASH, - VertexAIGeminiVersion.GEMINI_1_5_FLASH_8B: GEMINI_1_5_FLASH_8B, - VertexAIGeminiVersion.GEMINI_1_5_PRO: GEMINI_1_5_PRO, - VertexAIGeminiVersion.GEMINI_2_0_FLASH: GEMINI_2_0_FLASH, - VertexAIGeminiVersion.GEMINI_2_0_FLASH_EXP: GEMINI_2_0_FLASH_EXP_IMAGEN, - VertexAIGeminiVersion.GEMINI_2_0_FLASH_LITE: GEMINI_2_0_FLASH_LITE, - VertexAIGeminiVersion.GEMINI_2_0_FLASH_THINKING_EXP_01_21: GEMINI_2_0_FLASH_THINKING_EXP_01_21, - VertexAIGeminiVersion.GEMINI_2_0_PRO_EXP_02_05: GEMINI_2_0_PRO_EXP_02_05, - VertexAIGeminiVersion.GEMINI_2_5_PRO_EXP_03_25: GEMINI_2_5_PRO_EXP_03_25, - VertexAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_03_25: GEMINI_2_5_PRO_PREVIEW_03_25, - VertexAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_05_06: GEMINI_2_5_PRO_PREVIEW_05_06, - VertexAIGeminiVersion.GEMINI_3_FLASH_PREVIEW: GENERIC_GEMINI_MODEL, - VertexAIGeminiVersion.GEMINI_3_PRO_PREVIEW: GENERIC_GEMINI_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_PRO: GENERIC_GEMINI_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_FLASH: GENERIC_GEMINI_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_FLASH_LITE: GENERIC_GEMINI_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_FLASH_PREVIEW_TTS: GENERIC_TTS_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_TTS: GENERIC_TTS_MODEL, - VertexAIGeminiVersion.GEMINI_3_PRO_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE: GENERIC_IMAGE_MODEL, - VertexAIGeminiVersion.GEMMA_3_12B_IT: GENERIC_GEMMA_MODEL, - VertexAIGeminiVersion.GEMMA_3_1B_IT: GENERIC_GEMMA_MODEL, - VertexAIGeminiVersion.GEMMA_3_27B_IT: GENERIC_GEMMA_MODEL, - VertexAIGeminiVersion.GEMMA_3_4B_IT: GENERIC_GEMMA_MODEL, - VertexAIGeminiVersion.GEMMA_3N_E4B_IT: GENERIC_GEMMA_MODEL, -} +SUPPORTED_MODELS = {} DEFAULT_SUPPORTS_MODEL = Supports( @@ -636,6 +824,37 @@ class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): ) +def is_gemini_model(name: str) -> bool: + """Check if the model is a Gemini model.""" + return name.startswith('gemini-') and not is_tts_model(name) and not is_image_model(name) + + +def is_tts_model(name: str) -> bool: + """Check if the model is a TTS model.""" + return (name.startswith('gemini-') and name.endswith('-tts')) or 'tts' in name + + +def is_image_model(name: str) -> bool: + """Check if the model is an image model.""" + return (name.startswith('gemini-') and '-image' in name) or 'image' in name + + +def is_gemma_model(name: str) -> bool: + """Check if the model is a Gemma model.""" + return name.startswith('gemma-') + + +def get_model_config_schema(name: str) -> type[GeminiConfigSchema]: + """Get the config schema for a given model name.""" + if is_tts_model(name): + return GeminiTtsConfigSchema + if is_image_model(name): + return GeminiImageConfigSchema + if is_gemma_model(name): + return GemmaConfigSchema + return GeminiConfigSchema + + def google_model_info( version: str, ) -> ModelInfo: @@ -650,6 +869,16 @@ def google_model_info( Returns: ModelInfo object. """ + if version in SUPPORTED_MODELS: + return SUPPORTED_MODELS[version] + + if is_tts_model(version): + return GENERIC_TTS_MODEL + if is_image_model(version): + return GENERIC_IMAGE_MODEL + if is_gemma_model(version): + return GENERIC_GEMMA_MODEL + return ModelInfo( label=f'Google AI - {version}', supports=DEFAULT_SUPPORTS_MODEL, @@ -728,8 +957,7 @@ def _convert_schema_property( return None if defs is None: - defs_value = input_schema.get('$defs') - defs = cast(dict[str, object], defs_value) if isinstance(defs_value, dict) else {} + defs = input_schema.get('$defs') if '$defs' in input_schema else {} if '$ref' in input_schema: ref_path = input_schema['$ref'] @@ -756,32 +984,28 @@ def _convert_schema_property( schema = genai_types.Schema() if input_schema.get('description'): - schema.description = cast(str, input_schema['description']) + schema.description = input_schema['description'] if 'required' in input_schema: - schema.required = cast(list[str], input_schema['required']) + schema.required = input_schema['required'] if 'type' in input_schema: - schema_type = genai_types.Type(cast(str, input_schema['type'])) + schema_type = genai_types.Type(input_schema['type']) schema.type = schema_type if 'enum' in input_schema: - schema.enum = cast(list[str], input_schema['enum']) + schema.enum = input_schema['enum'] if schema_type == genai_types.Type.ARRAY: - items_value = input_schema.get('items') - if isinstance(items_value, dict): - schema.items = self._convert_schema_property(cast(dict[str, object], items_value), defs) + schema.items = self._convert_schema_property(input_schema['items'], defs) if schema_type == genai_types.Type.OBJECT: schema.properties = {} - properties_value = input_schema.get('properties', {}) - if isinstance(properties_value, dict): - properties = cast(dict[str, dict[str, object]], properties_value) - for key in properties: - nested_schema = self._convert_schema_property(properties[key], defs) - if nested_schema: - schema.properties[key] = nested_schema + properties = input_schema.get('properties', {}) + for key in properties: + nested_schema = self._convert_schema_property(properties[key], defs) + if nested_schema: + schema.properties[key] = nested_schema return schema @@ -875,36 +1099,29 @@ async def generate(self, request: GenerateRequest, ctx: ActionRunContext) -> Gen # If the library changes its internal structure (e.g. renames _api_client or _credentials), # this code WILL BREAK. api_client = self._client._api_client - http_opts: genai_types.HttpOptionsDict = {'api_version': api_version} + kwargs = { + 'vertexai': api_client.vertexai, + 'http_options': {'api_version': api_version}, + } if api_client.vertexai: - # Vertex AI mode: requires project/location - client = genai.Client( - vertexai=True, - http_options=http_opts, - project=api_client.project, - location=api_client.location, - credentials=api_client._credentials, - ) + # Vertex AI mode: requires project/location (api_key is optional/unlikely) + if api_client.project: + kwargs['project'] = api_client.project + if api_client.location: + kwargs['location'] = api_client.location + if api_client._credentials: + kwargs['credentials'] = api_client._credentials + # Don't pass api_key if we are in Vertex AI mode with credentials/project else: # Google AI mode: primarily uses api_key if api_client.api_key: - client = genai.Client( - vertexai=False, - http_options=http_opts, - api_key=api_client.api_key, - ) - elif api_client._credentials: - # Fallback if no api_key but credentials present - client = genai.Client( - vertexai=False, - http_options=http_opts, - credentials=api_client._credentials, - ) - else: - client = genai.Client( - vertexai=False, - http_options=http_opts, - ) + kwargs['api_key'] = api_client.api_key + # Do NOT pass project/location/credentials if in Google AI mode to be safe + if api_client._credentials and not kwargs.get('api_key'): + # Fallback if no api_key but credentials present (unlikely for pure Google AI but possible) + kwargs['credentials'] = api_client._credentials + + client = genai.Client(**kwargs) if ctx.is_streaming: response = await self._streaming_generate( @@ -1030,11 +1247,7 @@ def metadata(self) -> dict: Returns: model metadata. """ - model_info = SUPPORTED_MODELS.get(self._version) - if model_info and model_info.supports: - supports = model_info.supports.model_dump() - else: - supports = {} + supports = SUPPORTED_MODELS[self._version].supports.model_dump() return { 'model': { 'supports': supports, @@ -1113,7 +1326,9 @@ def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.Gener if request.config: request_config = request.config - if isinstance(request_config, GenerationCommonConfig): + if isinstance(request_config, GeminiConfigSchema): + cfg = request_config + elif isinstance(request_config, GenerationCommonConfig): cfg = genai_types.GenerateContentConfig( max_output_tokens=request_config.max_output_tokens, top_k=request_config.top_k, @@ -1121,8 +1336,6 @@ def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.Gener temperature=request_config.temperature, stop_sequences=request_config.stop_sequences, ) - elif isinstance(request_config, GeminiConfigSchema): - cfg = request_config elif isinstance(request_config, dict): if 'image_config' in request_config: cfg = GeminiImageConfigSchema(**request_config) @@ -1136,7 +1349,50 @@ def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.Gener tools.extend([genai_types.Tool(code_execution=genai_types.ToolCodeExecution())]) dumped_config = cfg.model_dump(exclude_none=True) - for key in ['code_execution', 'file_search', 'url_context', 'api_version']: + + if 'code_execution' in dumped_config: + if dumped_config.pop('code_execution'): + tools.append(genai_types.Tool(code_execution=genai_types.ToolCodeExecution())) + + if 'safety_settings' in dumped_config: + dumped_config['safety_settings'] = [ + s + for s in dumped_config['safety_settings'] + if s['category'] != HarmCategory.HARM_CATEGORY_UNSPECIFIED + ] + + if 'google_search_retrieval' in dumped_config: + val = dumped_config.pop('google_search_retrieval') + if val is not None: + val = {} if val is True else val + tools.append(genai_types.Tool(google_search_retrieval=genai_types.GoogleSearchRetrieval(**val))) + + if 'file_search' in dumped_config: + val = dumped_config.pop('file_search') + # File search requires a store name to be valid. + if val and val.get('file_search_store_names'): + # Filter out empty strings from store names + valid_stores = [s for s in val['file_search_store_names'] if s] + if valid_stores: + val['file_search_store_names'] = valid_stores + tools.append(genai_types.Tool(file_search=genai_types.FileSearch(**val))) + + if 'url_context' in dumped_config: + val = dumped_config.pop('url_context') + if val is not None: + val = {} if val is True else val + tools.append(genai_types.Tool(url_context=genai_types.UrlContext(**val))) + + # Map Function Calling Config to ToolConfig + if 'function_calling_config' in dumped_config: + dumped_config['tool_config'] = genai_types.ToolConfig( + function_calling_config=genai_types.FunctionCallingConfig( + **dumped_config.pop('function_calling_config') + ) + ) + + # Clean up fields not supported by GenerateContentConfig + for key in ['api_version', 'api_key', 'base_url', 'context_cache']: if key in dumped_config: del dumped_config[key] @@ -1213,6 +1469,13 @@ def _create_usage_stats(self, request: GenerateRequest, response: GenerateRespon usage = get_basic_usage_stats(input_=request.messages, response=response.message) else: usage = GenerationUsage() + if not response.message: + usage = GenerationUsage() + usage.input_tokens = 0 + usage.output_tokens = 0 + usage.total_tokens = 0 + return usage + usage = get_basic_usage_stats(input_=request.messages, response=response.message) if response.usage: usage.input_tokens = response.usage.input_tokens usage.output_tokens = response.usage.output_tokens diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/imagen.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/imagen.py index cf27774c67..11643ffd3e 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/imagen.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/imagen.py @@ -231,11 +231,13 @@ def metadata(self) -> dict: Returns: model metadata. """ - supports = SUPPORTED_MODELS[self._version].supports - if supports: - return { - 'model': { - 'supports': supports.model_dump(), - } + if self._version in SUPPORTED_MODELS: + supports = SUPPORTED_MODELS[self._version].supports.model_dump() + else: + supports = vertexai_image_model_info(self._version).supports.model_dump() + + return { + 'model': { + 'supports': supports, } - return {'model': {'supports': {}}} + } diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/veo.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/veo.py new file mode 100644 index 0000000000..f87288bf28 --- /dev/null +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/veo.py @@ -0,0 +1,193 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from typing import Any + +from google import genai +from google.genai import types as genai_types +from pydantic import BaseModel, ConfigDict, Field, ValidationError + +from genkit.ai import ActionRunContext +from genkit.core.tracing import tracer +from genkit.types import ( + GenerateRequest, + GenerateResponse, + Media, + Message, + ModelInfo, + Part, + Role, + Supports, + TextPart, +) + + +class VeoConfigSchema(BaseModel): + """Veo Config Schema.""" + + model_config = ConfigDict(extra='allow') + negative_prompt: str | None = Field(default=None, description='Negative prompt for video generation.') + aspect_ratio: str | None = Field( + default=None, description='Desired aspect ratio of the output video (e.g. "16:9").' + ) + person_generation: str | None = Field(default=None, description='Person generation mode.') + duration_seconds: int | None = Field(default=None, description='Length of video in seconds.') + enhance_prompt: bool | None = Field(default=None, description='Enable prompt enhancement.') + + +DEFAULT_VEO_SUPPORT = Supports( + media=True, + multiturn=False, + tools=False, + systemRole=False, + output=['media'], +) + + +def veo_model_info( + version: str, +) -> ModelInfo: + """Generates a ModelInfo object for Veo. + + Args: + version: Version of the model. + + Returns: + ModelInfo object. + """ + return ModelInfo( + label=f'Google AI - {version}', + supports=DEFAULT_VEO_SUPPORT, + ) + + +class VeoModel: + """Veo text-to-video model.""" + + def __init__(self, version: str, client: genai.Client): + """Initialize Veo model. + + Args: + version: Veo version + client: Google AI client + """ + self._version = version + self._client = client + + def _build_prompt(self, request: GenerateRequest) -> str: + """Build prompt request from Genkit request.""" + prompt = [] + for message in request.messages: + for part in message.content: + if isinstance(part.root, TextPart): + prompt.append(part.root.text) + else: + # TODO: Support image input if Veo supports it (e.g. for image-to-video) + # For now, strict text text-to-video + pass + return ' '.join(prompt) + + async def generate(self, request: GenerateRequest, _: ActionRunContext) -> GenerateResponse: + """Handle a generation request using internal polling for LRO. + + Args: + request: The generation request. + _: action context + + Returns: + The model's response. + """ + prompt = self._build_prompt(request) + config = self._get_config(request) + + with tracer.start_as_current_span('generate_videos') as span: + # TODO: Add span attributes + + # Start LRO + operation = await self._client.aio.models.generate_videos(model=self._version, prompt=prompt, config=config) + + # Poll until done + # Note: SDK might have wait_for_completion logic? + # If `operation` is a standard LRO object, we can loop. + # Assuming SDK returns a job/operation object that has `.done()` or similar. + # If it's `google.api_core.operation.Operation`, it has `.result()`. + # But `genai` SDK is new. Let's assume it returns a custom Operation object. + # Based on `veo.go`, it returns an Operation. + + while not operation.done: + await asyncio.sleep(2) # Poll every 2 seconds + # We need to refresh the operation status. + # Does the SDK object update itself or do we need to fetch it? + # In standard GAPIC, we don't. But `genai` client might be different. + # `genai` SDK typically has `.poll()` or we re-fetch. + # Actually, `client.aio.models.generate_videos` might return the RESOLVED response if it waits? + # No, typically "generate_videos" implies LRO. + # Let's assume `operation` needs refreshing or `result()` awaiting. + # Safest: Use `operation.result()` if available and awaitable? + # If `operation` is `google.genai.operations.AsyncOperation`: + if hasattr(operation, 'result'): + response = await operation.result() + break + # Fallback manual polling if no async result() + # But SDK likely provides a way. + pass + + # If `operation` doesn't have `result` or `done`, we might be using it wrong. + # Let's assume `operation.result()` works for now as standard Python convention. + + if hasattr(operation, 'result'): + response = await operation.result() + else: + # Fallback: Assume it finished if we exited loop + response = operation.result + + # Extract video + content = self._contents_from_response(response) + + return GenerateResponse( + message=Message( + content=content, + role=Role.MODEL, + ) + ) + + def _get_config(self, request: GenerateRequest) -> genai_types.GenerateVideosConfigOrDict: + cfg = None + if request.config: + # Simple cast/validate + cfg = request.config + return cfg + + def _contents_from_response(self, response: genai_types.GenerateVideosResponse) -> list: + content = [] + if response.generated_videos: + for video in response.generated_videos: + # Video URI is typically in video.uri or similar + uri = video.video.uri + content.append( + Part( + media=Media( + url=uri, + contentType='video/mp4', # Default? + ) + ) + ) + return content + + @property + def metadata(self) -> dict: + return {'model': {'supports': DEFAULT_VEO_SUPPORT.model_dump()}} diff --git a/py/plugins/google-genai/test/google_plugin_test.py b/py/plugins/google-genai/test/google_plugin_test.py index 30c7d0f58f..c8dc1cfb1c 100644 --- a/py/plugins/google-genai/test/google_plugin_test.py +++ b/py/plugins/google-genai/test/google_plugin_test.py @@ -23,8 +23,9 @@ from unittest.mock import MagicMock, patch, ANY from google.auth.credentials import Credentials -from pydantic import BaseModel -from google.genai.types import HttpOptions, HttpOptionsDict +from unittest.mock import AsyncMock, MagicMock, patch +from dataclasses import dataclass, field +from google.genai.types import HttpOptions import pytest from genkit.ai import Genkit, GENKIT_CLIENT_HEADER @@ -64,7 +65,7 @@ class TestGoogleAIInit(unittest.TestCase): """Test cases for __init__ plugin.""" @patch('google.genai.client.Client') - def test_init_with_api_key(self, mock_genai_client: MagicMock) -> None: + def test_init_with_api_key(self, mock_genai_client): """Test using api_key parameter.""" api_key = 'test_api_key' plugin = GoogleAI(api_key=api_key) @@ -81,7 +82,7 @@ def test_init_with_api_key(self, mock_genai_client: MagicMock) -> None: @patch('google.genai.client.Client') @patch.dict(os.environ, {'GEMINI_API_KEY': 'env_api_key'}) - def test_init_from_env_var(self, mock_genai_client: MagicMock) -> None: + def test_init_from_env_var(self, mock_genai_client): """Test using env var for api_key.""" plugin = GoogleAI() mock_genai_client.assert_called_once_with( @@ -96,7 +97,7 @@ def test_init_from_env_var(self, mock_genai_client: MagicMock) -> None: self.assertIsInstance(plugin._client, MagicMock) @patch('google.genai.client.Client') - def test_init_with_credentials(self, mock_genai_client: MagicMock) -> None: + def test_init_with_credentials(self, mock_genai_client): """Test using credentials parameter.""" mock_credentials = MagicMock(spec=Credentials) plugin = GoogleAI(credentials=mock_credentials) @@ -111,7 +112,7 @@ def test_init_with_credentials(self, mock_genai_client: MagicMock) -> None: self.assertFalse(plugin._vertexai) self.assertIsInstance(plugin._client, MagicMock) - def test_init_raises_value_error_no_api_key(self) -> None: + def test_init_raises_value_error_no_api_key(self): """Test using credentials parameter.""" with patch.dict(os.environ, {'GEMINI_API_KEY': ''}, clear=True): with self.assertRaisesRegex( @@ -121,11 +122,28 @@ def test_init_raises_value_error_no_api_key(self) -> None: GoogleAI() +@patch('google.genai.client.Client') @pytest.mark.asyncio -async def test_googleai_initialize() -> None: +async def test_googleai_initialize(mock_client_cls): """Unit tests for GoogleAI.init method.""" + mock_client = mock_client_cls.return_value + + m1 = MagicMock() + m1.name = 'models/gemini-pro' + m1.supported_actions = ['generateContent'] + m1.description = ' Gemini Pro ' + + m2 = MagicMock() + m2.name = 'models/text-embedding-004' + m2.supported_actions = ['embedContent'] + m2.description = ' Embedding ' + + mock_client.models.list.return_value = [m1, m2] + api_key = 'test_api_key' plugin = GoogleAI(api_key=api_key) + # Ensure usage of mock + plugin._client = mock_client result = await plugin.init() @@ -146,9 +164,7 @@ async def test_googleai_initialize() -> None: @patch('genkit.plugins.google_genai.GoogleAI._resolve_model') @pytest.mark.asyncio -async def test_googleai_resolve_action_model( - mock_resolve_action: MagicMock, googleai_plugin_instance: GoogleAI -) -> None: +async def test_googleai_resolve_action_model(mock_resolve_action, googleai_plugin_instance): """Test resolve action for model.""" plugin = googleai_plugin_instance @@ -158,9 +174,7 @@ async def test_googleai_resolve_action_model( @patch('genkit.plugins.google_genai.GoogleAI._resolve_embedder') @pytest.mark.asyncio -async def test_googleai_resolve_action_embedder( - mock_resolve_action: MagicMock, googleai_plugin_instance: GoogleAI -) -> None: +async def test_googleai_resolve_action_embedder(mock_resolve_action, googleai_plugin_instance): """Test resolve action for embedder.""" plugin = googleai_plugin_instance @@ -185,12 +199,12 @@ async def test_googleai_resolve_action_embedder( ], ) def test_googleai__resolve_model( - mock_google_model_info: MagicMock, - model_name: str, - expected_model_name: str, - key: str, - googleai_plugin_instance: GoogleAI, -) -> None: + mock_google_model_info, + model_name, + expected_model_name, + key, + googleai_plugin_instance, +): """Tests for GoogleAI._resolve_model method.""" plugin = googleai_plugin_instance @@ -215,11 +229,11 @@ def test_googleai__resolve_model( ], ) def test_googleai__resolve_embedder( - model_name: str, - expected_model_name: str, - clean_name: str, - googleai_plugin_instance: GoogleAI, -) -> None: + model_name, + expected_model_name, + clean_name, + googleai_plugin_instance, +): """Tests for GoogleAI._resolve_embedder method.""" plugin = googleai_plugin_instance @@ -231,19 +245,19 @@ def test_googleai__resolve_embedder( @pytest.mark.asyncio -async def test_googleai_list_actions(googleai_plugin_instance: GoogleAI) -> None: +async def test_googleai_list_actions(googleai_plugin_instance): """Unit test for list actions.""" - class MockModel(BaseModel): - """mock.""" - + @dataclass + class MockModel: supported_actions: list[str] name: str + description: str = '' models_return_value = [ - MockModel(supported_actions=['generateContent'], name='models/model1'), - MockModel(supported_actions=['embedContent'], name='models/model2'), - MockModel(supported_actions=['generateContent', 'embedContent'], name='models/model3'), + MockModel(supported_actions=['generateContent'], name='models/gemini-pro'), + MockModel(supported_actions=['embedContent'], name='models/text-embedding-004'), + MockModel(supported_actions=['generateContent'], name='models/gemini-2.0-flash-tts'), # TTS ] mock_client = MagicMock() @@ -251,32 +265,22 @@ class MockModel(BaseModel): googleai_plugin_instance._client = mock_client result = await googleai_plugin_instance.list_actions() - assert result == [ - model_action_metadata( - name=googleai_name('model1'), - info=google_model_info('model1').model_dump(), - ), - embedder_action_metadata( - name=googleai_name('model2'), - options=EmbedderOptions( - label=default_embedder_info('model2').get('label'), - supports=EmbedderSupports(input=default_embedder_info('model2').get('supports', {}).get('input')), - dimensions=default_embedder_info('model2').get('dimensions'), - ), - ), - model_action_metadata( - name=googleai_name('model3'), - info=google_model_info('model3').model_dump(), - ), - embedder_action_metadata( - name=googleai_name('model3'), - options=EmbedderOptions( - label=default_embedder_info('model3').get('label'), - supports=EmbedderSupports(input=default_embedder_info('model3').get('supports', {}).get('input')), - dimensions=default_embedder_info('model3').get('dimensions'), - ), - ), - ] + + # Check Gemini Pro + action1 = next(a for a in result if a.name == googleai_name('gemini-pro')) + assert action1 is not None + + # Check Embedder + action2 = next(a for a in result if a.name == googleai_name('text-embedding-004')) + assert action2 is not None + assert action2.kind == ActionKind.EMBEDDER + + # Check TTS + action3 = next(a for a in result if a.name == googleai_name('gemini-2.0-flash-tts')) + assert action3 is not None + # from genkit.plugins.google_genai.models.gemini import GeminiTtsConfigSchema, GeminiConfigSchema + # assert action3.config_schema == GeminiTtsConfigSchema + # assert action1.config_schema == GeminiConfigSchema @pytest.mark.parametrize( @@ -385,9 +389,7 @@ class MockModel(BaseModel): ), ], ) -def test_inject_attribution_headers( - input_options: HttpOptions | HttpOptionsDict | None, expected_headers: dict[str, str] -) -> None: +def test_inject_attribution_headers(input_options, expected_headers): """Tests the _inject_attribution_headers function with various inputs.""" result = _inject_attribution_headers(input_options) assert isinstance(result, HttpOptions) @@ -399,7 +401,7 @@ class TestVertexAIInit(unittest.TestCase): @patch('google.genai.client.Client') @patch.dict(os.environ, {'GCLOUD_PROJECT': 'project'}) - def test_init_with_api_key(self, mock_genai_client: MagicMock) -> None: + def test_init_with_api_key(self, mock_genai_client): """Test using api_key parameter.""" api_key = 'test_api_key' plugin = VertexAI(api_key=api_key) @@ -418,7 +420,7 @@ def test_init_with_api_key(self, mock_genai_client: MagicMock) -> None: @patch('google.genai.client.Client') @patch.dict(os.environ, {'GCLOUD_PROJECT': 'project'}) - def test_init_with_credentials(self, mock_genai_client: MagicMock) -> None: + def test_init_with_credentials(self, mock_genai_client): """Test using credentials parameter.""" mock_credentials = MagicMock(spec=Credentials) plugin = VertexAI(credentials=mock_credentials) @@ -436,7 +438,7 @@ def test_init_with_credentials(self, mock_genai_client: MagicMock) -> None: self.assertIsInstance(plugin._client, MagicMock) @patch('google.genai.client.Client') - def test_init_with_all(self, mock_genai_client: MagicMock) -> None: + def test_init_with_all(self, mock_genai_client): """Test using credentials parameter.""" mock_credentials = MagicMock(spec=Credentials) api_key = 'test_api_key' @@ -462,13 +464,13 @@ def test_init_with_all(self, mock_genai_client: MagicMock) -> None: @pytest.fixture @patch('google.genai.client.Client') -def vertexai_plugin_instance(client: MagicMock) -> VertexAI: +def vertexai_plugin_instance(client): """VertexAI fixture.""" return VertexAI() @pytest.mark.asyncio -async def test_vertexai_initialize(vertexai_plugin_instance: VertexAI) -> None: +async def test_vertexai_initialize(vertexai_plugin_instance): """Unit tests for VertexAI.init method.""" plugin = vertexai_plugin_instance @@ -491,9 +493,7 @@ async def test_vertexai_initialize(vertexai_plugin_instance: VertexAI) -> None: @patch('genkit.plugins.google_genai.VertexAI._resolve_model') @pytest.mark.asyncio -async def test_vertexai_resolve_action_model( - mock_resolve_action: MagicMock, vertexai_plugin_instance: VertexAI -) -> None: +async def test_vertexai_resolve_action_model(mock_resolve_action, vertexai_plugin_instance): """Test resolve action for model.""" plugin = vertexai_plugin_instance @@ -503,9 +503,7 @@ async def test_vertexai_resolve_action_model( @patch('genkit.plugins.google_genai.VertexAI._resolve_embedder') @pytest.mark.asyncio -async def test_vertexai_resolve_action_embedder( - mock_resolve_action: MagicMock, vertexai_plugin_instance: VertexAI -) -> None: +async def test_vertexai_resolve_action_embedder(mock_resolve_action, vertexai_plugin_instance): """Test resolve action for embedder.""" plugin = vertexai_plugin_instance @@ -557,14 +555,14 @@ async def test_vertexai_resolve_action_embedder( ], ) def test_vertexai__resolve_model( - mock_google_model_info: MagicMock, - mock_vertexai_image_model_info: MagicMock, - model_name: str, - expected_model_name: str, - key: str, - image: bool, - vertexai_plugin_instance: VertexAI, -) -> None: + mock_google_model_info, + mock_vertexai_image_model_info, + model_name, + expected_model_name, + key, + image, + vertexai_plugin_instance, +): """Tests for VertexAI._resolve_model method.""" plugin = vertexai_plugin_instance MagicMock(spec=Genkit) @@ -607,11 +605,11 @@ def test_vertexai__resolve_model( ], ) def test_vertexai__resolve_embedder( - model_name: str, - expected_model_name: str, - clean_name: str, - vertexai_plugin_instance: VertexAI, -) -> None: + model_name, + expected_model_name, + clean_name, + vertexai_plugin_instance, +): """Tests for VertexAI._resolve_embedder method.""" plugin = vertexai_plugin_instance @@ -623,59 +621,136 @@ def test_vertexai__resolve_embedder( @pytest.mark.asyncio -async def test_vertexai_list_actions(vertexai_plugin_instance: VertexAI) -> None: +async def test_vertexai_list_actions(vertexai_plugin_instance): """Unit test for list actions.""" - class MockModel(BaseModel): - """mock.""" - + @dataclass + class MockModel: name: str + description: str = '' models_return_value = [ - MockModel(name='publishers/google/models/model1'), - MockModel(name='publishers/google/models/model2_embeddings'), - MockModel(name='publishers/google/models/model3_embedder'), + MockModel(name='publishers/google/models/gemini-1.5-flash'), + MockModel(name='publishers/google/models/text-embedding-004'), + MockModel(name='publishers/google/models/imagen-3.0-generate-001'), + MockModel(name='publishers/google/models/veo-2.0-generate-001'), ] mock_client = MagicMock() - mock_client.models.list.return_value = models_return_value + # Create sophisticated mocks that have supported_actions + m1 = MagicMock() + m1.name = 'publishers/google/models/gemini-1.5-flash' + m1.supported_actions = ['generateContent'] + m1.description = 'Gemini model' + + m2 = MagicMock() + m2.name = 'publishers/google/models/text-embedding-004' + m2.supported_actions = ['embedContent'] + m2.description = 'Embedder' + + m3 = MagicMock() + m3.name = 'publishers/google/models/imagen-3.0-generate-001' + m3.supported_actions = ['predict'] # Imagen uses predict + m3.description = 'Imagen' + + m4 = MagicMock() + m4.name = 'publishers/google/models/veo-2.0-generate-001' + m4.supported_actions = ['generateVideos'] # Veo uses generateVideos + m4.description = 'Veo' + + mock_client.models.list.return_value = [m1, m2, m3, m4] vertexai_plugin_instance._client = mock_client result = await vertexai_plugin_instance.list_actions() - assert result == [ - model_action_metadata( - name=vertexai_name('model1'), - info=google_model_info('model1').model_dump(), - config_schema=GeminiConfigSchema, - ), - embedder_action_metadata( - name=vertexai_name('model2_embeddings'), - options=EmbedderOptions( - label=default_embedder_info('model2_embeddings').get('label'), - supports=EmbedderSupports( - input=default_embedder_info('model2_embeddings').get('supports', {}).get('input') - ), - dimensions=default_embedder_info('model2_embeddings').get('dimensions'), - ), - ), - model_action_metadata( - name=vertexai_name('model2_embeddings'), - info=google_model_info('model2_embeddings').model_dump(), - config_schema=GeminiConfigSchema, - ), - embedder_action_metadata( - name=vertexai_name('model3_embedder'), - options=EmbedderOptions( - label=default_embedder_info('model3_embedder').get('label'), - supports=EmbedderSupports( - input=default_embedder_info('model3_embedder').get('supports', {}).get('input') - ), - dimensions=default_embedder_info('model3_embedder').get('dimensions'), - ), - ), - model_action_metadata( - name=vertexai_name('model3_embedder'), - info=google_model_info('model3_embedder').model_dump(), - config_schema=GeminiConfigSchema, - ), - ] + + # Verify Gemini + action1 = next(a for a in result if a.name == vertexai_name('gemini-1.5-flash')) + assert action1 is not None + + # Verify Embedder + action2 = next(a for a in result if a.name == vertexai_name('text-embedding-004')) + assert action2 is not None + + # Verify Imagen + action3 = next(a for a in result if a.name == vertexai_name('imagen-3.0-generate-001')) + assert action3 is not None + assert action3.kind == ActionKind.MODEL + + # Verify Veo + action4 = next(a for a in result if a.name == vertexai_name('veo-2.0-generate-001')) + assert action4 is not None + # from genkit.plugins.google_genai.models.veo import VeoConfigSchema + # assert action4.config_schema == VeoConfigSchema + + +def test_config_schema_extra_fields(): + """Test that config schema accepts extra fields (dynamic config).""" + from genkit.plugins.google_genai.models.gemini import GeminiConfigSchema + + # Validation should succeed with unknown field + config = GeminiConfigSchema(temperature=0.5, new_experimental_param='test') + assert config.temperature == 0.5 + assert config.new_experimental_param == 'test' + assert config.new_experimental_param == 'test' + assert config.model_dump()['new_experimental_param'] == 'test' + + +def test_system_prompt_handling(): + """Test that system prompts are correctly extracted to config.""" + from google import genai + + from genkit.plugins.google_genai.models.gemini import GeminiModel + from genkit.types import GenerateRequest, Message, Role, TextPart + + mock_client = MagicMock(spec=genai.Client) + model = GeminiModel(version='gemini-1.5-flash', client=mock_client) + + request = GenerateRequest( + messages=[ + Message(role=Role.SYSTEM, content=[TextPart(text='You are a helpful assistant')]), + Message(role=Role.USER, content=[TextPart(text='Hello')]), + ], + config=None, + ) + + cfg = model._genkit_to_googleai_cfg(request) + + assert cfg is not None + assert cfg.system_instruction is not None + assert len(cfg.system_instruction.parts) == 1 + assert cfg.system_instruction.parts[0].text == 'You are a helpful assistant' + + def test_config_schema_extra_fields(): + """Test that config schema accepts extra fields (dynamic config).""" + from genkit.plugins.google_genai.models.gemini import GeminiConfigSchema + + # Validation should succeed with unknown field + config = GeminiConfigSchema(temperature=0.5, new_experimental_param='test') + assert config.temperature == 0.5 + assert config.new_experimental_param == 'test' + assert config.new_experimental_param == 'test' + assert config.model_dump()['new_experimental_param'] == 'test' + + def test_system_prompt_handling(): + """Test that system prompts are correctly extracted to config.""" + from genkit.plugins.google_genai.models.gemini import GeminiModel + from genkit.types import GenerateRequest, Message, Role, TextPart + from google import genai + + mock_client = MagicMock(spec=genai.Client) + model = GeminiModel(version='gemini-3-flash-preview', client=mock_client) + + request = GenerateRequest( + messages=[ + Message(role=Role.SYSTEM, content=[TextPart(text='You are a helpful assistant')]), + Message(role=Role.USER, content=[TextPart(text='Hello')]), + ], + config=None, + ) + + cfg = model._genkit_to_googleai_cfg(request) + + assert cfg is not None + assert cfg.system_instruction is not None + assert len(cfg.system_instruction.parts) == 1 + assert cfg.system_instruction.parts[0].text == 'You are a helpful assistant' diff --git a/py/plugins/xai/src/genkit/plugins/xai/model_info.py b/py/plugins/xai/src/genkit/plugins/xai/model_info.py index a64899e595..4d4abe0c20 100644 --- a/py/plugins/xai/src/genkit/plugins/xai/model_info.py +++ b/py/plugins/xai/src/genkit/plugins/xai/model_info.py @@ -16,19 +16,11 @@ """xAI model information.""" -import sys - -if sys.version_info < (3, 11): - from strenum import StrEnum -else: - from enum import StrEnum - from genkit.types import ModelInfo, Supports __all__ = ['SUPPORTED_XAI_MODELS', 'get_model_info'] - -LANGUAGE_MODEL_SUPPORTS = Supports( +_LANGUAGE_MODEL_SUPPORTS = Supports( multiturn=True, tools=True, media=False, @@ -36,57 +28,60 @@ output=['text', 'json'], ) -GROK_3 = ModelInfo(label='xAI - Grok 3', versions=['grok-3'], supports=LANGUAGE_MODEL_SUPPORTS) -GROK_3_FAST = ModelInfo(label='xAI - Grok 3 Fast', versions=['grok-3-fast'], supports=LANGUAGE_MODEL_SUPPORTS) -GROK_3_MINI = ModelInfo(label='xAI - Grok 3 Mini', versions=['grok-3-mini'], supports=LANGUAGE_MODEL_SUPPORTS) -GROK_3_MINI_FAST = ModelInfo( - label='xAI - Grok 3 Mini Fast', versions=['grok-3-mini-fast'], supports=LANGUAGE_MODEL_SUPPORTS -) -GROK_4 = ModelInfo(label='xAI - Grok 4', versions=['grok-4'], supports=LANGUAGE_MODEL_SUPPORTS) -GROK_2_VISION_1212 = ModelInfo( - label='xAI - Grok 2 Vision', - versions=['grok-2-vision-1212'], - supports=Supports( - multiturn=False, - tools=True, - media=True, - system_role=False, - output=['text', 'json'], - ), +_VISION_MODEL_SUPPORTS = Supports( + multiturn=False, + tools=True, + media=True, + system_role=False, + output=['text', 'json'], ) - -# Enum for xAI Grok versions -class XAIGrokVersion(StrEnum): - """xAI Grok models. - - Model Support: - - | Model | Description | Status | - |----------------------|--------------------|------------| - | `grok-3` | Grok 3 | Supported | - | `grok-3-fast` | Grok 3 Fast | Supported | - | `grok-3-mini` | Grok 3 Mini | Supported | - | `grok-3-mini-fast` | Grok 3 Mini Fast | Supported | - | `grok-4` | Grok 4 | Supported | - | `grok-2-vision-1212` | Grok 2 Vision | Supported | - """ - - GROK_3 = 'grok-3' - GROK_3_FAST = 'grok-3-fast' - GROK_3_MINI = 'grok-3-mini' - GROK_3_MINI_FAST = 'grok-3-mini-fast' - GROK_4 = 'grok-4' - GROK_2_VISION_1212 = 'grok-2-vision-1212' - - SUPPORTED_XAI_MODELS: dict[str, ModelInfo] = { - XAIGrokVersion.GROK_3: GROK_3, - XAIGrokVersion.GROK_3_FAST: GROK_3_FAST, - XAIGrokVersion.GROK_3_MINI: GROK_3_MINI, - XAIGrokVersion.GROK_3_MINI_FAST: GROK_3_MINI_FAST, - XAIGrokVersion.GROK_4: GROK_4, - XAIGrokVersion.GROK_2_VISION_1212: GROK_2_VISION_1212, + 'grok-3': ModelInfo( + label='xAI - Grok 3', + versions=['grok-3'], + supports=_LANGUAGE_MODEL_SUPPORTS, + ), + 'grok-3-fast': ModelInfo( + label='xAI - Grok 3 Fast', + versions=['grok-3-fast'], + supports=_LANGUAGE_MODEL_SUPPORTS, + ), + 'grok-3-mini': ModelInfo( + label='xAI - Grok 3 Mini', + versions=['grok-3-mini'], + supports=_LANGUAGE_MODEL_SUPPORTS, + ), + 'grok-3-mini-fast': ModelInfo( + label='xAI - Grok 3 Mini Fast', + versions=['grok-3-mini-fast'], + supports=_LANGUAGE_MODEL_SUPPORTS, + ), + 'grok-2-vision-1212': ModelInfo( + label='xAI - Grok 2 Vision', + versions=['grok-2-vision-1212'], + supports=_VISION_MODEL_SUPPORTS, + ), + 'grok-4': ModelInfo( + label='xAI - Grok 4', + versions=['grok-4'], + supports=_LANGUAGE_MODEL_SUPPORTS, + ), + 'grok-4.1': ModelInfo( + label='xAI - Grok 4.1', + versions=['grok-4.1'], + supports=_LANGUAGE_MODEL_SUPPORTS, + ), + 'grok-2-1212': ModelInfo( + label='xAI - Grok 2', + versions=['grok-2-1212'], + supports=_LANGUAGE_MODEL_SUPPORTS, + ), + 'grok-2-latest': ModelInfo( + label='xAI - Grok 2 Latest', + versions=['grok-2-latest'], + supports=_LANGUAGE_MODEL_SUPPORTS, + ), } @@ -96,6 +91,6 @@ def get_model_info(name: str) -> ModelInfo: name, ModelInfo( label=f'xAI - {name}', - supports=LANGUAGE_MODEL_SUPPORTS, + supports=_LANGUAGE_MODEL_SUPPORTS, ), ) diff --git a/py/plugins/xai/src/genkit/plugins/xai/models.py b/py/plugins/xai/src/genkit/plugins/xai/models.py index c22f472f03..9991f67343 100644 --- a/py/plugins/xai/src/genkit/plugins/xai/models.py +++ b/py/plugins/xai/src/genkit/plugins/xai/models.py @@ -20,7 +20,6 @@ import json from typing import Any, cast -from pydantic import Field, ValidationError from xai_sdk import Client as XAIClient from xai_sdk.proto.v6 import chat_pb2, image_pb2 @@ -33,7 +32,6 @@ GenerateRequest, GenerateResponse, GenerateResponseChunk, - GenerationCommonConfig, GenerationUsage, MediaPart, Message, @@ -45,21 +43,6 @@ ToolResponsePart, ) - -class XAIConfig(GenerationCommonConfig): - deferred: bool | None = None - reasoning_effort: str | None = Field(None, pattern='^(low|medium|high)$') - web_search_options: dict | None = None - frequency_penalty: float | None = None - presence_penalty: float | None = None - - -# Tool type mapping for xAI(function only, for now) -TOOL_TYPE_MAP = { - 'function': chat_pb2.ToolCallType.TOOL_CALL_TYPE_CLIENT_SIDE_TOOL, -} - - __all__ = ['XAIModel'] DEFAULT_MAX_OUTPUT_TOKENS = 4096 @@ -79,22 +62,6 @@ class XAIConfig(GenerationCommonConfig): } -def build_generation_usage( - final_response: Any | None, # noqa: ANN401 - basic_usage: GenerationUsage, -) -> GenerationUsage: - """Builds a GenerationUsage object from a final_response and basic_usage.""" - return GenerationUsage( - input_tokens=getattr(final_response.usage, 'prompt_tokens', 0) if final_response else 0, - output_tokens=getattr(final_response.usage, 'completion_tokens', 0) if final_response else 0, - total_tokens=getattr(final_response.usage, 'total_tokens', 0) if final_response else 0, - input_characters=basic_usage.input_characters, - output_characters=basic_usage.output_characters, - input_images=basic_usage.input_images, - output_images=basic_usage.output_images, - ) - - class XAIModel: """xAI Grok model for Genkit.""" @@ -132,40 +99,64 @@ def _sample() -> Any: # noqa: ANN401 return GenerateResponse( message=response_message, - usage=build_generation_usage(response, basic_usage), + usage=GenerationUsage( + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + input_characters=basic_usage.input_characters, + output_characters=basic_usage.output_characters, + input_images=basic_usage.input_images, + output_images=basic_usage.output_images, + ), finish_reason=FINISH_REASON_MAP.get(response.finish_reason, FinishReason.UNKNOWN), ) def _build_params(self, request: GenerateRequest) -> dict[str, object]: - """Build xAI API parameters from request using validated config.""" - config = request.config or {} - if not isinstance(config, XAIConfig): - try: - config = XAIConfig.model_validate(config) - except ValidationError: - config = XAIConfig() + """Build xAI API parameters from request.""" + config = request.config + if isinstance(config, dict): + max_tokens = config.get('max_output_tokens') or DEFAULT_MAX_OUTPUT_TOKENS + temperature = config.get('temperature') + top_p = config.get('top_p') + stop = config.get('stop_sequences') + frequency_penalty = config.get('frequency_penalty') + presence_penalty = config.get('presence_penalty') + web_search_options = config.get('web_search_options') + deferred = config.get('deferred') + reasoning_effort = config.get('reasoning_effort') + else: + max_tokens = (config.max_output_tokens if config else None) or DEFAULT_MAX_OUTPUT_TOKENS + temperature = getattr(config, 'temperature', None) if config else None + top_p = getattr(config, 'top_p', None) if config else None + stop = getattr(config, 'stop_sequences', None) if config else None + frequency_penalty = getattr(config, 'frequency_penalty', None) if config else None + presence_penalty = getattr(config, 'presence_penalty', None) if config else None + web_search_options = getattr(config, 'web_search_options', None) if config else None + deferred = getattr(config, 'deferred', None) if config else None + reasoning_effort = getattr(config, 'reasoning_effort', None) if config else None params: dict[str, object] = { 'model': self.model_name, 'messages': self._to_xai_messages(request.messages), - 'max_tokens': int(config.max_output_tokens or DEFAULT_MAX_OUTPUT_TOKENS), + 'max_tokens': int(max_tokens), } - if config.temperature is not None: - params['temperature'] = config.temperature - if config.top_p is not None: - params['top_p'] = config.top_p - if config.stop_sequences: - params['stop'] = config.stop_sequences - if getattr(config, 'frequency_penalty', None) is not None: - params['frequency_penalty'] = config.frequency_penalty - if getattr(config, 'presence_penalty', None) is not None: - params['presence_penalty'] = config.presence_penalty - if config.web_search_options is not None: - params['web_search_options'] = config.web_search_options - if config.deferred is not None: - params['deferred'] = config.deferred - if config.reasoning_effort is not None: - params['reasoning_effort'] = config.reasoning_effort + + if temperature is not None: + params['temperature'] = temperature + if top_p is not None: + params['top_p'] = top_p + if stop: + params['stop'] = stop + if frequency_penalty is not None: + params['frequency_penalty'] = frequency_penalty + if presence_penalty is not None: + params['presence_penalty'] = presence_penalty + if web_search_options is not None: + params['web_search_options'] = web_search_options + if deferred is not None: + params['deferred'] = deferred + if reasoning_effort is not None: + params['reasoning_effort'] = reasoning_effort if request.tools: params['tools'] = [ @@ -173,7 +164,7 @@ def _build_params(self, request: GenerateRequest) -> dict[str, object]: function=chat_pb2.Function( name=t.name, description=t.description or '', - parameters=json.dumps(to_json_schema(t.input_schema or {})), + parameters=json.dumps(to_json_schema(t.input_schema)), ), ) for t in request.tools @@ -237,7 +228,15 @@ def _sync_stream() -> GenerateResponse: return GenerateResponse( message=response_message, - usage=build_generation_usage(final_response, basic_usage), + usage=GenerationUsage( + input_tokens=final_response.usage.prompt_tokens if final_response else 0, + output_tokens=final_response.usage.completion_tokens if final_response else 0, + total_tokens=final_response.usage.total_tokens if final_response else 0, + input_characters=basic_usage.input_characters, + output_characters=basic_usage.output_characters, + input_images=basic_usage.input_images, + output_images=basic_usage.output_images, + ), finish_reason=finish_reason, ) @@ -263,14 +262,13 @@ def _to_xai_messages(self, messages: list[Message]) -> list[chat_pb2.Message]: chat_pb2.Content(image_url=image_pb2.ImageUrlContent(image_url=actual_part.media.url)) ) elif isinstance(actual_part, ToolRequestPart): - tool_type = getattr(actual_part.tool_request, 'type', 'function') tool_calls.append( chat_pb2.ToolCall( id=actual_part.tool_request.ref, - type=TOOL_TYPE_MAP.get(tool_type, chat_pb2.ToolCallType.TOOL_CALL_TYPE_CLIENT_SIDE_TOOL), + type=chat_pb2.ToolCallType.TOOL_CALL_TYPE_CLIENT_SIDE_TOOL, function=chat_pb2.FunctionCall( name=actual_part.tool_request.name, - arguments=json.dumps(actual_part.tool_request.input), + arguments=actual_part.tool_request.input, ), ) ) @@ -278,12 +276,13 @@ def _to_xai_messages(self, messages: list[Message]) -> list[chat_pb2.Message]: result.append( chat_pb2.Message( role=chat_pb2.MessageRole.ROLE_TOOL, + name=actual_part.tool_response.ref, content=[chat_pb2.Content(text=str(actual_part.tool_response.output))], ) ) continue - pb_message = chat_pb2.Message(role=role, content=content or [chat_pb2.Content(text='')]) + pb_message = chat_pb2.Message(role=role, content=content) if tool_calls: pb_message.tool_calls.extend(tool_calls) diff --git a/py/plugins/xai/src/genkit/plugins/xai/plugin.py b/py/plugins/xai/src/genkit/plugins/xai/plugin.py index a6e2c6f1df..bfdd1b3304 100644 --- a/py/plugins/xai/src/genkit/plugins/xai/plugin.py +++ b/py/plugins/xai/src/genkit/plugins/xai/plugin.py @@ -28,7 +28,8 @@ from genkit.core.registry import ActionKind from genkit.core.schema import to_json_schema from genkit.plugins.xai.model_info import SUPPORTED_XAI_MODELS, get_model_info -from genkit.plugins.xai.models import XAIConfig, XAIModel +from genkit.plugins.xai.models import XAIModel +from genkit.types import GenerationCommonConfig __all__ = ['XAI', 'xai_name'] @@ -120,7 +121,7 @@ def _create_model_action(self, name: str) -> Action: metadata={ 'model': { 'supports': model_info.supports.model_dump() if model_info.supports else {}, - 'customOptions': to_json_schema(XAIConfig), + 'customOptions': to_json_schema(GenerationCommonConfig), }, }, ) @@ -137,7 +138,7 @@ async def list_actions(self) -> list: model_action_metadata( name=xai_name(model_name), info={'supports': model_info.supports.model_dump() if model_info.supports else {}}, - config_schema=XAIConfig, + config_schema=GenerationCommonConfig, ) ) return actions diff --git a/py/plugins/xai/tests/xai_models_test.py b/py/plugins/xai/tests/xai_models_test.py index 0ea8c5766c..101ed6700e 100644 --- a/py/plugins/xai/tests/xai_models_test.py +++ b/py/plugins/xai/tests/xai_models_test.py @@ -21,10 +21,11 @@ import pytest -from genkit.plugins.xai.models import XAIConfig, XAIModel +from genkit.plugins.xai.models import XAIModel from genkit.types import ( GenerateRequest, GenerateResponseChunk, + GenerationCommonConfig, Message, Part, Role, @@ -43,7 +44,7 @@ def _create_sample_request() -> GenerateRequest: content=[Part(root=TextPart(text='Hello, how are you?'))], ) ], - config=XAIConfig(), + config=GenerationCommonConfig(), tools=[ ToolDefinition( name='get_weather', @@ -120,7 +121,7 @@ async def test_generate_with_config() -> None: request = GenerateRequest( messages=[Message(role=Role.USER, content=[Part(root=TextPart(text='Test'))])], - config=XAIConfig( + config=GenerationCommonConfig( temperature=0.7, max_output_tokens=100, top_p=0.9, @@ -282,7 +283,7 @@ async def test_build_params_basic() -> None: request = GenerateRequest( messages=[Message(role=Role.USER, content=[Part(root=TextPart(text='Test'))])], - config=XAIConfig(), + config=GenerationCommonConfig(), ) params = model._build_params(request) @@ -300,11 +301,11 @@ async def test_build_params_with_config() -> None: request = GenerateRequest( messages=[Message(role=Role.USER, content=[Part(root=TextPart(text='Test'))])], - config=XAIConfig( - temperature=0.5, - max_output_tokens=200, - top_p=0.8, - ), + config={ + 'temperature': 0.5, + 'max_output_tokens': 200, + 'top_p': 0.8, + }, ) params = model._build_params(request) @@ -322,13 +323,13 @@ async def test_build_params_with_xai_specific_config() -> None: request = GenerateRequest( messages=[Message(role=Role.USER, content=[Part(root=TextPart(text='Test'))])], - config=XAIConfig( - temperature=0.7, - max_output_tokens=300, - deferred=True, - reasoning_effort='high', - web_search_options={'enabled': True}, - ), + config={ + 'temperature': 0.7, + 'max_output_tokens': 300, + 'deferred': True, + 'reasoning_effort': 'high', + 'web_search_options': {'enabled': True}, + }, ) params = model._build_params(request) diff --git a/py/pyproject.toml b/py/pyproject.toml index a910960027..3312617f74 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -34,7 +34,6 @@ dependencies = [ "genkit-plugin-mcp", "liccheck>=0.9.2", "mcp>=1.25.0", - "python-multipart>=0.0.22", "strenum>=0.4.15; python_version < '3.11'", ] description = "Workspace for Genkit packages" @@ -135,8 +134,6 @@ prompt-demo = { workspace = true } [tool.uv.workspace] members = ["packages/*", "plugins/*", "samples/*"] -exclude = ["*/shared"] - # Ruff checks and formatting. [tool.ruff] diff --git a/py/samples/anthropic-hello/src/main.py b/py/samples/anthropic-hello/src/main.py index 115a6c965c..8fcb1f66f9 100755 --- a/py/samples/anthropic-hello/src/main.py +++ b/py/samples/anthropic-hello/src/main.py @@ -179,23 +179,15 @@ async def generate_character( @ai.tool() def get_weather(input: WeatherInput) -> str: - """Return a random realistic weather string for a city name. + """Get weather for a location. Args: - input: Weather input location. + input: Weather input with location. Returns: - Weather information with temperature in degree Celsius. + Weather information. """ - import random - - weather_options = [ - '32° C sunny', - '17° C cloudy', - '22° C cloudy', - '19° C humid', - ] - return random.choice(weather_options) + return f'Weather in {input.location}: Sunny, 23°C' @ai.flow() diff --git a/py/samples/deepseek-hello/src/main.py b/py/samples/deepseek-hello/src/main.py index 58daf0e5a2..766ee87753 100644 --- a/py/samples/deepseek-hello/src/main.py +++ b/py/samples/deepseek-hello/src/main.py @@ -96,48 +96,7 @@ class RpgCharacter(BaseModel): class WeatherInput(BaseModel): """Input schema for the weather tool.""" - location: str = Field(description='City or location name') - - -@ai.tool() -def get_weather(input: WeatherInput) -> str: - """Return a random realistic weather string for a location. - - Args: - input: Weather input location. - - Returns: - Weather information with temperature in degrees Celsius. - """ - import random - - weather_options = [ - '32° C sunny', - '17° C cloudy', - '22° C cloudy', - '19° C humid', - ] - return random.choice(weather_options) - - -@ai.flow() -async def reasoning_flow(prompt: str | None = None) -> str: - """Solve reasoning problems using deepseek-reasoner model. - - Args: - prompt: The reasoning question to solve. Defaults to a classic logic problem. - - Returns: - The reasoning and answer. - """ - if prompt is None: - prompt = 'What is heavier, one kilo of steel or one kilo of feathers?' - - response = await ai.generate( - model=deepseek_name('deepseek-reasoner'), - prompt=prompt, - ) - return response.text + location: str = Field(description='The city and state, e.g. San Francisco, CA') @ai.flow() @@ -255,19 +214,19 @@ async def custom_config_flow(task: str | None = None) -> str: configs = { 'creative': { - 'temperature': 1.5, + 'temperature': 1.5, # High temperature for creativity 'max_tokens': 200, 'top_p': 0.95, }, 'precise': { - 'temperature': 0.1, + 'temperature': 0.1, # Low temperature for consistency 'max_tokens': 150, - 'presence_penalty': 0.5, + 'presence_penalty': 0.5, # Encourage covering all steps }, 'detailed': { 'temperature': 0.7, - 'max_tokens': 400, - 'frequency_penalty': 0.8, + 'max_tokens': 400, # More tokens for detailed explanation + 'frequency_penalty': 0.8, # Reduce repetitive phrasing }, } @@ -301,6 +260,48 @@ async def generate_character( return cast(RpgCharacter, result.output) +@ai.tool() +def get_weather(input: WeatherInput) -> str: + """Get weather of a location, the user should supply a location first. + + Args: + input: Weather input with location (city and state, e.g. San Francisco, CA). + + Returns: + Weather information with temperature in degrees Fahrenheit. + """ + # Mocked weather data + weather_data = { + 'San Francisco, CA': {'temp': 72, 'condition': 'sunny', 'humidity': 65}, + 'Seattle, WA': {'temp': 55, 'condition': 'rainy', 'humidity': 85}, + } + + location = input.location + data = weather_data.get(location, {'temp': 70, 'condition': 'partly cloudy', 'humidity': 55}) + + return f'The weather in {location} is {data["temp"]}°F and {data["condition"]}. Humidity is {data["humidity"]}%.' + + +@ai.flow() +async def reasoning_flow(prompt: str | None = None) -> str: + """Solve reasoning problems using deepseek-reasoner model. + + Args: + prompt: The reasoning question to solve. Defaults to a classic logic problem. + + Returns: + The reasoning and answer. + """ + if prompt is None: + prompt = 'What is heavier, one kilo of steel or one kilo of feathers?' + + response = await ai.generate( + model=deepseek_name('deepseek-reasoner'), + prompt=prompt, + ) + return response.text + + @ai.flow() async def say_hi(name: Annotated[str, Field(default='Alice')] = 'Alice') -> str: """Generate a simple greeting. @@ -337,7 +338,7 @@ async def streaming_flow( @ai.flow() -async def weather_flow(location: Annotated[str, Field(default='London')] = 'London') -> str: +async def weather_flow(location: Annotated[str, Field(default='San Francisco, CA')] = 'San Francisco, CA') -> str: """Get weather using compat-oai auto tool calling.""" response = await ai.generate( model=deepseek_name('deepseek-chat'), diff --git a/py/samples/google-genai-hello/src/main.py b/py/samples/google-genai-hello/src/main.py index d2a232215a..35ff16eee4 100755 --- a/py/samples/google-genai-hello/src/main.py +++ b/py/samples/google-genai-hello/src/main.py @@ -107,22 +107,6 @@ ) -class CurrencyExchangeInput(BaseModel): - """Currency exchange flow input schema.""" - - amount: float = Field(description='Amount to convert', default=100) - from_curr: str = Field(description='Source currency code', default='USD') - to_curr: str = Field(description='Target currency code', default='EUR') - - -class CurrencyInput(BaseModel): - """Currency conversion input schema.""" - - amount: float = Field(description='Amount to convert', default=100) - from_currency: str = Field(description='Source currency code (e.g., USD)', default='USD') - to_currency: str = Field(description='Target currency code (e.g., EUR)', default='EUR') - - class GablorkenInput(BaseModel): """The Pydantic model for tools.""" @@ -153,66 +137,96 @@ class ThinkingLevel(StrEnum): HIGH = 'HIGH' -class ThinkingLevelFlash(StrEnum): - """Thinking level flash enum.""" +@ai.flow() +async def simple_generate_with_tools_flow( + value: Annotated[int, Field(default=42)] = 42, + ctx: ActionRunContext = None, # type: ignore[assignment] +) -> str: + """Generate a greeting for the given name. - MINIMAL = 'MINIMAL' - LOW = 'LOW' - MEDIUM = 'MEDIUM' - HIGH = 'HIGH' + Args: + value: the integer to send to test function + ctx: the flow context + Returns: + The generated response with a function. + """ + response = await ai.generate( + prompt=f'what is a gablorken of {value}', + tools=['gablorkenTool'], + on_chunk=ctx.send_chunk, + ) + return response.text -class WeatherInput(BaseModel): - """Input for getting weather.""" - location: str = Field(description='The city and state, e.g. San Francisco, CA') +@ai.tool(name='gablorkenTool2') +def gablorken_tool2(input_: GablorkenInput, ctx: ToolRunContext) -> None: + """The user-defined tool function. + Args: + input_: the input to the tool + ctx: the tool run context -@ai.tool(name='celsiusToFahrenheit') -def celsius_to_fahrenheit(celsius: float) -> float: - """Converts Celsius to Fahrenheit.""" - return (celsius * 9) / 5 + 32 + Returns: + The calculated gablorken. + """ + ctx.interrupt() -@ai.tool() -def convert_currency(input: CurrencyInput) -> str: - """Convert currency amount. +@ai.flow() +async def simple_generate_with_interrupts(value: Annotated[int, Field(default=42)] = 42) -> str: + """Generate a greeting for the given name. Args: - input: Currency conversion parameters. + value: the integer to send to test function Returns: - Converted amount. + The generated response with a function. """ - # Mock conversion rates - rates = { - ('USD', 'EUR'): 0.85, - ('EUR', 'USD'): 1.18, - ('USD', 'GBP'): 0.73, - ('GBP', 'USD'): 1.37, - } - - rate = rates.get((input.from_currency, input.to_currency), 1.0) - converted = input.amount * rate + response1 = await ai.generate( + messages=[ + Message( + role=Role.USER, + content=[Part(root=TextPart(text=f'what is a gablorken of {value}'))], + ), + ], + tools=['gablorkenTool2'], + ) + await logger.ainfo(f'len(response.tool_requests)={len(response1.tool_requests)}') + if len(response1.interrupts) == 0: + return response1.text - return f'{input.amount} {input.from_currency} = {converted:.2f} {input.to_currency}' + tr = tool_response(response1.interrupts[0], {'output': 178}) + response = await ai.generate( + messages=response1.messages, + tool_responses=[tr], + tools=['gablorkenTool'], + ) + return response.text @ai.flow() -async def currency_exchange(input: CurrencyExchangeInput) -> str: - """Convert currency using tools. +async def say_hi(name: Annotated[str, Field(default='Alice')] = 'Alice') -> str: + """Generate a greeting for the given name. Args: - input: Currency exchange parameters. + name: the name to send to test function Returns: - Conversion result. + The generated response with a function. """ - response = await ai.generate( - prompt=f'Convert {input.amount} {input.from_curr} to {input.to_curr}', - tools=['convert_currency'], + resp = await ai.generate( + prompt=f'hi {name}', ) - return response.text + + await logger.ainfo( + 'generation_response', + has_usage=hasattr(resp, 'usage'), + usage_dict=resp.usage.model_dump() if hasattr(resp, 'usage') and resp.usage else None, + text_length=len(resp.text), + ) + + return resp.text @ai.flow() @@ -253,24 +267,6 @@ def multiplier_fn(x: int) -> int: } -@ai.flow() -async def describe_image( - image_url: Annotated[ - str, Field(default='https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png') - ] = 'https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png', -) -> str: - """Describe an image.""" - response = await ai.generate( - model='googleai/gemini-3-flash-preview', - prompt=[ - Part(root=TextPart(text='Describe this image')), - Part(root=MediaPart(media=Media(url=image_url, content_type='image/png'))), - ], - config={'api_version': 'v1alpha'}, - ) - return response.text - - @ai.flow() async def embed_docs( docs: list[str] | None = None, @@ -294,37 +290,29 @@ async def embed_docs( @ai.flow() -async def file_search() -> str: - """File Search.""" - # TODO: add file search store - store_name = 'fileSearchStores/sample-store' - response = await ai.generate( - model='googleai/gemini-3-flash-preview', - prompt="What is the character's name in the story?", - config={ - 'file_search': { - 'file_search_store_names': [store_name], - 'metadata_filter': 'author=foo', - }, - 'api_version': 'v1alpha', - }, - ) - return response.text - +async def say_hi_with_configured_temperature( + data: Annotated[str, Field(default='Alice')] = 'Alice', +) -> GenerateResponseWrapper: + """Generate a greeting for the given name. -@ai.tool(name='gablorkenTool') -def gablorken_tool(input_: GablorkenInput) -> dict[str, int]: - """Calculate a gablorken. + Args: + data: the name to send to test function Returns: - The calculated gablorken. + The generated response with a function. """ - return {'result': input_.value * 3 - 5} + return await ai.generate( + messages=[Message(role=Role.USER, content=[Part(root=TextPart(text=f'hi {data}'))])], + config=GenerationCommonConfig(temperature=0.1), + ) -@ai.tool(name='gablorkenTool2') -def gablorken_tool2(input_: GablorkenInput, ctx: ToolRunContext) -> None: - """The user-defined tool function. +@ai.flow() +async def say_hi_stream( + name: Annotated[str, Field(default='Alice')] = 'Alice', + ctx: ActionRunContext = None, # type: ignore[assignment] +) -> str: + """Generate a greeting for the given name. Args: input_: the input to the tool @@ -333,7 +321,30 @@ def gablorken_tool2(input_: GablorkenInput, ctx: ToolRunContext) -> None: Returns: The calculated gablorken. """ - ctx.interrupt() + stream, _ = ai.generate_stream(prompt=f'hi {name}') + result: str = '' + async for data in stream: + ctx.send_chunk(data.text) + result += data.text + + return result + + +class Skills(BaseModel): + """Skills for an RPG character.""" + + strength: int = Field(description='strength (0-100)') + charisma: int = Field(description='charisma (0-100)') + endurance: int = Field(description='endurance (0-100)') + + +class RpgCharacter(BaseModel): + """An RPG character.""" + + name: str = Field(description='name of the character') + back_story: str = Field(description='back story', alias='backStory') + abilities: list[str] = Field(description='list of abilities (3-4)') + skills: Skills @ai.flow() @@ -390,151 +401,18 @@ async def generate_character_unconstrained( return cast(RpgCharacter, result.output) -@ai.tool(name='getWeather') -def get_weather(input_: WeatherInput) -> dict: - """Used to get current weather for a location.""" - return { - 'location': input_.location, - 'temperature_celcius': 21.5, - 'conditions': 'cloudy', - } - - -@ai.flow() -async def say_hi(name: Annotated[str, Field(default='Alice')] = 'Alice') -> str: - """Generate a greeting for the given name. - - Args: - name: the name to send to test function - - Returns: - The generated response with a function. - """ - resp = await ai.generate( - prompt=f'hi {name}', - ) - - await logger.ainfo( - 'generation_response', - has_usage=hasattr(resp, 'usage'), - usage_dict=resp.usage.model_dump() if hasattr(resp, 'usage') and resp.usage else None, - text_length=len(resp.text), - ) - - return resp.text - - -@ai.flow() -async def say_hi_stream( - name: Annotated[str, Field(default='Alice')] = 'Alice', - ctx: ActionRunContext = None, # type: ignore[assignment] -) -> str: - """Generate a greeting for the given name. - - Args: - name: the name to send to test function - ctx: the context of the tool - - Returns: - The generated response with a function. - """ - stream, _ = ai.generate_stream(prompt=f'hi {name}') - result: str = '' - async for data in stream: - ctx.send_chunk(data.text) - result += data.text - - return result - - -@ai.flow() -async def say_hi_with_configured_temperature( - data: Annotated[str, Field(default='Alice')] = 'Alice', -) -> GenerateResponseWrapper: - """Generate a greeting for the given name. - - Args: - data: the name to send to test function - - Returns: - The generated response with a function. - """ - return await ai.generate( - messages=[Message(role=Role.USER, content=[Part(root=TextPart(text=f'hi {data}'))])], - config=GenerationCommonConfig(temperature=0.1), - ) - - -@ai.flow() -async def search_grounding() -> str: - """Search grounding.""" - response = await ai.generate( - model='googleai/gemini-3-flash-preview', - prompt='Who is Albert Einstein?', - config={'tools': [{'googleSearch': {}}], 'api_version': 'v1alpha'}, - ) - return response.text - - -@ai.flow() -async def simple_generate_with_interrupts(value: Annotated[int, Field(default=42)] = 42) -> str: - """Generate a greeting for the given name. - - Args: - value: the integer to send to test function - - Returns: - The generated response with a function. - """ - response1 = await ai.generate( - messages=[ - Message( - role=Role.USER, - content=[Part(root=TextPart(text=f'what is a gablorken of {value}'))], - ), - ], - tools=['gablorkenTool2'], - ) - await logger.ainfo(f'len(response.tool_requests)={len(response1.tool_requests)}') - if len(response1.interrupts) == 0: - return response1.text - - tr = tool_response(response1.interrupts[0], {'output': 178}) - response = await ai.generate( - messages=response1.messages, - tool_responses=[tr], - tools=['gablorkenTool'], - ) - return response.text - - -@ai.flow() -async def simple_generate_with_tools_flow( - value: Annotated[int, Field(default=42)] = 42, - ctx: ActionRunContext = None, # type: ignore[assignment] -) -> str: - """Generate a greeting for the given name. - - Args: - value: the integer to send to test function - ctx: the flow context +class ThinkingLevel(StrEnum): + """Thinking level enum.""" - Returns: - The generated response with a function. - """ - response = await ai.generate( - prompt=f'what is a gablorken of {value}', - tools=['gablorkenTool'], - on_chunk=ctx.send_chunk, - ) - return response.text + LOW = 'LOW' + HIGH = 'HIGH' @ai.flow() -async def thinking_level_flash(level: ThinkingLevelFlash) -> str: - """Gemini 3.0 thinkingLevel config (Flash).""" +async def thinking_level_pro(level: ThinkingLevel) -> str: + """Gemini 3.0 thinkingLevel config (Pro).""" response = await ai.generate( - model='googleai/gemini-3-flash-preview', + model='googleai/gemini-3-pro-preview', prompt=( 'Alice, Bob, and Carol each live in a different house on the ' 'same street: red, green, and blue. The person who lives in the red house ' @@ -553,9 +431,18 @@ async def thinking_level_flash(level: ThinkingLevelFlash) -> str: return response.text +class ThinkingLevelFlash(StrEnum): + """Thinking level flash enum.""" + + MINIMAL = 'MINIMAL' + LOW = 'LOW' + MEDIUM = 'MEDIUM' + HIGH = 'HIGH' + + @ai.flow() -async def thinking_level_pro(level: ThinkingLevel) -> str: - """Gemini 3.0 thinkingLevel config (Pro).""" +async def thinking_level_flash(level: ThinkingLevelFlash) -> str: + """Gemini 3.0 thinkingLevel config (Flash).""" response = await ai.generate( model='googleai/gemini-3-pro-preview', prompt=( @@ -577,13 +464,12 @@ async def thinking_level_pro(level: ThinkingLevel) -> str: @ai.flow() -async def tool_calling(location: Annotated[str, Field(default='Paris, France')] = 'Paris, France') -> str: - """Tool calling with Gemini.""" +async def search_grounding() -> str: + """Search grounding.""" response = await ai.generate( - model='googleai/gemini-2.5-flash', - tools=['getWeather', 'celsiusToFahrenheit'], - prompt=f"What's the weather in {location}? Convert the temperature to Fahrenheit.", - config=GenerationCommonConfig(temperature=1), + model='googleai/gemini-3-flash-preview', + prompt='Who is Albert Einstein?', + config={'tools': [{'googleSearch': {}}], 'api_version': 'v1alpha'}, ) return response.text @@ -601,6 +487,105 @@ async def url_context() -> str: return response.text +from google import genai as google_genai_sdk + +# ... existing imports ... + + +async def create_file_search_store(client: google_genai_sdk.Client) -> str: + """Creates a file search store.""" + file_search_store = await client.aio.file_search_stores.create() + if not file_search_store.name: + raise ValueError('File Search Store created without a name.') + return file_search_store.name + + +async def upload_blob_to_file_search_store(client: google_genai_sdk.Client, file_search_store_name: str): + """Uploads a blob to the file search store.""" + text_content = ( + 'The Whispering Woods In the heart of Eldergrove, there stood a forest whispered about by the villagers. ' + 'They spoke of trees that could talk and streams that sang. Young Elara, curious and adventurous, ' + 'decided to explore the woods one crisp autumn morning. As she wandered deeper, the leaves rustled with ' + 'excitement, revealing hidden paths. Elara noticed the trees bending slightly as if beckoning her to come ' + 'closer. When she paused to listen, she heard soft murmurs—stories of lost treasures and forgotten dreams. ' + 'Drawn by the enchanting sounds, she followed a narrow trail until she stumbled upon a shimmering pond. ' + 'At its edge, a wise old willow tree spoke, “Child of the village, what do you seek?” “I seek adventure,” ' + 'Elara replied, her heart racing. “Adventure lies not in faraway lands but within your spirit,” the willow ' + 'said, swaying gently. “Every choice you make is a step into the unknown.” With newfound courage, Elara left ' + 'the woods, her mind buzzing with possibilities. The villagers would say the woods were magical, but to Elara, ' + 'it was the spark of her imagination that had transformed her ordinary world into a realm of endless adventures. ' + 'She smiled, knowing her journey was just beginning' + ) + + # Create a temporary file to upload + import tempfile + + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as tmp: + tmp.write(text_content) + tmp_path = tmp.name + + try: + # Use the high-level helper to upload directly to the store with metadata + print(f'Uploading file to store {file_search_store_name}...') + op = await client.aio.file_search_stores.upload_to_file_search_store( + file_search_store_name=file_search_store_name, + file=tmp_path, + config={'custom_metadata': [{'key': 'author', 'string_value': 'foo'}]}, + ) + + # Poll for completion + while not op.done: + await asyncio.sleep(2) + # Fetch the updated operation status + op = await client.aio.operations.get(operation=op) + print(f'Operation status: {op.metadata.get("state") if op.metadata else "processing"}') + + print('Upload complete.') + + finally: + os.unlink(tmp_path) + return + + +async def delete_file_search_store(client: google_genai_sdk.Client, name: str): + """Deletes the file search store.""" + await client.aio.file_search_stores.delete(name=name, config={'force': True}) + + +@ai.flow() +async def file_search() -> str: + """File Search.""" + # Create a client using the same API Key as the plugin + api_key = os.environ.get('GEMINI_API_KEY') + client = google_genai_sdk.Client(api_key=api_key) + + # 1. Create Store + store_name = await create_file_search_store(client) + print(f'Created store: {store_name}') + + try: + # 2. Upload Blob (Story) + await upload_blob_to_file_search_store(client, store_name) + + # 3. Generate + response = await ai.generate( + model='googleai/gemini-3-flash-preview', + prompt="What is the character's name in the story?", + config={ + 'file_search': { + 'file_search_store_names': [store_name], + 'metadata_filter': 'author=foo', + }, + 'api_version': 'v1alpha', + }, + ) + return response.text + finally: + # 4. Cleanup + await delete_file_search_store(client, store_name) + print(f'Deleted store: {store_name}') + + @ai.flow() async def youtube_videos() -> str: """YouTube videos.""" @@ -617,6 +602,69 @@ async def youtube_videos() -> str: return response.text +class ScreenshotInput(BaseModel): + url: str = Field(description='The URL to take a screenshot of') + + +@ai.tool(name='screenShot') +def take_screenshot(input_: ScreenshotInput) -> dict: + """Take a screenshot of a given URL.""" + # Implement your screenshot logic here + print(f'Taking screenshot of {input_.url}') + return {'url': input_.url, 'screenshot_path': '/tmp/screenshot.png'} + + +class WeatherInput(BaseModel): + """Input for getting weather.""" + + location: str = Field(description='The city and state, e.g. San Francisco, CA') + + +@ai.tool(name='getWeather') +def get_weather(input_: WeatherInput) -> dict: + """Used to get current weather for a location.""" + return { + 'location': input_.location, + 'temperature_celcius': 21.5, + 'conditions': 'cloudy', + } + + +@ai.tool(name='celsiusToFahrenheit') +def celsius_to_fahrenheit(celsius: float) -> float: + """Converts Celsius to Fahrenheit.""" + return (celsius * 9) / 5 + 32 + + +@ai.flow() +async def tool_calling(location: Annotated[str, Field(default='Paris, France')] = 'Paris, France') -> str: + """Tool calling with Gemini.""" + response = await ai.generate( + model='googleai/gemini-2.5-flash', + tools=['getWeather', 'celsiusToFahrenheit'], + prompt=f"What's the weather in {location}? Convert the temperature to Fahrenheit.", + config=GenerationCommonConfig(temperature=1), + ) + return response.text + + +@ai.flow() +async def describe_image( + image_url: Annotated[ + str, Field(default='https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png') + ] = 'https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png', +) -> str: + """Describe an image.""" + response = await ai.generate( + model='googleai/gemini-2.5-flash', + prompt=[ + Part(root=TextPart(text='Describe this image')), + Part(root=MediaPart(media=Media(url=image_url, content_type='image/png'))), + ], + ) + return response.text + + async def main() -> None: """Main function - keep alive for Dev UI.""" await logger.ainfo('Genkit server running. Press Ctrl+C to stop.') diff --git a/py/samples/shared/__init__.py b/py/samples/shared/__init__.py deleted file mode 100644 index 1a8fb32722..0000000000 --- a/py/samples/shared/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 -"""Shared utilities and types for samples.""" - -from .flows import ( - calculation_logic, - currency_exchange_logic, - generate_character_logic, - say_hi_logic, - say_hi_stream_logic, - say_hi_with_config_logic, - weather_logic, -) -from .tools import ( - calculate, - convert_currency, - get_weather, -) -from .types import ( - CalculatorInput, - CurrencyExchangeInput, - RpgCharacter, - WeatherInput, -) - -__all__ = [ - get_weather, - convert_currency, - calculate, - weather_logic, - currency_exchange_logic, - calculation_logic, - say_hi_logic, - say_hi_stream_logic, - say_hi_with_config_logic, - WeatherInput, - CurrencyExchangeInput, - CalculatorInput, - RpgCharacter, - generate_character_logic, -] diff --git a/py/samples/shared/flows.py b/py/samples/shared/flows.py deleted file mode 100644 index 2c3dac9143..0000000000 --- a/py/samples/shared/flows.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 - -"""Common flows for samples.""" - -from typing import cast - -from genkit.ai import Genkit -from genkit.core.action import ActionRunContext - -from .types import CalculatorInput, CurrencyExchangeInput, RpgCharacter, WeatherInput - - -async def calculation_logic(ai: Genkit, input: CalculatorInput) -> str: - """Business logic to perform currency conversion via an LLM tool call. - - Args: - ai: The initialized Genkit instance. - input: Validated currency exchange parameters. - - Returns: - Conversion result. - """ - response = await ai.generate( - prompt=f'Calculate {input.a} {input.operation} {input.b}', - tools=['calculate'], - ) - - return response.text - - -async def currency_exchange_logic(ai: Genkit, input: CurrencyExchangeInput) -> str: - """Business logic to perform currency conversion via an LLM tool call. - - Args: - ai: The initialized Genkit instance. - input: Validated currency exchange parameters. - - Returns: - Conversion result. - """ - response = await ai.generate( - prompt=f'Convert {input.amount} {input.from_currency} to {input.to_currency}', - tools=['convert_currency'], - ) - - return response.text - - -async def generate_character_logic(ai: Genkit, name: str) -> RpgCharacter: - """Generate an RPG character. - - Args: - ai: The Genkit instance. - name: The name of the character - - Returns: - The generated RPG character. - """ - result = await ai.generate( - prompt=f'Generate a structured RPG character named {name}. Output ONLY the JSON object.', - output_schema=RpgCharacter, - ) - return cast(RpgCharacter, result.output) - - -async def say_hi_logic(ai: Genkit, name: str) -> str: - """Generates a simple greeting via the AI model. - - Args: - ai: The Genkit instance. - name: Name to greet. - - Returns: - Greeting message from the LLM. - """ - response = await ai.generate(prompt=f'Say hello to {name}!') - return response.text - - -async def say_hi_stream_logic(ai: Genkit, name: str, ctx: ActionRunContext) -> str: - """Generates a streaming story. - - Args: - ai: The Genkit instance. - name: Name to greet. - ctx: Action context for streaming. - """ - response = await ai.generate( - prompt=f'Tell me a short story about {name}', - on_chunk=ctx.send_chunk, - ) - return response.text - - -async def say_hi_with_config_logic(ai: Genkit, name: str) -> str: - """Generates a greeting with custom model configuration. - - Args: - ai: The Genkit instance. - name: User name. - - Returns: - Greeting message from the LLM. - """ - response = await ai.generate( - prompt=f'Write a creative greeting for {name}', - config={'temperature': 1.0, 'max_output_tokens': 200}, - ) - return response.text - - -async def weather_logic(ai: Genkit, input: WeatherInput) -> str: - """Get weather info using the weather tool (via model tool calling). - - Args: - ai: The AI model or client used to generate the weather response. - input: Weather input data. - - Returns: - Formatted weather string. - - Example: - >>> await weather_flow(WeatherInput(location='London')) - "Weather in London: 15°C, cloudy" - """ - response = await ai.generate( - prompt=f'What is the weather in {input.location}?', - tools=['get_weather'], - ) - return response.text diff --git a/py/samples/shared/tools.py b/py/samples/shared/tools.py deleted file mode 100644 index 7059d4ab72..0000000000 --- a/py/samples/shared/tools.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 -"""Common tools for samples.""" - -import operator -import random - -from .types import ( - CalculatorInput, - CurrencyExchangeInput, - WeatherInput, -) - - -def calculate(input: CalculatorInput) -> dict: - """Perform basic arithmetic operations. - - Uses a dispatch table lookup and handles arithmetic edge cases. - """ - operations = { - 'add': operator.add, - 'subtract': operator.sub, - 'multiply': operator.mul, - 'divide': operator.truediv, - } - - op_name = input.operation.lower() - handler = operations.get(op_name) - - if not handler: - return {'error': f'Unknown operation: {op_name}'} - - try: - a, b = float(input.a), float(input.b) - result = handler(a, b) - except ZeroDivisionError: - return {'error': 'Division by zero'} - except (ValueError, TypeError) as e: - return {'error': f'Invalid numeric input: {e}'} - - return { - 'operation': op_name, - 'a': a, - 'b': b, - 'result': result, - } - - -def convert_currency(input: CurrencyExchangeInput) -> str: - """Convert currency amount. - - Args: - input: Currency conversion parameters. - - Returns: - Converted amount. - """ - # Mock conversion rates - rates = { - ('USD', 'EUR'): 0.85, - ('EUR', 'USD'): 1.18, - ('USD', 'GBP'): 0.73, - ('GBP', 'USD'): 1.37, - } - rate = rates.get((input.from_currency, input.to_currency), 1.0) - converted = input.amount * rate - return f'{input.amount} {input.from_currency} = {converted:.2f} {input.to_currency}' - - -def get_weather(input: WeatherInput) -> str: - """Return a random realistic weather string for a city name. - - Args: - input: Weather input location. - - Returns: - Weather information with temperature in degree Celsius. - """ - weather_options = [ - '32° C sunny', - '17° C cloudy', - '22° C cloudy', - '19° C humid', - ] - return random.choice(weather_options) diff --git a/py/samples/shared/types.py b/py/samples/shared/types.py deleted file mode 100644 index b857faf118..0000000000 --- a/py/samples/shared/types.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 - -"""Common types for samples.""" - -from pydantic import BaseModel, Field - - -class CurrencyExchangeInput(BaseModel): - """Currency conversion input schema.""" - - amount: float = Field(description='Amount to convert', default=100) - from_currency: str = Field(description='Source currency code (e.g., USD)', default='USD') - to_currency: str = Field(description='Target currency code (e.g., EUR)', default='EUR') - - -class CalculatorInput(BaseModel): - """Input for the calculator tool.""" - - operation: str = Field(description='Math operation: add, subtract, multiply, divide', default='add') - a: float = Field(description='First number', default=123) - b: float = Field(description='Second number', default=321) - - -class Skills(BaseModel): - """A set of core character skills for an RPG character.""" - - strength: int = Field(description='strength (0-100)') - charisma: int = Field(description='charisma (0-100)') - endurance: int = Field(description='endurance (0-100)') - - -class RpgCharacter(BaseModel): - """An RPG character.""" - - name: str = Field(description='name of the character') - back_story: str = Field(description='back story', alias='backStory') - abilities: list[str] = Field(description='list of abilities (3-4)') - skills: Skills - - -class WeatherInput(BaseModel): - """Input for the weather tool.""" - - location: str = Field(description='City or location name', default='New York') diff --git a/py/samples/xai-hello/src/main.py b/py/samples/xai-hello/src/main.py index 9c543a37a9..393cf3607a 100755 --- a/py/samples/xai-hello/src/main.py +++ b/py/samples/xai-hello/src/main.py @@ -35,30 +35,14 @@ """ import os -from typing import Annotated +from typing import Annotated, cast import structlog -from pydantic import Field +from pydantic import BaseModel, Field from genkit.ai import Genkit from genkit.core.action import ActionRunContext from genkit.plugins.xai import XAI, xai_name -from samples.shared import ( - CalculatorInput, - CurrencyExchangeInput, - RpgCharacter, - WeatherInput, - calculate, - calculation_logic, - convert_currency, - currency_exchange_logic, - generate_character_logic, - get_weather, - say_hi_logic, - say_hi_stream_logic, - say_hi_with_config_logic, - weather_logic, -) if 'XAI_API_KEY' not in os.environ: os.environ['XAI_API_KEY'] = input('Please enter your XAI_API_KEY: ') @@ -71,28 +55,148 @@ ) -# Decorated tools -ai.tool()(get_weather) -ai.tool()(convert_currency) -ai.tool()(calculate) +class CalculatorInput(BaseModel): + """Input for the calculator tool.""" + + operation: str = Field(description='Math operation: add, subtract, multiply, divide') + a: float = Field(description='First number') + b: float = Field(description='Second number') + + +class CurrencyExchangeInput(BaseModel): + """Currency exchange flow input schema.""" + + amount: float = Field(description='Amount to convert', default=100) + from_curr: str = Field(description='Source currency code', default='USD') + to_curr: str = Field(description='Target currency code', default='EUR') + + +class CurrencyInput(BaseModel): + """Currency conversion input schema.""" + + amount: float = Field(description='Amount to convert', default=100) + from_currency: str = Field(description='Source currency code (e.g., USD)', default='USD') + to_currency: str = Field(description='Target currency code (e.g., EUR)', default='EUR') + + +class Skills(BaseModel): + """A set of core character skills for an RPG character.""" + + strength: int = Field(description='strength (0-100)') + charisma: int = Field(description='charisma (0-100)') + endurance: int = Field(description='endurance (0-100)') + + +class RpgCharacter(BaseModel): + """An RPG character.""" + + name: str = Field(description='name of the character') + back_story: str = Field(description='back story', alias='backStory') + abilities: list[str] = Field(description='list of abilities (3-4)') + skills: Skills + + +class WeatherInput(BaseModel): + """Input for the weather tool.""" + + location: str = Field(description='City or location name') + unit: str = Field(default='celsius', description='Temperature unit: celsius or fahrenheit') + + +@ai.tool() +def calculate(input: CalculatorInput) -> dict: + """Perform basic arithmetic operations. + + Args: + input: Calculation request input. + + Returns: + Calculation result dictionary. + """ + operations = { + 'add': lambda a, b: a + b, + 'subtract': lambda a, b: a - b, + 'multiply': lambda a, b: a * b, + 'divide': lambda a, b: a / b if b != 0 else None, + } + + operation = input.operation.lower() + if operation not in operations: + return {'error': f'Unknown operation: {operation}'} + + result = operations[operation](input.a, input.b) + return { + 'operation': operation, + 'a': input.a, + 'b': input.b, + 'result': result, + } @ai.flow() -async def currency_exchange_flow(input_data: CurrencyExchangeInput) -> str: - """Genkit entry point for the currency exchange flow. +async def calculator_flow(expression: Annotated[str, Field(default='add_5_3')] = 'add_5_3') -> str: + """Parse and calculate a math expression. + + Args: + expression: String in format 'operation_a_b'. - Exposes conversion logic as a traceable Genkit flow. + Returns: + Calculation result string. + + Example: + >>> await calculator_flow('add_5_3') + "Add(5.0, 3.0) = 8.0" """ - return await currency_exchange_logic(ai, input_data) + parts = expression.split('_') + if len(parts) < 3: + return 'Invalid expression format. Use: operation_a_b (e.g., add_5_3)' + + operation, a, b = parts[0], float(parts[1]), float(parts[2]) + result = calculate(CalculatorInput(operation=operation, a=a, b=b)) + if 'error' in result: + return f'Error: {result["error"]}' + return f'{operation.title()}({a}, {b}) = {result.get("result")}' + + +@ai.tool() +def convert_currency(input: CurrencyInput) -> str: + """Convert currency amount. + + Args: + input: Currency conversion parameters. + + Returns: + Converted amount. + """ + # Mock conversion rates + rates = { + ('USD', 'EUR'): 0.85, + ('EUR', 'USD'): 1.18, + ('USD', 'GBP'): 0.73, + ('GBP', 'USD'): 1.37, + } + + rate = rates.get((input.from_currency, input.to_currency), 1.0) + converted = input.amount * rate + + return f'{input.amount} {input.from_currency} = {converted:.2f} {input.to_currency}' @ai.flow() -async def calculator_flow(input_data: CalculatorInput) -> str: - """Genkit entry point for the calculator flow. +async def currency_exchange(input: CurrencyExchangeInput) -> str: + """Convert currency using tools. + + Args: + input: Currency exchange parameters. - Exposes calculation logic as a traceable Genkit flow. + Returns: + Conversion result. """ - return await calculation_logic(ai, input_data) + response = await ai.generate( + prompt=f'Convert {input.amount} {input.from_curr} to {input.to_curr}', + tools=['convert_currency'], + ) + return response.text @ai.flow() @@ -107,7 +211,43 @@ async def generate_character( Returns: The generated RPG character. """ - return await generate_character_logic(ai, name) + result = await ai.generate( + prompt=f'generate an RPG character named {name}', + output_schema=RpgCharacter, + ) + return cast(RpgCharacter, result.output) + + +@ai.tool() +def get_weather(input: WeatherInput) -> dict: + """Get weather information for a location. + + Args: + input: Weather request input. + + Returns: + Weather data dictionary. + """ + weather_data = { + 'New York': {'temp': 15, 'condition': 'cloudy', 'humidity': 65}, + 'London': {'temp': 12, 'condition': 'rainy', 'humidity': 78}, + 'Tokyo': {'temp': 20, 'condition': 'sunny', 'humidity': 55}, + 'Paris': {'temp': 14, 'condition': 'partly cloudy', 'humidity': 60}, + } + + location = input.location.title() + data = weather_data.get(location, {'temp': 18, 'condition': 'unknown', 'humidity': 50}) + + if input.unit == 'fahrenheit' and 'temp' in data: + temp = data['temp'] + if isinstance(temp, (int, float)): + data['temp'] = round((temp * 9 / 5) + 32, 1) + data['unit'] = 'F' + else: + data['unit'] = 'C' + + data['location'] = location + return data @ai.flow() @@ -124,7 +264,8 @@ async def say_hi(name: Annotated[str, Field(default='Alice')] = 'Alice') -> str: >>> await say_hi('Alice') "Hello Alice!" """ - return await say_hi_logic(ai, name) + response = await ai.generate(prompt=f'Say hello to {name}!') + return response.text @ai.flow() @@ -145,7 +286,11 @@ async def say_hi_stream( >>> await say_hi_stream('Bob', ctx) "Once upon a time..." """ - return await say_hi_stream_logic(ai, name, ctx) + response = await ai.generate( + prompt=f'Tell me a short story about {name}', + on_chunk=ctx.send_chunk, + ) + return response.text @ai.flow() @@ -157,17 +302,37 @@ async def say_hi_with_config(name: Annotated[str, Field(default='Charlie')] = 'C Returns: Greeting message. + + Example: + >>> await say_hi_with_config('Charlie') + "Greetings, Charlie!" """ - return await say_hi_with_config_logic(ai, name) + response = await ai.generate( + prompt=f'Write a creative greeting for {name}', + config={'temperature': 1.0, 'max_output_tokens': 200}, + ) + return response.text @ai.flow() -async def weather_flow(input_data: WeatherInput) -> str: - """Genkit entry point for the weather information flow. +async def weather_flow(location: Annotated[str, Field(default='New York')] = 'New York') -> str: + """Get weather info using the weather tool (via model tool calling). + + Args: + location: City name. - Exposes weather logic as a traceable Genkit flow. + Returns: + Formatted weather string. + + Example: + >>> await weather_flow('New York') + "Weather in New York: 15°C, cloudy" """ - return await weather_logic(ai, input_data) + response = await ai.generate( + prompt=f'What is the weather in {location}?', + tools=['get_weather'], + ) + return response.text async def main() -> None: