diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index ab40d25d..d65bf02e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -3,6 +3,7 @@ from __future__ import annotations +import asyncio import functools import logging from abc import ABC, abstractmethod @@ -43,6 +44,14 @@ def generate(self, data: pd.DataFrame) -> pd.DataFrame: ... @abstractmethod def generate(self, data: DataT) -> DataT: ... + async def agenerate(self, data: dict) -> dict: + """Async fallback — delegates to sync generate via thread pool. + + Subclasses with native async support (e.g. ColumnGeneratorWithModelChatCompletion) + should override this with a direct async implementation. + """ + return await asyncio.to_thread(self.generate, data) + def log_pre_generation(self) -> None: """A shared method to log info before the generator's `generate` method is called. diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py index 3cc31e46..a0ff447a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py @@ -5,6 +5,7 @@ import functools import logging +from typing import TYPE_CHECKING, Any from data_designer.config.column_configs import ( LLMCodeColumnConfig, @@ -24,6 +25,9 @@ from data_designer.engine.models.recipes.base import ResponseRecipe from data_designer.engine.processing.utils import deserialize_json_values +if TYPE_CHECKING: + from data_designer.engine.models.utils import ChatMessage + logger = logging.getLogger(__name__) @@ -56,36 +60,55 @@ def prompt_renderer(self) -> RecordBasedPromptRenderer: ) def generate(self, data: dict) -> dict: + kwargs = self._prepare_generation_kwargs(data) + response, trace = self.model.generate(**kwargs) + return self._process_generation_result(data, response, trace) + + async def agenerate(self, data: dict) -> dict: + kwargs = self._prepare_generation_kwargs(data) + response, trace = await self.model.agenerate(**kwargs) + return self._process_generation_result(data, response, trace) + + def _prepare_generation_kwargs(self, data: dict) -> dict[str, Any]: + """Prepare keyword arguments for model.generate() / model.agenerate(). + + Deserializes input data, builds multi-modal context, and renders prompts. + """ # Deserialize input data from previous columns so Jinja2 templates can access nested fields # Example: If prev column stored '{"key": "value"}', templates can use {{ prev_column.key }} # Note: This creates a new dict and doesn't mutate the original `data` argument deserialized_record = deserialize_json_values(data) - multi_modal_context = None + multi_modal_context: list[dict[str, Any]] | None = None if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0: multi_modal_context = [] for context in self.config.multi_modal_context: multi_modal_context.extend(context.get_contexts(deserialized_record)) - response, trace = self.model.generate( - prompt=self.prompt_renderer.render( + return { + "prompt": self.prompt_renderer.render( record=deserialized_record, prompt_template=self.config.prompt, prompt_type=PromptType.USER_PROMPT, ), - system_prompt=self.prompt_renderer.render( + "system_prompt": self.prompt_renderer.render( record=deserialized_record, prompt_template=self.config.system_prompt, prompt_type=PromptType.SYSTEM_PROMPT, ), - parser=self.response_recipe.parse, - multi_modal_context=multi_modal_context, - tool_alias=self.config.tool_alias, - max_correction_steps=self.max_conversation_correction_steps, - max_conversation_restarts=self.max_conversation_restarts, - purpose=f"running generation for column '{self.config.name}'", - ) - + "parser": self.response_recipe.parse, + "multi_modal_context": multi_modal_context, + "tool_alias": self.config.tool_alias, + "max_correction_steps": self.max_conversation_correction_steps, + "max_conversation_restarts": self.max_conversation_restarts, + "purpose": f"running generation for column '{self.config.name}'", + } + + def _process_generation_result(self, data: dict, response: Any, trace: list[ChatMessage]) -> dict: + """Process model response and trace into the output data dict. + + Serializes the response, applies trace column logic, and extracts reasoning content. + """ serialized_output = self.response_recipe.serialize_output(response) data[self.config.name] = self._process_serialized_output(serialized_output) @@ -102,7 +125,7 @@ def generate(self, data: dict) -> dict: return data - def _extract_reasoning_content(self, trace: list) -> str | None: + def _extract_reasoning_content(self, trace: list[ChatMessage]) -> str | None: """Extract reasoning_content from the final assistant message in the trace. Args: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 6e42844b..1accdaaf 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -5,10 +5,11 @@ import functools import logging +import os import time import uuid from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable from data_designer.config.column_configs import CustomColumnConfig from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType @@ -30,6 +31,7 @@ from data_designer.engine.dataset_builders.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig +from data_designer.engine.dataset_builders.utils.async_concurrency import AsyncConcurrentExecutor from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager @@ -51,6 +53,11 @@ logger = logging.getLogger(__name__) +DATA_DESIGNER_ASYNC_ENGINE = os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "0") == "1" + +if DATA_DESIGNER_ASYNC_ENGINE: + logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled — using async concurrency") + _CLIENT_VERSION: str = get_library_version() @@ -255,7 +262,11 @@ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None: max_workers = self._resource_provider.run_config.non_inference_max_parallel_workers if isinstance(generator, ColumnGeneratorWithModel): max_workers = generator.inference_parameters.max_parallel_requests - self._fan_out_with_threads(generator, max_workers=max_workers) + if DATA_DESIGNER_ASYNC_ENGINE: + logger.info("⚡ Using async engine for concurrent execution") + self._fan_out_with_async(generator, max_workers=max_workers) + else: + self._fan_out_with_threads(generator, max_workers=max_workers) def _run_full_column_generator(self, generator: ColumnGenerator) -> None: df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True)) @@ -282,16 +293,15 @@ def _run_mcp_tool_check_if_needed(self) -> None: raise DatasetGenerationError(f"Tool alias(es) {tool_aliases!r} specified but no MCPRegistry configured.") self._resource_provider.mcp_registry.run_health_check(tool_aliases) - def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None: + def _setup_fan_out( + self, generator: ColumnGeneratorWithModelRegistry, max_workers: int + ) -> tuple[ProgressTracker, dict[str, Any]]: if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL: raise DatasetGenerationError( f"Generator {generator.name} is not a {GenerationStrategy.CELL_BY_CELL} " - "generator so concurrency through threads is not supported." + "generator so concurrent fan-out is not supported." ) - if getattr(generator.config, "tool_alias", None): - logger.info("🛠️ Tool calling enabled") - progress_tracker = ProgressTracker( total_records=self.batch_manager.num_records_batch, label=f"{generator.config.column_type} column '{generator.config.name}'", @@ -299,25 +309,42 @@ def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max progress_tracker.log_start(max_workers) settings = self._resource_provider.run_config - with ConcurrentThreadExecutor( - max_workers=max_workers, - column_name=generator.config.name, - result_callback=self._make_result_callback(progress_tracker), - error_callback=self._make_error_callback(progress_tracker), - shutdown_error_rate=settings.shutdown_error_rate, - shutdown_error_window=settings.shutdown_error_window, - disable_early_shutdown=settings.disable_early_shutdown, - ) as executor: - for i, record in self.batch_manager.iter_current_batch(): - executor.submit(lambda record: generator.generate(record), record, context={"index": i}) - + executor_kwargs: dict = { + "column_name": generator.config.name, + "result_callback": self._make_result_callback(progress_tracker), + "error_callback": self._make_error_callback(progress_tracker), + "shutdown_error_rate": settings.shutdown_error_rate, + "shutdown_error_window": settings.shutdown_error_window, + "disable_early_shutdown": settings.disable_early_shutdown, + } + + return progress_tracker, executor_kwargs + + def _finalize_fan_out(self, progress_tracker: ProgressTracker) -> None: progress_tracker.log_final() - if len(self._records_to_drop) > 0: self._cleanup_dropped_record_images(self._records_to_drop) self.batch_manager.drop_records(self._records_to_drop) self._records_to_drop.clear() + def _fan_out_with_async(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None: + progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers) + executor = AsyncConcurrentExecutor(max_workers=max_workers, **executor_kwargs) + work_items = [ + (generator.agenerate(record), {"index": i}) for i, record in self.batch_manager.iter_current_batch() + ] + executor.run(work_items) + self._finalize_fan_out(progress_tracker) + + def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None: + if getattr(generator.config, "tool_alias", None): + logger.info("🛠️ Tool calling enabled") + progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers) + with ConcurrentThreadExecutor(max_workers=max_workers, **executor_kwargs) as executor: + for i, record in self.batch_manager.iter_current_batch(): + executor.submit(lambda record: generator.generate(record), record, context={"index": i}) + self._finalize_fan_out(progress_tracker) + def _make_result_callback(self, progress_tracker: ProgressTracker) -> Callable[[dict], None]: def callback(result: dict, *, context: dict | None = None) -> None: self._worker_result_callback(result, context=context) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py new file mode 100644 index 00000000..28ae22b5 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py @@ -0,0 +1,225 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Async batch execution with bounded concurrency and early-shutdown semantics. + +Async counterpart to ``concurrency.py``. Same operational contract (callbacks +with optional context, error aggregation, shutdown thresholds), different +runtime model. The sync module runs callables in a ``ThreadPoolExecutor``; +this module runs coroutines in ``asyncio.TaskGroup`` on a dedicated loop +thread. Callers stay synchronous. + +Architecture: + ``AsyncConcurrentExecutor.run()`` is a blocking call that submits + coroutines to a shared background event loop via + ``run_coroutine_threadsafe``. Bounded concurrency is enforced with an + ``asyncio.Semaphore``. Success/error counts use the same + ``ExecutorResults`` model as the sync executor. + + Caller Thread ──► run() ──► run_coroutine_threadsafe ──► Background Loop + (TaskGroup) + +Singleton Event Loop: + The background loop is a process-wide singleton. LiteLLM and similar + libraries bind internal async state to a specific event loop, so creating + per-call or per-instance loops breaks connection reuse and triggers + cross-loop errors. ``_ensure_async_engine_loop()`` creates one daemon + loop thread and reuses it for all executor instances. + +Startup Handshake: + Loop creation uses a ``threading.Event`` readiness handshake. The + background thread signals readiness via ``loop.call_soon(ready.set)``, + and the creating thread holds the lock until that event fires (or a + timeout expires). This prevents a race where a second caller could see + ``_loop.is_running() == False`` before the first loop has entered + ``run_forever()``, which would create a duplicate loop. On timeout, + globals are reset and the orphaned loop is cleaned up before raising. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import threading +from collections.abc import Coroutine +from dataclasses import dataclass +from typing import Any, Generic, TypeVar + +from data_designer.engine.dataset_builders.utils.concurrency import ( + CallbackWithContext, + ErrorCallbackWithContext, + ExecutorResults, +) +from data_designer.engine.errors import DataDesignerRuntimeError + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +@dataclass(frozen=True, slots=True) +class Success(Generic[T]): + index: int + value: T + + +@dataclass(frozen=True, slots=True) +class Failure: + index: int + error: Exception + + +TaskResult = Success[T] | Failure + +_loop: asyncio.AbstractEventLoop | None = None +_thread: threading.Thread | None = None +_lock = threading.Lock() + +_LOOP_READY_TIMEOUT = 5.0 # seconds to wait for the background loop to start + + +def _run_loop(loop: asyncio.AbstractEventLoop, ready: threading.Event) -> None: + asyncio.set_event_loop(loop) + loop.call_soon(ready.set) + loop.run_forever() + + +def _ensure_async_engine_loop() -> asyncio.AbstractEventLoop: + """Get or create a persistent event loop for async engine work. + + A single event loop is shared across all AsyncConcurrentExecutor instances + to avoid breaking libraries (like LiteLLM) that bind internal async state + to a specific event loop. + """ + global _loop, _thread + with _lock: + if _loop is None or not _loop.is_running(): + ready = threading.Event() + _loop = asyncio.new_event_loop() + _thread = threading.Thread(target=_run_loop, args=(_loop, ready), daemon=True, name="AsyncEngine-EventLoop") + _thread.start() + if not ready.wait(timeout=_LOOP_READY_TIMEOUT): + orphan_loop = _loop + orphan_thread = _thread + _loop = None + _thread = None + + if orphan_loop is not None: + try: + if orphan_thread is not None and orphan_thread.is_alive(): + orphan_loop.call_soon_threadsafe(orphan_loop.stop) + if not orphan_loop.is_running(): + orphan_loop.close() + except Exception: + logger.warning("Failed to clean up timed-out AsyncEngine loop startup", exc_info=True) + + raise RuntimeError("AsyncEngine event loop failed to start within timeout") + return _loop + + +class AsyncConcurrentExecutor: + """Async equivalent of ConcurrentThreadExecutor. + + Executes a batch of coroutines with bounded concurrency, error rate + monitoring, and early shutdown semantics. Callers remain synchronous — + the ``run()`` method submits work to a persistent background event loop. + + No locks are needed because asyncio tasks run cooperatively on a + single thread — mutations to ``_results`` are always sequential. + """ + + def __init__( + self, + *, + max_workers: int, + column_name: str, + result_callback: CallbackWithContext | None = None, + error_callback: ErrorCallbackWithContext | None = None, + shutdown_error_rate: float = 0.50, + shutdown_error_window: int = 10, + disable_early_shutdown: bool = False, + ) -> None: + self._column_name = column_name + self._max_workers = max_workers + self._result_callback = result_callback + self._error_callback = error_callback + self._shutdown_error_rate = shutdown_error_rate + self._shutdown_window_size = shutdown_error_window + self._disable_early_shutdown = disable_early_shutdown + self._results = ExecutorResults(failure_threshold=shutdown_error_rate) + + @property + def results(self) -> ExecutorResults: + return self._results + + @property + def max_workers(self) -> int: + return self._max_workers + + @property + def shutdown_error_rate(self) -> float: + return self._shutdown_error_rate + + @property + def shutdown_window_size(self) -> int: + return self._shutdown_window_size + + def run(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> None: + """Execute all work items concurrently. Callers remain synchronous.""" + logger.debug( + f"AsyncConcurrentExecutor: launching {len(work_items)} tasks " + f"with max_workers={self._max_workers} for column '{self._column_name}'" + ) + loop = _ensure_async_engine_loop() + future = asyncio.run_coroutine_threadsafe(self._run_all(work_items), loop) + future.result() + + async def _run_all(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> None: + self._semaphore = asyncio.Semaphore(self._max_workers) + self._shutdown_event = asyncio.Event() + + async with asyncio.TaskGroup() as tg: + for i, (coro, context) in enumerate(work_items): + tg.create_task(self._run_task(i, coro, context)) + + if not self._disable_early_shutdown and self._results.early_shutdown: + self._raise_task_error() + + async def _run_task(self, index: int, coro: Coroutine[Any, Any, Any], context: dict | None) -> None: + if self._shutdown_event.is_set(): + coro.close() + return + + async with self._semaphore: + if self._shutdown_event.is_set(): + coro.close() + return + + try: + result = await coro + self._results.completed_count += 1 + self._results.success_count += 1 + if self._result_callback is not None: + self._result_callback(result, context=context) + except Exception as err: + self._results.completed_count += 1 + self._results.error_trap.handle_error(err) + if not self._disable_early_shutdown and self._results.is_error_rate_exceeded( + self._shutdown_window_size + ): + if not self._results.early_shutdown: + self._results.early_shutdown = True + self._shutdown_event.set() + if self._error_callback is not None: + self._error_callback(err, context=context) + + def _raise_task_error(self) -> None: + raise DataDesignerRuntimeError( + "\n".join( + [ + " |-- Data generation was terminated early due to error rate exceeding threshold.", + f" |-- The summary of encountered errors is: \n{json.dumps(self._results.summary, indent=4)}", + ] + ) + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/__init__.py index 52a7a9da..e5725ea5 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/__init__.py @@ -1,2 +1,2 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py index 8ca1ebfd..b3195730 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py @@ -251,6 +251,31 @@ def wrapper(model_facade: Any, *args, **kwargs): return wrapper +def acatch_llm_exceptions(func: Callable) -> Callable: + @wraps(func) + async def wrapper(model_facade: Any, *args: Any, **kwargs: Any) -> Any: + try: + return await func(model_facade, *args, **kwargs) + except Exception as e: + logger.debug( + "\n".join( + [ + "", + "|----------", + f"| Caught an exception downstream of type {type(e)!r}. Re-raising it below as a custom error with more context.", + "|----------", + ] + ), + exc_info=True, + stack_info=True, + ) + handle_llm_exceptions( + e, model_facade.model_name, model_facade.model_provider_name, purpose=kwargs.get("purpose") + ) + + return wrapper + + class DownstreamLLMExceptionMessageParser: def __init__(self, model_name: str, model_provider_name: str, purpose: str): self.model_name = model_name diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index ef328a9a..9e44c6d6 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -3,6 +3,7 @@ from __future__ import annotations +import asyncio import logging from collections.abc import Callable from copy import deepcopy @@ -20,6 +21,7 @@ from data_designer.engine.models.errors import ( GenerationValidationFailureError, ImageGenerationError, + acatch_llm_exceptions, catch_llm_exceptions, get_exception_primary_cause, ) @@ -108,6 +110,10 @@ def model_provider_name(self) -> str: def model_alias(self) -> str: return self._model_config.alias + @property + def max_parallel_requests(self) -> int: + return self._model_config.inference_parameters.max_parallel_requests + @property def usage_stats(self) -> ModelUsageStats: return self._usage_stats @@ -511,6 +517,7 @@ def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.Deployme model=f"{provider.provider_type}/{model_config.model}", api_base=provider.endpoint, api_key=api_key, + max_parallel_requests=model_config.inference_parameters.max_parallel_requests, ) return { "model_name": model_config.model, @@ -564,3 +571,327 @@ def _track_token_usage_from_image_diffusion(self, response: litellm.types.utils. else: # Successful response but no token usage data (some providers don't report it) self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=1, failed_requests=0)) + + async def acompletion( + self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any + ) -> litellm.ModelResponse: + message_payloads = [message.to_dict() for message in messages] + logger.debug( + f"Prompting model {self.model_name!r}...", + extra={"model": self.model_name, "messages": message_payloads}, + ) + response = None + kwargs = self.consolidate_kwargs(**kwargs) + try: + response = await self._router.acompletion(model=self.model_name, messages=message_payloads, **kwargs) + logger.debug( + f"Received completion from model {self.model_name!r}", + extra={ + "model": self.model_name, + "response": response, + "text": response.choices[0].message.content, + "usage": self._usage_stats.model_dump(), + }, + ) + return response + except Exception as e: + raise e + finally: + if not skip_usage_tracking and response is not None: + self._track_token_usage_from_completion(response) + + @acatch_llm_exceptions + async def agenerate_text_embeddings( + self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs: Any + ) -> list[list[float]]: + logger.debug( + f"Generating embeddings with model {self.model_name!r}...", + extra={ + "model": self.model_name, + "input_count": len(input_texts), + }, + ) + kwargs = self.consolidate_kwargs(**kwargs) + response = None + try: + response = await self._router.aembedding(model=self.model_name, input=input_texts, **kwargs) + logger.debug( + f"Received embeddings from model {self.model_name!r}", + extra={ + "model": self.model_name, + "embedding_count": len(response.data) if response.data else 0, + "usage": self._usage_stats.model_dump(), + }, + ) + if response.data and len(response.data) == len(input_texts): + return [data["embedding"] for data in response.data] + else: + raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}") + except Exception as e: + raise e + finally: + if not skip_usage_tracking and response is not None: + self._track_token_usage_from_embedding(response) + + @acatch_llm_exceptions + async def agenerate( + self, + prompt: str, + *, + parser: Callable[[str], Any] = _identity, + system_prompt: str | None = None, + multi_modal_context: list[dict[str, Any]] | None = None, + tool_alias: str | None = None, + max_correction_steps: int = 0, + max_conversation_restarts: int = 0, + skip_usage_tracking: bool = False, + purpose: str | None = None, + **kwargs: Any, + ) -> tuple[Any, list[ChatMessage]]: + output_obj = None + tool_schemas = None + tool_call_turns = 0 + total_tool_calls = 0 + curr_num_correction_steps = 0 + curr_num_restarts = 0 + + mcp_facade = self._get_mcp_facade(tool_alias) + + restart_checkpoint = prompt_to_messages( + user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context + ) + checkpoint_tool_call_turns = 0 + messages: list[ChatMessage] = deepcopy(restart_checkpoint) + + if mcp_facade is not None: + tool_schemas = await asyncio.to_thread(mcp_facade.get_tool_schemas) + + while True: + completion_kwargs = dict(kwargs) + if tool_schemas is not None: + completion_kwargs["tools"] = tool_schemas + + completion_response = await self.acompletion( + messages, + skip_usage_tracking=skip_usage_tracking, + **completion_kwargs, + ) + + if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response): + tool_call_turns += 1 + total_tool_calls += mcp_facade.tool_call_count(completion_response) + + if tool_call_turns > mcp_facade.max_tool_call_turns: + messages.extend(mcp_facade.refuse_completion_response(completion_response)) + else: + messages.extend( + await asyncio.to_thread(mcp_facade.process_completion_response, completion_response) + ) + + restart_checkpoint = deepcopy(messages) + checkpoint_tool_call_turns = tool_call_turns + + continue + + response = (completion_response.choices[0].message.content or "").strip() + reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None) + messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) + curr_num_correction_steps += 1 + + try: + output_obj = parser(response) + break + except ParserException as exc: + if max_correction_steps == 0 and max_conversation_restarts == 0: + raise GenerationValidationFailureError( + "Unsuccessful generation attempt. No retries were attempted." + ) from exc + + if curr_num_correction_steps <= max_correction_steps: + messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc)))) + + elif curr_num_restarts < max_conversation_restarts: + curr_num_correction_steps = 0 + curr_num_restarts += 1 + messages = deepcopy(restart_checkpoint) + tool_call_turns = checkpoint_tool_call_turns + + else: + raise GenerationValidationFailureError( + f"Unsuccessful generation despite {max_correction_steps} correction steps " + f"and {max_conversation_restarts} conversation restarts." + ) from exc + + if not skip_usage_tracking and mcp_facade is not None: + self._usage_stats.tool_usage.extend( + tool_calls=total_tool_calls, + tool_call_turns=tool_call_turns, + ) + + return output_obj, messages + + @acatch_llm_exceptions + async def agenerate_image( + self, + prompt: str, + multi_modal_context: list[dict[str, Any]] | None = None, + skip_usage_tracking: bool = False, + **kwargs: Any, + ) -> list[str]: + """Async version of generate_image. Generate image(s) and return base64-encoded data. + + Automatically detects the appropriate API based on model name: + - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) → image_generation API + - All other models → chat/completions API (default) + + Both paths return base64-encoded image data. If the API returns multiple images, + all are returned in the list. + + Args: + prompt: The prompt for image generation + multi_modal_context: Optional list of image contexts for multi-modal generation. + Only used with autoregressive models via chat completions API. + skip_usage_tracking: Whether to skip usage tracking + **kwargs: Additional arguments to pass to the model (including n=number of images) + + Returns: + List of base64-encoded image strings (without data URI prefix) + + Raises: + ImageGenerationError: If image generation fails or returns invalid data + """ + logger.debug( + f"Generating image with model {self.model_name!r}...", + extra={"model": self.model_name, "prompt": prompt}, + ) + + # Auto-detect API type based on model name + if is_image_diffusion_model(self.model_name): + images = await self._agenerate_image_diffusion(prompt, skip_usage_tracking, **kwargs) + else: + images = await self._agenerate_image_chat_completion( + prompt, multi_modal_context, skip_usage_tracking, **kwargs + ) + + # Track image usage + if not skip_usage_tracking and len(images) > 0: + self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) + + return images + + async def _agenerate_image_chat_completion( + self, + prompt: str, + multi_modal_context: list[dict[str, Any]] | None = None, + skip_usage_tracking: bool = False, + **kwargs: Any, + ) -> list[str]: + """Async version of _generate_image_chat_completion. + + Generate image(s) using autoregressive model via chat completions API. + + Args: + prompt: The prompt for image generation + multi_modal_context: Optional list of image contexts for multi-modal generation + skip_usage_tracking: Whether to skip usage tracking + **kwargs: Additional arguments to pass to the model + + Returns: + List of base64-encoded image strings + """ + messages = prompt_to_messages(user_prompt=prompt, multi_modal_context=multi_modal_context) + + response = None + try: + response = await self.acompletion( + messages=messages, + skip_usage_tracking=skip_usage_tracking, + **kwargs, + ) + + logger.debug( + f"Received image(s) from autoregressive model {self.model_name!r}", + extra={"model": self.model_name, "response": response}, + ) + + # Validate response structure + if not response.choices or len(response.choices) == 0: + raise ImageGenerationError("Image generation response missing choices") + + message = response.choices[0].message + images = [] + + # Extract base64 from images attribute (primary path) + if hasattr(message, "images") and message.images: + for image in message.images: + # Handle different response formats + if isinstance(image, dict) and "image_url" in image: + image_url = image["image_url"] + + if isinstance(image_url, dict) and "url" in image_url: + if (b64 := _try_extract_base64(image_url["url"])) is not None: + images.append(b64) + elif isinstance(image_url, str): + if (b64 := _try_extract_base64(image_url)) is not None: + images.append(b64) + # Fallback: treat as base64 string + elif isinstance(image, str): + if (b64 := _try_extract_base64(image)) is not None: + images.append(b64) + + # Fallback: check content field if it looks like image data + if not images: + content = message.content or "" + if content and (content.startswith("data:image/") or is_base64_image(content)): + if (b64 := _try_extract_base64(content)) is not None: + images.append(b64) + + if not images: + raise ImageGenerationError("No image data found in image generation response") + + return images + + except Exception: + raise + + async def _agenerate_image_diffusion( + self, prompt: str, skip_usage_tracking: bool = False, **kwargs: Any + ) -> list[str]: + """Async version of _generate_image_diffusion. + + Generate image(s) using diffusion model via image_generation API. + + Always returns base64. If the API returns URLs instead of inline base64, + the images are downloaded and converted automatically. + + Returns: + List of base64-encoded image strings + """ + kwargs = self.consolidate_kwargs(**kwargs) + + response = None + + try: + response = await self._router.aimage_generation(prompt=prompt, model=self.model_name, **kwargs) + + logger.debug( + f"Received {len(response.data)} image(s) from diffusion model {self.model_name!r}", + extra={"model": self.model_name, "response": response}, + ) + + # Validate response + if not response.data or len(response.data) == 0: + raise ImageGenerationError("Image generation returned no data") + + images = [b64 for img in response.data if (b64 := _try_extract_base64(img)) is not None] + + if not images: + raise ImageGenerationError("No image data could be extracted from response") + + return images + + except Exception: + raise + finally: + if not skip_usage_tracking and response is not None: + self._track_token_usage_from_image_diffusion(response) diff --git a/packages/data-designer-engine/tests/engine/models/test_async_engine_switch.py b/packages/data-designer-engine/tests/engine/models/test_async_engine_switch.py new file mode 100644 index 00000000..685c9289 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/test_async_engine_switch.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from data_designer.config.column_configs import GenerationStrategy +from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder +from data_designer.engine.models.facade import ModelFacade + + +def test_model_facade_has_async_methods() -> None: + """ModelFacade exposes async variants of its core methods.""" + assert hasattr(ModelFacade, "acompletion") + assert hasattr(ModelFacade, "agenerate") + assert hasattr(ModelFacade, "agenerate_text_embeddings") + + +def test_model_facade_has_sync_methods() -> None: + """ModelFacade exposes synchronous core methods.""" + assert hasattr(ModelFacade, "completion") + assert hasattr(ModelFacade, "generate") + assert hasattr(ModelFacade, "generate_text_embeddings") + + +def test_async_engine_env_controls_builder_execution_path(monkeypatch: pytest.MonkeyPatch) -> None: + """When DATA_DESIGNER_ASYNC_ENGINE is set, _run_cell_by_cell_generator dispatches to async fan-out.""" + import data_designer.engine.dataset_builders.column_wise_builder as cwb_module + + mock_generator = MagicMock() + mock_generator.get_generation_strategy.return_value = GenerationStrategy.CELL_BY_CELL + mock_generator.inference_parameters.max_parallel_requests = 4 + + builder = MagicMock() + builder._resource_provider.run_config.non_inference_max_parallel_workers = 4 + + # Test with async enabled — uses max_parallel_requests from generator (same as sync) + with patch.object(cwb_module, "DATA_DESIGNER_ASYNC_ENGINE", True): + ColumnWiseDatasetBuilder._run_cell_by_cell_generator(builder, mock_generator) + builder._fan_out_with_async.assert_called_once_with(mock_generator, max_workers=4) + builder._fan_out_with_threads.assert_not_called() + + builder.reset_mock() + + # Test with async disabled — uses max_parallel_requests from generator + with patch.object(cwb_module, "DATA_DESIGNER_ASYNC_ENGINE", False): + ColumnWiseDatasetBuilder._run_cell_by_cell_generator(builder, mock_generator) + builder._fan_out_with_threads.assert_called_once_with(mock_generator, max_workers=4) + builder._fan_out_with_async.assert_not_called() diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 84da6ebb..53dc4625 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -4,13 +4,13 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from data_designer.engine.mcp.errors import MCPConfigurationError, MCPToolError from data_designer.engine.models.errors import ImageGenerationError, ModelGenerationValidationFailureError -from data_designer.engine.models.facade import ModelFacade +from data_designer.engine.models.facade import CustomRouter, ModelFacade from data_designer.engine.models.parsers.errors import ParserException from data_designer.engine.models.utils import ChatMessage from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, StubMessage, StubResponse @@ -18,6 +18,7 @@ if TYPE_CHECKING: import litellm + from litellm.types.utils import EmbeddingResponse, ModelResponse def mock_oai_response_object(response_text: str) -> StubResponse: @@ -61,18 +62,18 @@ def stub_expected_embedding_response(): (3, 3, 16), ], ) -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +@patch.object(ModelFacade, "completion", autospec=True) def test_generate( - mock_completion, - stub_model_facade, - max_correction_steps, - max_conversation_restarts, - total_calls, -): + mock_completion: Any, + stub_model_facade: ModelFacade, + max_correction_steps: int, + max_conversation_restarts: int, + total_calls: int, +) -> None: bad_response = mock_oai_response_object("bad response") mock_completion.side_effect = lambda *args, **kwargs: bad_response - def _failing_parser(response: str): + def _failing_parser(response: str) -> str: raise ParserException("parser exception") with pytest.raises(ModelGenerationValidationFailureError): @@ -103,7 +104,7 @@ def _failing_parser(response: str): ("hello!", [ChatMessage.as_system("hello!"), ChatMessage.as_user("does not matter")]), ], ) -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +@patch.object(ModelFacade, "completion", autospec=True) def test_generate_with_system_prompt( mock_completion: Any, stub_model_facade: ModelFacade, @@ -191,7 +192,7 @@ def test_consolidate_kwargs(stub_model_configs, stub_model_facade): True, ], ) -@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) +@patch.object(CustomRouter, "completion", autospec=True) def test_completion_success( mock_router_completion: Any, stub_completion_messages: list[ChatMessage], @@ -212,7 +213,7 @@ def test_completion_success( } -@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) +@patch.object(CustomRouter, "completion", autospec=True) def test_completion_with_exception( mock_router_completion: Any, stub_completion_messages: list[ChatMessage], @@ -224,7 +225,7 @@ def test_completion_with_exception( stub_model_facade.completion(stub_completion_messages) -@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) +@patch.object(CustomRouter, "completion", autospec=True) def test_completion_with_kwargs( mock_router_completion: Any, stub_completion_messages: list[ChatMessage], @@ -250,29 +251,36 @@ def mock_completion( assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs} -@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True) -def test_generate_text_embeddings_success(mock_router_embedding, stub_model_facade, stub_expected_embedding_response): +@patch.object(CustomRouter, "embedding", autospec=True) +def test_generate_text_embeddings_success( + mock_router_embedding: Any, + stub_model_facade: ModelFacade, + stub_expected_embedding_response: EmbeddingResponse, +) -> None: mock_router_embedding.side_effect = lambda self, model, input, **kwargs: stub_expected_embedding_response input_texts = ["test1", "test2"] result = stub_model_facade.generate_text_embeddings(input_texts) assert result == [data["embedding"] for data in stub_expected_embedding_response.data] -@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True) -def test_generate_text_embeddings_with_exception(mock_router_embedding, stub_model_facade): +@patch.object(CustomRouter, "embedding", autospec=True) +def test_generate_text_embeddings_with_exception(mock_router_embedding: Any, stub_model_facade: ModelFacade) -> None: mock_router_embedding.side_effect = Exception("Router error") with pytest.raises(Exception, match="Router error"): stub_model_facade.generate_text_embeddings(["test1", "test2"]) -@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True) +@patch.object(CustomRouter, "embedding", autospec=True) def test_generate_text_embeddings_with_kwargs( - mock_router_embedding, stub_model_configs, stub_model_facade, stub_expected_embedding_response -): + mock_router_embedding: Any, + stub_model_configs: Any, + stub_model_facade: ModelFacade, + stub_expected_embedding_response: EmbeddingResponse, +) -> None: captured_kwargs = {} - def mock_embedding(self, model, input, **kwargs): + def mock_embedding(self: Any, model: str, input: list[str], **kwargs: Any) -> EmbeddingResponse: captured_kwargs.update(kwargs) return stub_expected_embedding_response @@ -1272,3 +1280,232 @@ def test_generate_image_accumulates_usage( assert len(images2) == 3 # Usage should accumulate assert stub_model_facade.usage_stats.image_usage.total_images == 5 + + +# ============================================================================= +# Async behavior tests +# ============================================================================= + + +@pytest.mark.parametrize( + "skip_usage_tracking", + [ + False, + True, + ], +) +@patch.object(CustomRouter, "acompletion", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_acompletion_success( + mock_router_acompletion: AsyncMock, + stub_completion_messages: list[ChatMessage], + stub_model_configs: Any, + stub_model_facade: ModelFacade, + stub_expected_completion_response: ModelResponse, + skip_usage_tracking: bool, +) -> None: + mock_router_acompletion.return_value = stub_expected_completion_response + result = await stub_model_facade.acompletion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking) + expected_messages = [message.to_dict() for message in stub_completion_messages] + assert result == stub_expected_completion_response + assert mock_router_acompletion.call_count == 1 + assert mock_router_acompletion.call_args[1] == { + "model": "stub-model-text", + "messages": expected_messages, + **stub_model_configs[0].inference_parameters.generate_kwargs, + } + + +@patch.object(CustomRouter, "acompletion", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_acompletion_with_exception( + mock_router_acompletion: AsyncMock, + stub_completion_messages: list[ChatMessage], + stub_model_facade: ModelFacade, +) -> None: + mock_router_acompletion.side_effect = Exception("Router error") + + with pytest.raises(Exception, match="Router error"): + await stub_model_facade.acompletion(stub_completion_messages) + + +@patch.object(CustomRouter, "aembedding", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_agenerate_text_embeddings_success( + mock_router_aembedding: AsyncMock, + stub_model_facade: ModelFacade, + stub_expected_embedding_response: EmbeddingResponse, +) -> None: + mock_router_aembedding.return_value = stub_expected_embedding_response + input_texts = ["test1", "test2"] + result = await stub_model_facade.agenerate_text_embeddings(input_texts) + assert result == [data["embedding"] for data in stub_expected_embedding_response.data] + + +@pytest.mark.parametrize( + "max_correction_steps,max_conversation_restarts,total_calls", + [ + (0, 0, 1), + (1, 1, 4), + (1, 2, 6), + (5, 0, 6), + (0, 5, 6), + (3, 3, 16), + ], +) +@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_agenerate_correction_retries( + mock_acompletion: AsyncMock, + stub_model_facade: ModelFacade, + max_correction_steps: int, + max_conversation_restarts: int, + total_calls: int, +) -> None: + bad_response = mock_oai_response_object("bad response") + mock_acompletion.return_value = bad_response + + def _failing_parser(response: str) -> str: + raise ParserException("parser exception") + + with pytest.raises(ModelGenerationValidationFailureError): + await stub_model_facade.agenerate( + prompt="foo", + system_prompt="bar", + parser=_failing_parser, + max_correction_steps=max_correction_steps, + max_conversation_restarts=max_conversation_restarts, + ) + assert mock_acompletion.call_count == total_calls + + with pytest.raises(ModelGenerationValidationFailureError): + await stub_model_facade.agenerate( + prompt="foo", + parser=_failing_parser, + system_prompt="bar", + max_correction_steps=max_correction_steps, + max_conversation_restarts=max_conversation_restarts, + ) + assert mock_acompletion.call_count == 2 * total_calls + + +@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_agenerate_success( + mock_acompletion: AsyncMock, + stub_model_facade: ModelFacade, +) -> None: + good_response = mock_oai_response_object("parsed output") + mock_acompletion.return_value = good_response + + result, trace = await stub_model_facade.agenerate(prompt="test", parser=lambda x: x) + assert result == "parsed output" + assert mock_acompletion.call_count == 1 + # Trace should contain at least the user prompt and the assistant response + assert any(msg.role == "user" for msg in trace) + assert any(msg.role == "assistant" and msg.content == "parsed output" for msg in trace) + + +# ============================================================================= +# Async image generation tests +# ============================================================================= + + +@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_agenerate_image_diffusion_success( + mock_aimage_generation: AsyncMock, + stub_model_facade: ModelFacade, +) -> None: + """Test async image generation via diffusion API.""" + mock_response = litellm.types.utils.ImageResponse( + data=[ + litellm.types.utils.ImageObject(b64_json="image1_base64"), + litellm.types.utils.ImageObject(b64_json="image2_base64"), + ] + ) + mock_aimage_generation.return_value = mock_response + + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): + images = await stub_model_facade.agenerate_image(prompt="test prompt") + + assert len(images) == 2 + assert images == ["image1_base64", "image2_base64"] + assert mock_aimage_generation.call_count == 1 + # Verify image usage was tracked + assert stub_model_facade.usage_stats.image_usage.total_images == 2 + + +@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_agenerate_image_chat_completion_success( + mock_acompletion: AsyncMock, + stub_model_facade: ModelFacade, +) -> None: + """Test async image generation via chat completion API.""" + mock_message = litellm.types.utils.Message( + role="assistant", + content="", + images=[ + litellm.types.utils.ImageURLListItem( + type="image_url", image_url={"url": "data:image/png;base64,image1"}, index=0 + ), + ], + ) + mock_response = litellm.types.utils.ModelResponse(choices=[litellm.types.utils.Choices(message=mock_message)]) + mock_acompletion.return_value = mock_response + + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): + images = await stub_model_facade.agenerate_image(prompt="test prompt") + + assert len(images) == 1 + assert images == ["image1"] + assert mock_acompletion.call_count == 1 + assert stub_model_facade.usage_stats.image_usage.total_images == 1 + + +@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_agenerate_image_diffusion_no_data( + mock_aimage_generation: AsyncMock, + stub_model_facade: ModelFacade, +) -> None: + """Test async image generation raises error when diffusion API returns no data.""" + mock_response = litellm.types.utils.ImageResponse(data=[]) + mock_aimage_generation.return_value = mock_response + + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): + with pytest.raises(ImageGenerationError, match="Image generation returned no data"): + await stub_model_facade.agenerate_image(prompt="test prompt") + + +@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_agenerate_image_chat_completion_no_choices( + mock_acompletion: AsyncMock, + stub_model_facade: ModelFacade, +) -> None: + """Test async image generation raises error when response has no choices.""" + mock_response = litellm.types.utils.ModelResponse(choices=[]) + mock_acompletion.return_value = mock_response + + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): + with pytest.raises(ImageGenerationError, match="Image generation response missing choices"): + await stub_model_facade.agenerate_image(prompt="test prompt") + + +@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_agenerate_image_skip_usage_tracking( + mock_aimage_generation: AsyncMock, + stub_model_facade: ModelFacade, +) -> None: + """Test that async image generation respects skip_usage_tracking flag.""" + mock_response = litellm.types.utils.ImageResponse(data=[litellm.types.utils.ImageObject(b64_json="image1_base64")]) + mock_aimage_generation.return_value = mock_response + + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): + images = await stub_model_facade.agenerate_image(prompt="test prompt", skip_usage_tracking=True) + + assert len(images) == 1 + assert stub_model_facade.usage_stats.image_usage.total_images == 0 diff --git a/packages/data-designer-engine/tests/engine/models/test_litellm_overrides.py b/packages/data-designer-engine/tests/engine/models/test_litellm_overrides.py index 05f3ed79..6ccbe63a 100644 --- a/packages/data-designer-engine/tests/engine/models/test_litellm_overrides.py +++ b/packages/data-designer-engine/tests/engine/models/test_litellm_overrides.py @@ -8,6 +8,7 @@ import litellm import pytest +from data_designer.engine.models import litellm_overrides from data_designer.engine.models.litellm_overrides import ( DEFAULT_MAX_CALLBACKS, CustomRouter, @@ -56,9 +57,9 @@ def test_apply_litellm_patches_no_exceptions(): pytest.fail(f"apply_litellm_patches() raised an unexpected exception: {e}") -@patch("data_designer.engine.models.litellm_overrides.quiet_noisy_logger", autospec=True) -def test_apply_litellm_patches(mock_quiet_noisy_logger): - apply_litellm_patches() +@patch.object(litellm_overrides, "quiet_noisy_logger", autospec=True) +def test_apply_litellm_patches(mock_quiet_noisy_logger: object) -> None: + litellm_overrides.apply_litellm_patches() assert isinstance(litellm.in_memory_llm_clients_cache, ThreadSafeCache) assert ( litellm.litellm_core_utils.logging_callback_manager.LoggingCallbackManager.MAX_CALLBACKS diff --git a/packages/data-designer-engine/tests/engine/models/test_model_registry.py b/packages/data-designer-engine/tests/engine/models/test_model_registry.py index 17dfabdb..23d2a56e 100644 --- a/packages/data-designer-engine/tests/engine/models/test_model_registry.py +++ b/packages/data-designer-engine/tests/engine/models/test_model_registry.py @@ -6,6 +6,7 @@ import pytest from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig +from data_designer.engine.models import litellm_overrides from data_designer.engine.models.errors import ModelAuthenticationError from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.factory import create_model_registry @@ -41,10 +42,13 @@ def stub_no_usage_config(): ) -@patch("data_designer.engine.models.litellm_overrides.apply_litellm_patches", autospec=True) +@patch.object(litellm_overrides, "apply_litellm_patches", autospec=True) def test_create_model_registry( - mock_apply_litellm_patches, stub_model_configs, stub_secrets_resolver, stub_model_provider_registry -): + mock_apply_litellm_patches: object, + stub_model_configs: list[ModelConfig], + stub_secrets_resolver: object, + stub_model_provider_registry: object, +) -> None: model_registry = create_model_registry( model_configs=stub_model_configs, secret_resolver=stub_secrets_resolver, @@ -273,20 +277,26 @@ def test_get_usage_deltas( assert deltas == {} -@patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True) -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) -def test_run_health_check_success(mock_completion, mock_generate_text_embeddings, stub_model_registry): +@patch.object(ModelFacade, "generate_text_embeddings", autospec=True) +@patch.object(ModelFacade, "completion", autospec=True) +def test_run_health_check_success( + mock_completion: object, + mock_generate_text_embeddings: object, + stub_model_registry: ModelRegistry, +) -> None: model_aliases = {"stub-text", "stub-reasoning", "stub-embedding"} stub_model_registry.run_health_check(model_aliases) assert mock_completion.call_count == 2 assert mock_generate_text_embeddings.call_count == 1 -@patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True) -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +@patch.object(ModelFacade, "generate_text_embeddings", autospec=True) +@patch.object(ModelFacade, "completion", autospec=True) def test_run_health_check_completion_authentication_error( - mock_completion, mock_generate_text_embeddings, stub_model_registry -): + mock_completion: object, + mock_generate_text_embeddings: object, + stub_model_registry: ModelRegistry, +) -> None: auth_error = ModelAuthenticationError("Invalid API key for completion model") mock_completion.side_effect = auth_error model_aliases = ["stub-text", "stub-reasoning", "stub-embedding"] @@ -298,11 +308,13 @@ def test_run_health_check_completion_authentication_error( mock_generate_text_embeddings.assert_not_called() -@patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True) -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +@patch.object(ModelFacade, "generate_text_embeddings", autospec=True) +@patch.object(ModelFacade, "completion", autospec=True) def test_run_health_check_embedding_authentication_error( - mock_completion, mock_generate_text_embeddings, stub_model_registry -): + mock_completion: object, + mock_generate_text_embeddings: object, + stub_model_registry: ModelRegistry, +) -> None: auth_error = ModelAuthenticationError("Invalid API key for embedding model") mock_generate_text_embeddings.side_effect = auth_error model_aliases = ["stub-text", "stub-reasoning", "stub-embedding"] @@ -314,12 +326,12 @@ def test_run_health_check_embedding_authentication_error( mock_generate_text_embeddings.assert_called_once() -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +@patch.object(ModelFacade, "completion", autospec=True) def test_run_health_check_skip_health_check_flag( - mock_completion, - stub_secrets_resolver, - stub_model_provider_registry, -): + mock_completion: object, + stub_secrets_resolver: object, + stub_model_provider_registry: object, +) -> None: # Create model configs: one with skip_health_check=True, others with default (False) model_configs = [ ModelConfig( diff --git a/packages/data-designer-engine/tests/engine/validators/test_sql.py b/packages/data-designer-engine/tests/engine/validators/test_sql.py index 756d0856..3ab599e1 100644 --- a/packages/data-designer-engine/tests/engine/validators/test_sql.py +++ b/packages/data-designer-engine/tests/engine/validators/test_sql.py @@ -1,12 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from unittest.mock import patch + +import pytest + from data_designer.config.utils.code_lang import CodeLang from data_designer.config.validator_params import CodeValidatorParams +from data_designer.engine.validators import sql as sql_validator_module from data_designer.engine.validators.sql import SQLValidator -def test_valid_ansi_sql_code(): +def test_valid_ansi_sql_code() -> None: sql_validator = SQLValidator(CodeValidatorParams(code_lang=CodeLang.SQL_ANSI)) code = "SELECT category, COUNT(*) as total_incidents FROM security_incidents_2 GROUP BY category;" result = sql_validator.run_validation([{"sql": code}]) @@ -14,9 +19,31 @@ def test_valid_ansi_sql_code(): assert result.data[0].error_messages == "" -def test_invalid_ansi_sql_code(): +def test_invalid_ansi_sql_code() -> None: sql_validator = SQLValidator(CodeValidatorParams(code_lang=CodeLang.SQL_ANSI)) code = "NOT SQL" result = sql_validator.run_validation([{"sql": code}]) assert not result.data[0].is_valid assert result.data[0].error_messages == "PRS: Line 1, Position 1: Found unparsable section: 'NOT SQL'" + + +def test_sql_validator_multi_column_input_raises() -> None: + sql_validator = SQLValidator(CodeValidatorParams(code_lang=CodeLang.SQL_ANSI)) + with pytest.raises(ValueError, match="single column input"): + sql_validator.run_validation([{"sql": "SELECT 1", "extra": "ignored"}]) + + +def test_sql_validator_decimal_without_scale_fails() -> None: + sql_validator = SQLValidator(CodeValidatorParams(code_lang=CodeLang.SQL_ANSI)) + code = "CREATE TABLE example (amount DECIMAL(10));" + result = sql_validator.run_validation([{"sql": code}]) + assert not result.data[0].is_valid + assert "DECIMAL definitions without a scale" in result.data[0].error_messages + + +def test_sql_validator_handles_lint_exception() -> None: + sql_validator = SQLValidator(CodeValidatorParams(code_lang=CodeLang.SQL_ANSI)) + with patch.object(sql_validator_module.sqlfluff, "lint", side_effect=RuntimeError("boom")): + result = sql_validator.run_validation([{"sql": "SELECT 1"}]) + assert not result.data[0].is_valid + assert "Exception during SQL parsing" in result.data[0].error_messages diff --git a/scripts/benchmarks/benchmark_engine_v2.py b/scripts/benchmarks/benchmark_engine_v2.py new file mode 100644 index 00000000..5a3bb073 --- /dev/null +++ b/scripts/benchmarks/benchmark_engine_v2.py @@ -0,0 +1,853 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Benchmark DataDesigner engine performance with mock LLMs.""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import hashlib +import json +import math +import os +import random +import statistics +import subprocess +import sys +import tempfile +import time +from collections.abc import Iterator +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig, ValidationColumnConfig +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.mcp import MCPProvider, ToolConfig +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider +from data_designer.config.run_config import RunConfig +from data_designer.config.sampler_params import SamplerType, UniformSamplerParams +from data_designer.config.validator_params import LocalCallableValidatorParams, ValidatorType +from data_designer.engine.mcp.registry import MCPToolDefinition, MCPToolResult +from data_designer.lazy_heavy_imports import np, pd + +if TYPE_CHECKING: + import numpy as np + import pandas as pd + + +RESULT_PREFIX = "BENCHMARK_RESULT=" +DEFAULT_NUM_RECORDS = 1024 +DEFAULT_BUFFER_SIZE = 1024 +DEFAULT_SEED = 11 +DEFAULT_MAX_PARALLEL_REQUESTS = 16 +DEFAULT_VALIDATOR_BATCH_SIZE = 256 +DEFAULT_ITERATIONS = 5 + +MOCK_MCP_PROVIDER_NAME = "mock-mcp" +MOCK_TOOL_ALIAS = "mock-tools" +MOCK_TOOL_NAME = "mock_lookup" +MOCK_TOOL_DESCRIPTION = "Mock lookup tool for benchmark runs." +MOCK_TOOL_SCHEMA = { + "type": "object", + "properties": { + "query": {"type": "string"}, + "limit": {"type": "integer"}, + }, + "required": ["query"], +} + + +@dataclass(frozen=True) +class BenchmarkSettings: + num_records: int + buffer_size: int + seed: int + max_parallel_requests: int + validator_batch_size: int + simulated_latency: bool = False + + def to_cli_args(self) -> list[str]: + args = [ + "--num-records", + str(self.num_records), + "--buffer-size", + str(self.buffer_size), + "--seed", + str(self.seed), + "--max-parallel-requests", + str(self.max_parallel_requests), + "--validator-batch-size", + str(self.validator_batch_size), + ] + if self.simulated_latency: + args.append("--simulated-latency") + return args + + +@dataclass(frozen=True) +class BenchmarkResult: + engine_mode: str + num_records: int + buffer_size: int + build_time_sec: float + total_time_sec: float + dataset_hash: str + row_count: int + column_count: int + + def to_dict(self) -> dict[str, Any]: + return { + "engine_mode": self.engine_mode, + "num_records": self.num_records, + "buffer_size": self.buffer_size, + "build_time_sec": self.build_time_sec, + "total_time_sec": self.total_time_sec, + "dataset_hash": self.dataset_hash, + "row_count": self.row_count, + "column_count": self.column_count, + } + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> BenchmarkResult: + return cls( + engine_mode=str(payload["engine_mode"]), + num_records=int(payload["num_records"]), + buffer_size=int(payload["buffer_size"]), + build_time_sec=float(payload["build_time_sec"]), + total_time_sec=float(payload["total_time_sec"]), + dataset_hash=str(payload["dataset_hash"]), + row_count=int(payload["row_count"]), + column_count=int(payload["column_count"]), + ) + + +@dataclass(frozen=True) +class MetricStats: + mean: float + stdev: float + ci_half_width: float + n: int + + @property + def ci_low(self) -> float: + return self.mean - self.ci_half_width + + @property + def ci_high(self) -> float: + return self.mean + self.ci_half_width + + +@dataclass(frozen=True) +class ResponseProfile: + label: str + score_mu: float + score_sigma: float + latency_alpha: float + latency_beta: float + volatility_sigma: float + categories: tuple[str, ...] + category_weights: tuple[float, ...] + + +MODEL_PROFILES: dict[str, ResponseProfile] = { + "mock-alpha": ResponseProfile( + label="alpha", + score_mu=0.1, + score_sigma=0.35, + latency_alpha=2.2, + latency_beta=6.0, + volatility_sigma=0.25, + categories=("low", "mid", "high"), + category_weights=(0.25, 0.55, 0.2), + ), + "mock-beta": ResponseProfile( + label="beta", + score_mu=0.3, + score_sigma=0.45, + latency_alpha=2.6, + latency_beta=4.8, + volatility_sigma=0.3, + categories=("low", "mid", "high"), + category_weights=(0.2, 0.5, 0.3), + ), + "mock-gamma": ResponseProfile( + label="gamma", + score_mu=0.5, + score_sigma=0.5, + latency_alpha=3.0, + latency_beta=3.6, + volatility_sigma=0.35, + categories=("low", "mid", "high"), + category_weights=(0.15, 0.45, 0.4), + ), +} + +DEFAULT_PROFILE = ResponseProfile( + label="default", + score_mu=0.2, + score_sigma=0.4, + latency_alpha=2.4, + latency_beta=5.0, + volatility_sigma=0.3, + categories=("low", "mid", "high"), + category_weights=(0.3, 0.5, 0.2), +) + + +@dataclass(frozen=True) +class FakeMessage: + content: str + tool_calls: list[dict[str, Any]] | None = None + reasoning_content: str | None = None + + +@dataclass(frozen=True) +class FakeChoice: + message: FakeMessage + + +@dataclass(frozen=True) +class FakeResponse: + choices: list[FakeChoice] + usage: Any | None = None + model: str | None = None + latency_ms: float = 0.0 + + +def _distinct_parallel_requests(base: int) -> tuple[int, int, int]: + if base < 3: + raise ValueError("max_parallel_requests must be >= 3 to create distinct per-model limits.") + high = base + mid = max(1, int(round(high / 2))) + low = max(1, int(round(high / 5))) + + if mid >= high: + mid = high - 1 + if low >= mid: + low = max(1, mid - 1) + + return high, mid, low + + +def _t_critical_95(df: int) -> float: + table = { + 1: 12.706, + 2: 4.303, + 3: 3.182, + 4: 2.776, + 5: 2.571, + 6: 2.447, + 7: 2.365, + 8: 2.306, + 9: 2.262, + 10: 2.228, + 11: 2.201, + 12: 2.179, + 13: 2.160, + 14: 2.145, + 15: 2.131, + 16: 2.120, + 17: 2.110, + 18: 2.101, + 19: 2.093, + 20: 2.086, + 21: 2.080, + 22: 2.074, + 23: 2.069, + 24: 2.064, + 25: 2.060, + 26: 2.056, + 27: 2.052, + 28: 2.048, + 29: 2.045, + 30: 2.042, + } + return table.get(df, 1.96) + + +def _compute_stats(values: list[float]) -> MetricStats: + if not values: + return MetricStats(mean=0.0, stdev=0.0, ci_half_width=0.0, n=0) + if len(values) == 1: + return MetricStats(mean=values[0], stdev=0.0, ci_half_width=0.0, n=1) + stdev = statistics.stdev(values) + mean = statistics.mean(values) + t_value = _t_critical_95(len(values) - 1) + ci_half_width = t_value * stdev / math.sqrt(len(values)) + return MetricStats(mean=mean, stdev=stdev, ci_half_width=ci_half_width, n=len(values)) + + +def _format_stats(stats: MetricStats, *, unit: str, precision: int = 3) -> str: + fmt = f"{{:.{precision}f}}" + mean = fmt.format(stats.mean) + ci = fmt.format(stats.ci_half_width) + stdev = fmt.format(stats.stdev) + return f"{mean}{unit} ± {ci}{unit} (stdev {stdev}{unit}, n={stats.n})" + + +def _format_speed_stats(stats: MetricStats, *, precision: int = 2) -> str: + fmt = f"{{:.{precision}f}}" + mean = fmt.format(stats.mean) + ci = fmt.format(stats.ci_half_width) + stdev = fmt.format(stats.stdev) + return f"{mean}x ± {ci}x (stdev {stdev}x, n={stats.n})" + + +def _significant_diff(stats: MetricStats) -> bool: + return stats.n > 1 and abs(stats.mean) > stats.ci_half_width + + +def _json_default(value: Any) -> Any: + if isinstance(value, np.generic): + return value.item() + if isinstance(value, np.ndarray): + return value.tolist() + if isinstance(value, (pd.Timestamp, pd.Timedelta)): + return value.isoformat() + if isinstance(value, set): + return sorted(value) + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace") + return str(value) + + +def _stable_seed(model: str, messages: list[dict[str, Any]]) -> int: + payload = json.dumps( + {"model": model, "messages": messages}, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=True, + default=_json_default, + ) + digest = hashlib.sha256(payload.encode("utf-8")).digest() + return int.from_bytes(digest[:8], "big") + + +def _profile_for_model(model: str) -> ResponseProfile: + for key, profile in MODEL_PROFILES.items(): + if key in model: + return profile + return DEFAULT_PROFILE + + +def _mock_response_text(model: str, messages: list[dict[str, Any]]) -> tuple[str, float]: + profile = _profile_for_model(model) + rng = random.Random(_stable_seed(model, messages)) + category = rng.choices(profile.categories, weights=profile.category_weights, k=1)[0] + score = rng.lognormvariate(profile.score_mu, profile.score_sigma) + latency_ms = int(rng.betavariate(profile.latency_alpha, profile.latency_beta) * 900.0) + volatility = rng.gauss(0.0, profile.volatility_sigma) + text = f"{profile.label}:{category}|score={score:.3f}|latency_ms={latency_ms}|vol={volatility:.3f}" + return text, float(latency_ms) + + +def _tool_call_id(model: str, messages: list[dict[str, Any]]) -> str: + call_seed = _stable_seed(model, messages) + return f"tool-{call_seed:016x}" + + +def _tool_call_arguments(model: str, messages: list[dict[str, Any]]) -> dict[str, Any]: + rng = random.Random(_stable_seed(model, messages)) + return { + "query": f"{model}-lookup-{rng.randint(1000, 9999)}", + "limit": rng.randint(1, 3), + } + + +def _build_tool_call(model: str, messages: list[dict[str, Any]]) -> dict[str, Any]: + arguments = _tool_call_arguments(model, messages) + return { + "id": _tool_call_id(model, messages), + "type": "function", + "function": {"name": MOCK_TOOL_NAME, "arguments": json.dumps(arguments)}, + } + + +def _should_request_tool(messages: list[dict[str, Any]]) -> bool: + return not any(message.get("role") == "tool" for message in messages) + + +def _mock_tool_definition() -> MCPToolDefinition: + return MCPToolDefinition( + name=MOCK_TOOL_NAME, + description=MOCK_TOOL_DESCRIPTION, + input_schema=MOCK_TOOL_SCHEMA, + ) + + +def _mock_tool_result(tool_name: str, arguments: dict[str, Any], provider_name: str) -> MCPToolResult: + payload = { + "tool": tool_name, + "provider": provider_name, + "query": arguments.get("query", ""), + "limit": arguments.get("limit", 0), + "status": "ok", + } + return MCPToolResult(content=json.dumps(payload)) + + +def _fake_response(model: str, messages: list[dict[str, Any]], **kwargs: Any) -> FakeResponse: + if kwargs.get("tools") and _should_request_tool(messages): + tool_call = _build_tool_call(model, messages) + # Compute latency for tool-call responses using the same profile/seed mechanism. + profile = _profile_for_model(model) + rng = random.Random(_stable_seed(model, messages)) + latency_ms = float(int(rng.betavariate(profile.latency_alpha, profile.latency_beta) * 900.0)) + return FakeResponse( + choices=[FakeChoice(message=FakeMessage(content="Using tool.", tool_calls=[tool_call]))], + model=model, + latency_ms=latency_ms, + ) + response_text, latency_ms = _mock_response_text(model, messages) + return FakeResponse( + choices=[FakeChoice(message=FakeMessage(content=response_text))], + model=model, + latency_ms=latency_ms, + ) + + +@contextlib.contextmanager +def _patch_llm_responses(*, simulated_latency: bool = False) -> Iterator[None]: + # Imports are deferred so engine selection respects DATA_DESIGNER_ASYNC_ENGINE. + from data_designer.engine.models.litellm_overrides import CustomRouter + + original_completion = CustomRouter.completion + original_acompletion = getattr(CustomRouter, "acompletion", None) + + def fake_completion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> FakeResponse: + _ = self + response = _fake_response(model, messages, **kwargs) + if simulated_latency and response.latency_ms > 0: + time.sleep(response.latency_ms / 1000.0) + return response + + async def fake_acompletion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> FakeResponse: + _ = self + response = _fake_response(model, messages, **kwargs) + if simulated_latency and response.latency_ms > 0: + await asyncio.sleep(response.latency_ms / 1000.0) + return response + + CustomRouter.completion = fake_completion + CustomRouter.acompletion = fake_acompletion + try: + yield + finally: + CustomRouter.completion = original_completion + if original_acompletion is not None: + CustomRouter.acompletion = original_acompletion + else: + try: + delattr(CustomRouter, "acompletion") + except AttributeError: + pass + + +@contextlib.contextmanager +def _patch_mcp_io() -> Iterator[None]: + import data_designer.engine.mcp.io as mcp_io + + original_list_tools = mcp_io.list_tools + original_call_tools = mcp_io.call_tools + + def fake_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: + if getattr(provider, "name", None) != MOCK_MCP_PROVIDER_NAME: + return original_list_tools(provider, timeout_sec=timeout_sec) + return (_mock_tool_definition(),) + + def fake_call_tools( + calls: list[tuple[Any, str, dict[str, Any]]], + *, + timeout_sec: float | None = None, + ) -> list[MCPToolResult]: + if any(getattr(call[0], "name", None) != MOCK_MCP_PROVIDER_NAME for call in calls): + return original_call_tools(calls, timeout_sec=timeout_sec) + return [_mock_tool_result(tool_name, arguments, provider.name) for provider, tool_name, arguments in calls] + + mcp_io.list_tools = fake_list_tools + mcp_io.call_tools = fake_call_tools + try: + yield + finally: + mcp_io.list_tools = original_list_tools + mcp_io.call_tools = original_call_tools + + +def _extract_metric(text: str, key: str) -> float | None: + marker = f"{key}=" + start = text.find(marker) + if start == -1: + return None + start += len(marker) + end = start + while end < len(text) and (text[end].isdigit() or text[end] in {".", "-"}): + end += 1 + try: + return float(text[start:end]) + except ValueError: + return None + + +def _validate_recommendation(df: pd.DataFrame) -> pd.DataFrame: + series = df["llm_stage3"].astype(str) + scores = series.map(lambda text: _extract_metric(text, "score")) + latencies = series.map(lambda text: _extract_metric(text, "latency_ms")) + scores_numeric = pd.to_numeric(scores, errors="coerce") + latency_numeric = pd.to_numeric(latencies, errors="coerce") + is_valid = scores_numeric.between(0.0, 10.0) & latency_numeric.between(0.0, 900.0) + return pd.DataFrame( + { + "is_valid": is_valid.fillna(False).astype(bool), + "score": scores_numeric, + "latency_ms": latency_numeric, + } + ) + + +def _build_config(settings: BenchmarkSettings) -> DataDesignerConfigBuilder: + high_parallel, mid_parallel, low_parallel = _distinct_parallel_requests(settings.max_parallel_requests) + model_configs = [ + ModelConfig( + alias="mock-alpha", + model="mock-alpha", + provider="mock-provider", + inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=high_parallel), + skip_health_check=True, + ), + ModelConfig( + alias="mock-beta", + model="mock-beta", + provider="mock-provider", + inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=low_parallel), + skip_health_check=True, + ), + ModelConfig( + alias="mock-gamma", + model="mock-gamma", + provider="mock-provider", + inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=mid_parallel), + skip_health_check=True, + ), + ] + + builder = DataDesignerConfigBuilder(model_configs=model_configs) + builder.add_tool_config( + ToolConfig( + tool_alias=MOCK_TOOL_ALIAS, + providers=[MOCK_MCP_PROVIDER_NAME], + allow_tools=[MOCK_TOOL_NAME], + max_tool_call_turns=1, + timeout_sec=1.0, + ) + ) + builder.add_column( + SamplerColumnConfig( + name="seed_value", + sampler_type=SamplerType.UNIFORM, + params=UniformSamplerParams(low=0.0, high=100.0, decimal_places=3), + ) + ) + builder.add_column( + LLMTextColumnConfig( + name="llm_stage1", + model_alias="mock-alpha", + prompt="Summarize the signal for seed {{ seed_value }}.", + ) + ) + builder.add_column( + LLMTextColumnConfig( + name="llm_stage2", + model_alias="mock-beta", + tool_alias=MOCK_TOOL_ALIAS, + prompt="Analyze {{ llm_stage1 }} and produce a risk assessment.", + ) + ) + builder.add_column( + LLMTextColumnConfig( + name="llm_stage3", + model_alias="mock-gamma", + prompt="Generate a recommendation from {{ llm_stage2 }} with seed {{ seed_value }}.", + ) + ) + builder.add_column( + ValidationColumnConfig( + name="llm_stage3_validation", + target_columns=["llm_stage3"], + validator_type=ValidatorType.LOCAL_CALLABLE, + validator_params=LocalCallableValidatorParams(validation_function=_validate_recommendation), + batch_size=settings.validator_batch_size, + ) + ) + return builder + + +def _dataset_fingerprint(df: pd.DataFrame) -> str: + normalized = df.reset_index(drop=True) + normalized = normalized.reindex(sorted(normalized.columns), axis=1) + records = normalized.to_dict(orient="records") + payload = json.dumps( + records, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=True, + default=_json_default, + ) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + +def _run_single_benchmark(settings: BenchmarkSettings, engine_mode: str) -> BenchmarkResult: + # Imports are deferred so engine selection respects DATA_DESIGNER_ASYNC_ENGINE. + from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage + from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder + from data_designer.engine.model_provider import resolve_model_provider_registry + from data_designer.engine.resources.resource_provider import create_resource_provider + from data_designer.engine.resources.seed_reader import SeedReaderRegistry + from data_designer.engine.secret_resolver import CompositeResolver, EnvironmentResolver, PlaintextResolver + + random.seed(settings.seed) + np.random.seed(settings.seed) + + run_config = RunConfig( + buffer_size=settings.buffer_size, + disable_early_shutdown=True, + max_conversation_restarts=0, + max_conversation_correction_steps=0, + ) + builder = _build_config(settings) + + provider = ModelProvider( + name="mock-provider", + endpoint="https://mock.local", + provider_type="openai", + api_key="mock-key", + ) + mcp_provider = MCPProvider( + name=MOCK_MCP_PROVIDER_NAME, + endpoint="https://mock.local/mcp", + api_key="mock-mcp-key", + ) + model_provider_registry = resolve_model_provider_registry([provider], default_provider_name=provider.name) + secret_resolver = CompositeResolver([EnvironmentResolver(), PlaintextResolver()]) + + with tempfile.TemporaryDirectory() as temp_dir: + artifact_storage = ArtifactStorage(artifact_path=temp_dir, dataset_name=f"benchmark-{engine_mode}") + resource_provider = create_resource_provider( + artifact_storage=artifact_storage, + model_configs=builder.model_configs, + secret_resolver=secret_resolver, + model_provider_registry=model_provider_registry, + seed_reader_registry=SeedReaderRegistry(readers=[]), + blob_storage=None, + seed_dataset_source=None, + run_config=run_config, + mcp_providers=[mcp_provider], + tool_configs=builder.tool_configs, + ) + dataset_builder = ColumnWiseDatasetBuilder( + data_designer_config=builder.build(), + resource_provider=resource_provider, + ) + + total_start = time.perf_counter() + with _patch_llm_responses(simulated_latency=settings.simulated_latency), _patch_mcp_io(): + build_start = time.perf_counter() + dataset_builder.build(num_records=settings.num_records) + build_time = time.perf_counter() - build_start + dataset = dataset_builder.artifact_storage.load_dataset_with_dropped_columns() + + dataset_hash = _dataset_fingerprint(dataset) + total_time = time.perf_counter() - total_start + + return BenchmarkResult( + engine_mode=engine_mode, + num_records=settings.num_records, + buffer_size=settings.buffer_size, + build_time_sec=build_time, + total_time_sec=total_time, + dataset_hash=dataset_hash, + row_count=int(dataset.shape[0]), + column_count=int(dataset.shape[1]), + ) + + +def _run_subprocess(settings: BenchmarkSettings, engine_mode: str) -> BenchmarkResult: + env = os.environ.copy() + if engine_mode == "async": + env["DATA_DESIGNER_ASYNC_ENGINE"] = "1" + else: + env.pop("DATA_DESIGNER_ASYNC_ENGINE", None) + + script_path = Path(__file__).resolve() + cmd = [sys.executable, str(script_path), "--mode", "run", "--engine", engine_mode, *settings.to_cli_args()] + completed = subprocess.run(cmd, capture_output=True, text=True, env=env, check=False) + + if completed.returncode != 0: + raise RuntimeError(f"Benchmark subprocess failed.\nstdout:\n{completed.stdout}\nstderr:\n{completed.stderr}") + + for line in reversed(completed.stdout.splitlines()): + if line.startswith(RESULT_PREFIX): + payload = json.loads(line.removeprefix(RESULT_PREFIX)) + return BenchmarkResult.from_dict(payload) + + raise RuntimeError( + f"Benchmark subprocess did not emit a result payload.\nstdout:\n{completed.stdout}\nstderr:\n{completed.stderr}" + ) + + +def _format_speedup(sync_time: float, async_time: float) -> str: + if async_time <= 0: + return "n/a" + return f"{(sync_time / async_time):.2f}x" + + +def _run_with_progress(settings: BenchmarkSettings, engine_mode: str, iteration: int, total: int) -> BenchmarkResult: + print(f"[{iteration}/{total}] Running {engine_mode} benchmark...", end="", flush=True) + result = _run_subprocess(settings, engine_mode) + print(f" done ({result.total_time_sec:.3f}s)") + return result + + +def _compare_runs(settings: BenchmarkSettings, iterations: int) -> int: + sync_results: list[BenchmarkResult] = [] + async_results: list[BenchmarkResult] = [] + expected_hash: str | None = None + + for iteration in range(1, iterations + 1): + sync_result = _run_with_progress(settings, "sync", iteration, iterations) + async_result = _run_with_progress(settings, "async", iteration, iterations) + + if sync_result.dataset_hash != async_result.dataset_hash: + print( + "Content mismatch detected: " + f"sync hash {sync_result.dataset_hash} vs async hash {async_result.dataset_hash}" + ) + return 1 + + if expected_hash is None: + expected_hash = sync_result.dataset_hash + elif expected_hash != sync_result.dataset_hash or expected_hash != async_result.dataset_hash: + print("Content mismatch detected across iterations.") + return 1 + + sync_results.append(sync_result) + async_results.append(async_result) + + build_sync = [result.build_time_sec for result in sync_results] + build_async = [result.build_time_sec for result in async_results] + total_sync = [result.total_time_sec for result in sync_results] + total_async = [result.total_time_sec for result in async_results] + + build_speedups = [sync / async_ for sync, async_ in zip(build_sync, build_async)] + total_speedups = [sync / async_ for sync, async_ in zip(total_sync, total_async)] + build_diffs = [sync - async_ for sync, async_ in zip(build_sync, build_async)] + total_diffs = [sync - async_ for sync, async_ in zip(total_sync, total_async)] + + build_sync_stats = _compute_stats(build_sync) + build_async_stats = _compute_stats(build_async) + total_sync_stats = _compute_stats(total_sync) + total_async_stats = _compute_stats(total_async) + + build_speed_stats = _compute_stats(build_speedups) + total_speed_stats = _compute_stats(total_speedups) + build_diff_stats = _compute_stats(build_diffs) + total_diff_stats = _compute_stats(total_diffs) + + latency_label = "on" if settings.simulated_latency else "off" + print("\nEngine benchmark summary (95% CI)") + print(f"- runs: {iterations} | content match: yes | hash {expected_hash}") + print(f"- simulated latency: {latency_label}") + print(f"- build time sync: {_format_stats(build_sync_stats, unit='s')}") + print(f"- build time async: {_format_stats(build_async_stats, unit='s')}") + print( + f"- build speedup: {_format_speed_stats(build_speed_stats)} | " + f"paired diff {_format_stats(build_diff_stats, unit='s')} | " + f"significant: {'yes' if _significant_diff(build_diff_stats) else 'no'}" + ) + print(f"- total time sync: {_format_stats(total_sync_stats, unit='s')}") + print(f"- total time async: {_format_stats(total_async_stats, unit='s')}") + print( + f"- total speedup: {_format_speed_stats(total_speed_stats)} | " + f"paired diff {_format_stats(total_diff_stats, unit='s')} | " + f"significant: {'yes' if _significant_diff(total_diff_stats) else 'no'}" + ) + + return 0 + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark DataDesigner engine with mock LLMs and compare async execution." + ) + parser.add_argument( + "--mode", + type=str, + choices=("compare", "run"), + default="compare", + help="Run both engines in subprocesses, or run once in the current process.", + ) + parser.add_argument( + "--engine", + type=str, + choices=("sync", "async"), + default="sync", + help="Engine mode for --mode run.", + ) + parser.add_argument("--num-records", type=int, default=DEFAULT_NUM_RECORDS, help="Records to generate.") + parser.add_argument("--buffer-size", type=int, default=DEFAULT_BUFFER_SIZE, help="Batch buffer size.") + parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="Random seed for determinism.") + parser.add_argument( + "--iterations", + type=int, + default=DEFAULT_ITERATIONS, + help="Number of sync/async runs to include in the compare mode.", + ) + parser.add_argument( + "--max-parallel-requests", + type=int, + default=DEFAULT_MAX_PARALLEL_REQUESTS, + help="Max parallel LLM requests per model.", + ) + parser.add_argument( + "--validator-batch-size", + type=int, + default=DEFAULT_VALIDATOR_BATCH_SIZE, + help="Batch size for the local validator.", + ) + parser.add_argument( + "--simulated-latency", + action="store_true", + default=False, + help="Simulate LLM response latency using beta-distributed delays.", + ) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + settings = BenchmarkSettings( + num_records=args.num_records, + buffer_size=args.buffer_size, + seed=args.seed, + max_parallel_requests=args.max_parallel_requests, + validator_batch_size=args.validator_batch_size, + simulated_latency=args.simulated_latency, + ) + + if args.mode == "compare": + sys.exit(_compare_runs(settings, args.iterations)) + + if args.engine == "async": + os.environ["DATA_DESIGNER_ASYNC_ENGINE"] = "1" + else: + os.environ.pop("DATA_DESIGNER_ASYNC_ENGINE", None) + + print(f"Running {args.engine} benchmark...") + result = _run_single_benchmark(settings, args.engine) + print(f"{RESULT_PREFIX}{json.dumps(result.to_dict(), sort_keys=True)}") + + +if __name__ == "__main__": + main()