diff --git a/LITELLM_REMOVAL_ANALYSIS.md b/LITELLM_REMOVAL_ANALYSIS.md
new file mode 100644
index 00000000..2b2b705f
--- /dev/null
+++ b/LITELLM_REMOVAL_ANALYSIS.md
@@ -0,0 +1,441 @@
+# LiteLLM Removal: Impact Analysis & Implementation Plan
+
+## Executive Summary
+
+LiteLLM serves as Data Designer's abstraction layer between model configuration and HTTP API calls. It provides multi-provider routing, retry/backoff logic, error normalization, and response type unification. The dependency is well-contained — 12 files touch it directly, all within the `engine/models/` and `engine/models_v2/` layers. Removing it is feasible but non-trivial. The work breaks down into five areas: HTTP client, error handling, retry/resilience, response normalization, and provider abstraction.
+
+**Dependency**: `litellm>=1.73.6,<1.80.12` (pinned in `packages/data-designer-engine/pyproject.toml`)
+
+---
+
+## What LiteLLM Provides Today
+
+### 1. Multi-Provider Routing
+
+LiteLLM's `Router` class accepts a deployment list and routes `completion()`/`embedding()` calls to the correct provider endpoint using the `{provider_type}/{model_name}` naming convention.
+
+**DD's usage**: Each `ModelFacade` gets its own `CustomRouter` with a single-element deployment list — so DD does not actually use multi-deployment load balancing or failover. The Router is effectively a single-provider HTTP client with retry logic.
+
+**Key takeaway**: DD uses the Router as a resilient HTTP client, not as a load balancer. This simplifies replacement significantly.
+
+### 2. Provider Abstraction (`provider_type`)
+
+`ModelProvider.provider_type` (default: `"openai"`) tells LiteLLM which API format to use. LiteLLM translates this into the correct HTTP request format, auth headers, and response parsing for each provider (OpenAI, Anthropic, Cohere, etc.).
+
+**DD's usage**: The `provider_type` is combined with the model name (`f"{provider_type}/{model_name}"`) and passed to the Router. DD also passes `api_base` (endpoint URL) and `api_key` per deployment.
+
+**Key question**: How many distinct `provider_type` values does DD actually use in production? If it's primarily `"openai"` (OpenAI-compatible APIs including NVIDIA NIM), the provider abstraction is low-value since most inference providers expose OpenAI-compatible endpoints. If it includes `"anthropic"`, `"cohere"`, etc., the translation layer is more valuable.
+
+### 3. Retry & Exponential Backoff
+
+DD's `CustomRouter` extends LiteLLM's `Router` with configurable exponential backoff:
+
+- **Retry policy**: 3 retries for `RateLimitError`, 3 retries for `Timeout`
+- **Backoff formula**: `initial_retry_after_s * 2^retry_count * (1 +/- jitter_pct)`
+- **Defaults**: initial=2s, jitter=20%, timeout=60s
+- **Server Retry-After**: Extracted from response headers, capped at 60s
+
+This is the most custom piece of DD's LiteLLM integration. The `_time_to_sleep_before_retry()` and `_extract_retry_delay_from_headers()` overrides are well-tested.
+
+### 4. Error Normalization (12 Exception Types)
+
+`handle_llm_exceptions()` in `models/errors.py` pattern-matches 12 LiteLLM exception types and maps them to DD-specific error classes:
+
+| LiteLLM Exception | DD Error | Notes |
+|---|---|---|
+| `APIError` | `ModelAPIError` | Special 403 detection |
+| `APIConnectionError` | `ModelAPIConnectionError` | Network issues |
+| `AuthenticationError` | `ModelAuthenticationError` | Invalid API key |
+| `ContextWindowExceededError` | `ModelContextWindowExceededError` | Parses OpenAI token details |
+| `UnsupportedParamsError` | `ModelUnsupportedParamsError` | |
+| `BadRequestError` | `ModelBadRequestError` | Detects multimodal rejection |
+| `InternalServerError` | `ModelInternalServerError` | |
+| `NotFoundError` | `ModelNotFoundError` | |
+| `PermissionDeniedError` | `ModelPermissionDeniedError` | |
+| `RateLimitError` | `ModelRateLimitError` | |
+| `Timeout` | `ModelTimeoutError` | |
+| `UnprocessableEntityError` | `ModelUnprocessableEntityError` | |
+
+The DD error types already exist and carry user-friendly `FormattedLLMErrorMessage(cause, solution)` payloads. The only LiteLLM-specific part is the `match` statement that catches LiteLLM's exception classes.
+
+### 5. Response Type Normalization
+
+DD accesses these fields from LiteLLM responses:
+
+**Chat completions** (`ModelResponse`):
+- `response.choices[0].message.content` — generated text
+- `response.choices[0].message.reasoning_content` — extended thinking (via `getattr`)
+- `response.choices[0].message.tool_calls` — MCP tool calls
+- `response.usage.prompt_tokens` / `response.usage.completion_tokens` — token counts
+
+**Embeddings** (`EmbeddingResponse`):
+- `response.data[i]["embedding"]` — float arrays
+- `response.usage.prompt_tokens` — input token count
+
+These match the OpenAI API response format exactly. LiteLLM normalizes responses from non-OpenAI providers into this shape.
+
+### 6. Global Patches
+
+`apply_litellm_patches()` modifies LiteLLM's global state at startup:
+- Replaces the in-memory client cache with a thread-safe version (`ThreadSafeCache`)
+- Increases `LoggingCallbackManager.MAX_CALLBACKS` to 1000 (workaround for litellm#9792)
+- Suppresses verbose logging from httpx, LiteLLM, and LiteLLM Router
+
+These are workarounds for LiteLLM's rough edges — they disappear entirely if LiteLLM is removed.
+
+---
+
+## Arguments For Removal
+
+### Reduced dependency weight
+LiteLLM is a large dependency (~50+ transitive packages) with frequent releases and breaking changes. The version pin (`>=1.73.6,<1.80.12`) is already narrow, indicating past compatibility issues. Every LiteLLM upgrade is a risk surface.
+
+### Simpler async story
+The event loop issue we just fixed (LiteLLM's `LoggingWorker` queue binding to a stale loop) is symptomatic of LiteLLM's internal async state being opaque and hard to reason about. With a direct HTTP client, DD controls all async state.
+
+### Thread-safety workarounds become unnecessary
+`ThreadSafeCache`, the `LoggingCallbackManager.MAX_CALLBACKS` patch, and verbose logger suppression are all workarounds for LiteLLM behavior. They represent ongoing maintenance burden.
+
+### Overfit abstraction
+DD uses the Router as a single-deployment client with retry logic. The multi-model routing, caching, and callback infrastructure of LiteLLM's Router class is unused overhead.
+
+### Import performance
+LiteLLM's import time is significant (already lazy-loaded via `lazy_heavy_imports.py`). Removing it improves cold start time.
+
+### Better error messages
+DD already defines its own error types. Currently, LiteLLM's exceptions are caught and re-raised as DD errors. With a direct client, DD can produce better error messages without an intermediate translation layer.
+
+---
+
+## Arguments Against Removal
+
+### Provider format translation
+LiteLLM handles API format differences between providers. If DD only targets OpenAI-compatible endpoints, this is irrelevant. If it supports Anthropic, Cohere, etc., this is significant work to reimplement.
+
+### Battle-tested retry logic
+LiteLLM's Router has been hardened over many releases for rate limiting, retry-after headers, connection pooling, and edge cases. Reimplementing this from scratch risks regressions.
+
+### Maintained by others
+LiteLLM receives frequent updates for new models, API changes, and provider additions. DD's replacement would need to be maintained by the DD team.
+
+### Feature velocity risk
+If DD later needs streaming, function calling improvements, vision model support, or new provider integrations, LiteLLM provides these incrementally. A custom client requires explicit implementation for each.
+
+---
+
+## Blast Radius (by Phase)
+
+All paths relative to `packages/data-designer-engine/src/data_designer/engine/` unless noted.
+
+### Phase 1: Replace Router with ModelClient in `models_v2/`
+
+**Response type strategy:** Keep the OpenAI response format (`choices[0].message.content`, `choices[0].message.tool_calls`, `usage.prompt_tokens`, etc.) as DD's canonical model response type. The `openai` SDK already returns typed objects in this shape — use them directly or define DD-owned dataclasses with the same structure. This means **zero changes** to response access sites in the facade, MCP facade, or anywhere else that reads model responses. Non-OpenAI providers (Phase 3) will be responsible for translating their native responses into this format within their adapter.
+
+**New files:**
+- `models_v2/client.py` — `ModelClient` protocol, `OpenAIModelClient` adapter wrapping the `openai` SDK
+
+**Delete:**
+- `models_v2/litellm_overrides.py` — `CustomRouter`, `ThreadSafeCache`, `apply_litellm_patches()` no longer needed in `models_v2/`
+
+**Heavy modification:**
+- `models_v2/facade.py` — Replace `self._router.completion()` / `self._router.acompletion()` with `ModelClient` calls. Response access patterns (`response.choices[0].message.content`, etc.) stay the same since we keep the OpenAI format. Replace `_get_litellm_deployment()` with adapter construction.
+- `models_v2/errors.py` — Replace `litellm.exceptions.*` matching with `openai` SDK exception matching in `handle_llm_exceptions()`
+
+**Light modification:**
+- `models_v2/factory.py` — Remove `apply_litellm_patches()` call, remove litellm imports, construct `ModelClient` adapter instead of `CustomRouter`
+
+**Tests (for `models_v2/` path):**
+- `tests/engine/models/test_facade.py` — Medium rewrite: 26 tests, replace `CustomRouter` patches with `ModelClient` mocks. Response object construction in mocks can use `openai` SDK types directly (same shape as today's litellm types).
+- `tests/engine/models/test_model_errors.py` — Medium rewrite: 7 tests, replace `litellm.exceptions.*` with `openai` SDK exceptions
+- `tests/engine/models/test_model_registry.py` — Light: remove `apply_litellm_patches` mock
+- `tests/engine/models/conftest.py` — Replace 2 fixtures that construct LiteLLM response objects (use `openai` SDK types instead — same shape)
+- `scripts/benchmarks/benchmark_engine_v2.py` — Replace `CustomRouter` import/patches with `ModelClient` mocks
+
+**Dependency:**
+- `pyproject.toml` — Add `openai` as direct dependency (already transitive via litellm; no new weight)
+
+**NOT touched:** `models/` directory is entirely unchanged. `engine/mcp/`, column generators, dataset builders, config layer, validators — all unchanged. No response format changes means no cross-layer ripple.
+
+---
+
+### Phase 2: Validate
+
+**No code changes.** Validation only:
+- Benchmark: `--mode compare` between `models/` (litellm, env var off) and `models_v2/` (direct SDK, env var on)
+- Full test suite: `uv run pytest packages/data-designer-engine/tests -x -q`
+- Real inference: pdf_qa recipe or equivalent with `DATA_DESIGNER_ASYNC_ENGINE=1`
+
+---
+
+### Phase 3: Additional provider adapters
+
+**New files:**
+- `models_v2/adapters/anthropic.py` — `AnthropicModelClient`
+- `models_v2/adapters/bedrock.py` — `BedrockModelClient`
+
+**Modification:**
+- `config/models.py` — `ModelProvider.provider_type: str` → `ProviderType` enum with Pydantic string coercion
+- `models_v2/factory.py` — Adapter selection: `match provider_type` → construct appropriate `ModelClient`
+- `engine/mcp/facade.py` — If Anthropic's flat tool_use blocks need different extraction than OpenAI's nested format, the tool call normalization logic needs updating. **This is the highest-risk cross-layer change.**
+- `models/utils.py` — `ChatMessage.to_dict()` may need to support Anthropic's message format for multi-turn conversations with tool calls
+
+**Tests:**
+- New test files for Anthropic and Bedrock adapters
+- Update tool call extraction tests if MCP facade changes
+
+**Dependency:**
+- `pyproject.toml` — Add `anthropic`. Bedrock: add `boto3`/`aiobotocore` or use `asyncio.to_thread()` with sync boto3.
+
+**NOT touched:** `models/` still unchanged — litellm fallback remains available.
+
+---
+
+### Phase 4: Consolidate and drop dependency
+
+**Delete entirely:**
+- `models/` directory (all files: `facade.py`, `errors.py`, `factory.py`, `litellm_overrides.py`, `registry.py`, `usage.py`, `telemetry.py`, `utils.py`, `parsers/`, `recipes/`)
+- `tests/engine/models/test_litellm_overrides.py`
+
+**Modification:**
+- `models/__init__.py` — Remove the path redirect hack; `models_v2/` becomes the sole implementation (or rename to `models/`)
+- `lazy_heavy_imports.py` — Remove `litellm` from lazy import registry
+- `pyproject.toml` — Remove `litellm>=1.73.6,<1.80.12`
+- `uv.lock` — Regenerate
+
+**Cleanup (non-functional):**
+- `config/column_configs.py` — Docstring mentions "via LiteLLM"
+- `engine/resources/resource_provider.py` — Comment mentions "heavy dependencies like litellm"
+- `engine/mcp/facade.py` — Type hint comment references `litellm.ModelResponse`
+- `README.md`, `AGENTS.md` — Documentation references to LiteLLM
+- `async_concurrency.py` — Comment mentions "libraries (like LiteLLM)"
+
+---
+
+## What Needs to Be Built
+
+### 1. `ModelClient` Interface + Provider Adapters
+
+Replace LiteLLM Router with a `ModelClient` protocol and thin adapter classes that wrap the official provider SDKs. **No raw HTTP** — the SDKs handle networking, connection pooling, retry/backoff, and rate limiting internally.
+
+```python
+class ModelClient(Protocol):
+ def completion(self, messages: list[dict], **kwargs) -> CompletionResponse: ...
+ async def acompletion(self, messages: list[dict], **kwargs) -> CompletionResponse: ...
+ def embedding(self, input_texts: list[str], **kwargs) -> EmbeddingResponse: ...
+ async def aembedding(self, input_texts: list[str], **kwargs) -> EmbeddingResponse: ...
+```
+
+Each adapter's job is purely **translation**:
+- Translate DD's `ModelConfig` + `InferenceParams` → SDK-specific call parameters
+- Call the SDK method (e.g., `await self._client.chat.completions.create(...)`)
+- Translate the SDK response → DD's `CompletionResponse` / `EmbeddingResponse`
+- Catch SDK-specific exceptions → DD error types
+
+The SDK client is created once in the factory (same lifecycle as `CustomRouter` today) and reused for all calls. No need for a dedicated I/O service like `mcp/io.py` — the official SDKs already manage connection pools, event loops, and request lifecycle internally.
+
+### 2. Response Types
+
+Define lightweight response dataclasses replacing LiteLLM's `ModelResponse` and `EmbeddingResponse`:
+
+```python
+@dataclass
+class CompletionResponse:
+ content: str
+ reasoning_content: str | None
+ tool_calls: list[ToolCall] | None
+ usage: UsageInfo | None
+
+@dataclass
+class EmbeddingResponse:
+ embeddings: list[list[float]]
+ usage: UsageInfo | None
+
+@dataclass
+class UsageInfo:
+ prompt_tokens: int
+ completion_tokens: int
+```
+
+**Scope**: ~50 lines. The existing code already accesses these fields — the dataclass just formalizes the contract.
+
+### 3. Error Handling
+
+Each SDK has its own exception hierarchy. The adapter for each provider catches SDK-specific exceptions and maps them to DD's existing error types.
+
+**OpenAI SDK** — exception types map almost 1:1 to DD errors (LiteLLM's exception hierarchy was modeled on OpenAI's):
+`BadRequestError(400)`, `AuthenticationError(401)`, `PermissionDeniedError(403)`, `NotFoundError(404)`, `RateLimitError(429)`, `InternalServerError(5xx)`, `APIConnectionError`, `APITimeoutError`
+
+**Anthropic SDK** — simpler hierarchy:
+`APIStatusError` (with `.status_code`), `RateLimitError`, `APIConnectionError`, `APITimeoutError`. Adapter checks status code for finer-grained mapping.
+
+**Bedrock** — all errors via `botocore.exceptions.ClientError`:
+Check `response['Error']['Code']` for `ValidationException(400)`, `AccessDeniedException(403)`, `ThrottlingException(429)`, `InternalServerException(500)`, etc.
+
+**Nuance**: Some providers encode context window errors as 400 with specific error messages. The existing `parse_context_window_exceeded_error()` and `parse_bad_request_error()` logic handles this — it would need to match on response body strings, same as today.
+
+The DD error types and `FormattedLLMErrorMessage` formatting already exist. Only the matching logic changes per adapter.
+
+### 4. Retry & Backoff
+
+**This varies significantly by provider — the claim that "each SDK handles its own retries" is only partially true.**
+
+**OpenAI SDK**: Built-in. Default 2 retries, exponential backoff + jitter (0.5s initial, 8s cap). Auto-retries 408, 409, 429, 5xx. Respects `Retry-After` headers (capped at 60s). Configurable via `max_retries` on client or per-request. DD's `CustomRouter` defaults (2s initial, 20% jitter, 3 retries) become client configuration.
+
+**Anthropic SDK**: Built-in. Same defaults and behavior as OpenAI SDK (0.5s initial, 8s cap, 2 retries). Auto-retries connection errors, 408, 409, 429, 5xx. Same `max_retries` configuration.
+
+**Bedrock**: **Retry is NOT fully handled.** boto3 has three retry modes (legacy, standard, adaptive), but `ThrottlingException` (429 rate limiting) is **not auto-retried in any mode**. Only `ModelNotReadyException` (also 429, for cold-start) is auto-retried. DD must implement its own retry logic for Bedrock throttling — exponential backoff with jitter, same as `CustomRouter` does today.
+
+**Bottom line**: For OpenAI and Anthropic, DD can rely on the SDK's built-in retry. For Bedrock, DD needs a standalone retry utility (port of `CustomRouter`'s backoff logic) or a wrapper around the Bedrock adapter.
+
+### 5. Provider Adapters (use official SDKs directly)
+
+**Decision**: DD needs multi-provider support, but the set is bounded: OpenAI, Anthropic, and Bedrock. Use each provider's official SDK directly:
+
+- **OpenAI-compatible** (`openai`): Covers OpenAI, NVIDIA NIM, Azure OpenAI, and any OpenAI-compatible endpoint. Native async via `AsyncOpenAI`. Built-in retry + rate limit handling. Response format is what DD already expects (`choices[0].message.content`). **Thinnest adapter — mostly passthrough.**
+- **Anthropic** (`anthropic`): Native async via `AsyncAnthropic`. Built-in retry. But response format differs: content is an array of typed blocks (`text`, `tool_use`, `thinking`), not a single string. Extended thinking is native but structurally different. **Adapter must translate content blocks → DD's expected format.**
+- **Bedrock** (`boto3` / `aiobotocore`): Sync-only in boto3; async requires `aiobotocore` as an additional dependency (or `asyncio.to_thread()` as a simpler fallback). No auto-retry for throttling. Response format is Bedrock-native (`output.message.content[].text`), not OpenAI-compatible. **Most adapter work: retry, async wrapper, and response translation.** AWS does offer an `/openai/v1` endpoint that returns OpenAI-compatible responses, which could reduce translation work.
+
+Each adapter implements the `ModelClient` interface, translating DD's `ModelConfig` into the SDK-specific call and normalizing the response back to DD's `CompletionResponse` / `EmbeddingResponse` types.
+
+**Key insight**: This approach also simplifies retry/backoff — each SDK handles its own retries natively. DD's `CustomRouter` backoff logic may reduce to just configuration on the underlying client, rather than a reimplementation.
+
+---
+
+## Recommended Approach
+
+### Architecture: Parallel stack in `models_v2/`
+
+The `engine/models/` and `engine/models_v2/` directories are near-complete copies, switchable at runtime via a `__init__.py` path redirect on `DATA_DESIGNER_ASYNC_ENGINE`. Only 2 of 14 files actually differ (`facade.py` adds async methods, `errors.py` adds async decorator). The other 12 are pure copy-paste duplicates.
+
+**Strategy**: Build the new non-litellm implementation entirely within `models_v2/`. Leave `models/` untouched as the stable litellm-backed fallback. The existing env var switch (`DATA_DESIGNER_ASYNC_ENGINE=1`) already gates which module path is used. Once `models_v2/` is validated in production, consolidate by deleting `models/` and dropping the litellm dependency.
+
+This approach:
+- Avoids a risky big-bang swap — litellm remains available as fallback
+- Contains all new work to `models_v2/` (6 files to modify, not 12)
+- Reuses the existing runtime switching mechanism
+- Defers consolidation and dep removal to a clean follow-up
+
+### Phase 1: Replace Router with ModelClient in `models_v2/`
+- Define `ModelClient` protocol and DD-owned response types in `models_v2/`
+- Implement `OpenAIModelClient` using the `openai` SDK (already a transitive dep)
+- Rewrite `models_v2/facade.py` to use `ModelClient` instead of `CustomRouter`
+- Rewrite `models_v2/errors.py` to match on OpenAI SDK exceptions instead of litellm exceptions
+- Remove `models_v2/litellm_overrides.py` and litellm imports from `models_v2/factory.py`
+- Update response access sites within `models_v2/` (and any shared code that receives responses)
+- **Result**: `models_v2/` is litellm-free, `models/` is unchanged
+
+### Phase 2: Validate
+- Run benchmark: `--mode compare` to verify identical output between `models/` (litellm) and `models_v2/` (direct SDK)
+- Run full test suite
+- Run real inference (pdf_qa recipe or equivalent) with `DATA_DESIGNER_ASYNC_ENGINE=1`
+- **Result**: Confidence that the new stack is correct
+
+### Phase 3: Additional provider adapters
+- `AnthropicModelClient`: HIGH risk — content block → string translation, tool_use block → OpenAI tool_calls format, thinking block → reasoning_content. Requires changes to MCP facade tool extraction and ChatMessage serialization.
+- `BedrockModelClient`: HIGH risk — manual throttle retry, async via `to_thread` or `aiobotocore`, response format translation from Converse API shape.
+- `ProviderType` enum in config with Pydantic string coercion for backwards compatibility
+- Each adapter raises explicit `UnsupportedFeatureError` for capabilities the provider doesn't support
+- **Result**: Full provider coverage; `models/` (litellm) still available as fallback until all adapters are proven
+
+### Phase 4: Consolidate and drop dependency
+- Delete `models/` directory
+- Remove `__init__.py` path redirect hack
+- Remove `litellm` from `pyproject.toml`
+- Remove `litellm` from `lazy_heavy_imports.py`
+- Clean up `uv.lock`
+- Update documentation references
+- **Result**: Single implementation, litellm fully removed. Only after all provider adapters are validated in production.
+
+---
+
+## Design Decisions (Answered)
+
+1. ~~**Which `provider_type` values are used in production?**~~ **Answered**: OpenAI (including NIM), Anthropic, and Bedrock. Bounded set — use official SDKs directly behind a `ModelClient` interface.
+
+2. ~~**Is streaming on the roadmap?**~~ **Answered**: No. Streaming will not be supported. This simplifies the `ModelClient` interface — no need for `stream_completion()` or async iterator return types. Each method is a simple request/response call.
+
+3. ~~**How does `ModelConfig.provider_type` map to adapter selection?**~~ **Answered**: `provider_type` should become an **enum**, not remain a free-form string. The enum values determine which `ModelClient` adapter is instantiated. This makes the mapping explicit and catches misconfiguration at validation time rather than at runtime.
+
+4. ~~**What about provider-specific features?**~~ **Answered**: The `ModelClient` interface targets the **OpenAI feature superset** — completions, embeddings, tool calling, extended thinking, etc. OpenAI's adapter is the full implementation. Other providers implement what they can and **raise explicit incompatibility errors** for unsupported features (e.g., if a provider doesn't support tool calling, the adapter raises a clear error rather than silently degrading). This means:
+ - `ModelClient` defines the full interface anchored to OpenAI's capabilities
+ - `OpenAIModelClient` implements everything
+ - `AnthropicModelClient` implements everything but translates response formats (content blocks → string, thinking blocks, tool_use blocks)
+ - `BedrockModelClient` implements core completions/embeddings; raises `UnsupportedFeatureError` for anything Bedrock's Converse API doesn't support
+
+---
+
+## Risk Assessment
+
+| Component | Risk | Why |
+|---|---|---|
+| `ModelClient` interface | Low | Well-understood contract; OpenAI response shape is the reference |
+| OpenAI adapter | Low | Thinnest adapter — response format matches, retry built into SDK, `openai` already a transitive dep |
+| Anthropic adapter | **High** | Content blocks vs strings, flat tool_use vs nested function format, thinking blocks vs reasoning_content field. Leaks into MCP facade and trace storage. |
+| Bedrock adapter | **High** | Manual throttle retry, async requires extra dep, non-OpenAI response format |
+| Error mapping (OpenAI) | Low | SDK exceptions map 1:1 to DD errors (LiteLLM modeled its hierarchy on OpenAI's) |
+| Response type migration | Medium | Existing code accesses `response.choices[0].message.content` everywhere — either keep structural parity or coordinate refactor of all access sites |
+| Test migration | Medium | ~56 test functions need modification; 26 require significant rework (all in `tests/engine/models/`) |
+| Parallel stack validation | Low | Same env-var gating pattern as the async engine; benchmark already validates output correctness |
+
+OpenAI adapter + validation (Phases 1-2) is low risk. The high-risk work is Anthropic and Bedrock adapters (Phase 4), which can be deferred and tackled independently.
+
+---
+
+## Review Findings (Moth Swarm)
+
+10 independent reviewers examined this report against the actual codebase. Key corrections and additions:
+
+### Blast radius is larger than stated
+
+The original count of 2 test files is incomplete. The full test impact:
+
+| File | Impact | Details |
+|---|---|---|
+| `test_facade.py` | Heavy rewrite | 26 test functions, imports `ModelResponse`/`EmbeddingResponse`/`Choices`/`Message` from litellm, 6 `CustomRouter` patches |
+| `test_litellm_overrides.py` | Delete | 11 tests for `CustomRouter`, `ThreadSafeCache`, `apply_litellm_patches()` |
+| `test_model_errors.py` | Medium rewrite | 7 tests, imports all 12 `litellm.exceptions.*` types for parametrized error mapping tests |
+| `test_model_registry.py` | Light touch | 1 test patches `apply_litellm_patches()` |
+| `conftest.py` (models) | Light touch | 2 fixtures construct `ModelResponse` and `EmbeddingResponse` objects |
+| `benchmark_engine_v2.py` | Medium rewrite | Imports `CustomRouter`, patches `completion` and `acompletion` |
+
+Total: ~56 test functions need modification, with 26 requiring significant rework. All contained within `tests/engine/models/`.
+
+### Anthropic adapter is HIGH risk, not medium
+
+The Anthropic SDK's response format is structurally incompatible with DD's assumptions:
+
+- **Content is an array of typed blocks** (`text`, `thinking`, `tool_use`), not `choices[0].message.content` (string). DD accesses `response.choices[0].message.content` in facade.py, models_v2/facade.py, and mcp/facade.py.
+- **Tool calls are flat content blocks** (`name`, `input`), not nested OpenAI format (`function.name`, `function.arguments`). The MCP facade's tool extraction logic (`mcp/facade.py:340-353`) assumes the nested structure.
+- **Reasoning is a content block**, not a field. DD uses `getattr(message, "reasoning_content", None)` — Anthropic returns `ThinkingBlock(type="thinking", thinking="...")` in the content array.
+- **This leaks beyond the adapter** into the MCP facade (tool extraction), the generation loop (content access), ChatMessage serialization (trace storage), and multimodal content formatting.
+
+The Anthropic adapter requires refactoring core response handling, not just wrapping the SDK.
+
+### ModelClient interface needs more specificity
+
+The proposed 4-method interface is too thin:
+- Missing explicit `model` parameter (currently passed per-request to Router)
+- `ToolCall` type is undefined — needs `id`, `type`, `function.name`, `function.arguments`
+- Response type structure decision: the proposed flat `CompletionResponse` breaks all existing `response.choices[0].message.content` access sites. Either keep structural parity with OpenAI's response shape (nested `choices[0].message`) or coordinate a refactor of every access site.
+- `consolidate_kwargs()` merges `inference_parameters.generate_kwargs`, `extra_body`, and `extra_headers` before calling the Router — the adapter contract should document what's in `**kwargs` when it arrives.
+
+### Retry-after header extraction is LiteLLM-specific
+
+`CustomRouter._extract_retry_delay_from_headers()` uses:
+- `exception.litellm_response_headers` (LiteLLM-specific attribute)
+- `exception.response.headers` (httpx attribute, works with OpenAI/Anthropic SDKs)
+- `litellm.utils._get_retry_after_from_exception_header()` (LiteLLM utility)
+
+For OpenAI/Anthropic, the SDKs handle retry internally — this logic is only needed for Bedrock. But the Retry-After header parsing utility (`_get_retry_after_from_exception_header`) needs reimplementation as a standalone function.
+
+### Dependency impact is lighter than expected
+
+- `openai` is already a transitive dependency via litellm — promoting to direct dep adds zero new weight
+- `httpx` stays regardless — DD depends on it directly for `engine/models/telemetry.py` and `engine/validators/remote.py`
+- Net dependency change for OpenAI-only: a reduction (~10-15 packages removed with litellm)
+- Adding Anthropic is lightweight (most deps already present via openai)
+- Bedrock (boto3/botocore) is the heaviest new addition
+
+### Config migration is clean
+
+`provider_type` is only consumed in one place (`_get_litellm_deployment()` → `f"{provider_type}/{model_name}"`). Pydantic handles string → enum coercion automatically, so existing YAML/JSON configs and programmatic construction continue to work. The CLI text field for provider_type would become a select/choice field. All predefined providers and examples use `"openai"` — no existing users need migration.
diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py
index 42e551be..54b8fd58 100644
--- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py
+++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py
@@ -96,6 +96,45 @@ def generate(self, data: dict) -> dict:
return data
+ async def agenerate(self, data: dict) -> dict:
+ deserialized_record = deserialize_json_values(data)
+
+ multi_modal_context = None
+ if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0:
+ multi_modal_context = []
+ for context in self.config.multi_modal_context:
+ multi_modal_context.extend(context.get_contexts(deserialized_record))
+
+ response, trace = await self.model.agenerate(
+ prompt=self.prompt_renderer.render(
+ record=deserialized_record,
+ prompt_template=self.config.prompt,
+ prompt_type=PromptType.USER_PROMPT,
+ ),
+ system_prompt=self.prompt_renderer.render(
+ record=deserialized_record,
+ prompt_template=self.config.system_prompt,
+ prompt_type=PromptType.SYSTEM_PROMPT,
+ ),
+ parser=self.response_recipe.parse,
+ multi_modal_context=multi_modal_context,
+ tool_alias=self.config.tool_alias,
+ max_correction_steps=self.max_conversation_correction_steps,
+ max_conversation_restarts=self.max_conversation_restarts,
+ purpose=f"running generation for column '{self.config.name}'",
+ )
+
+ serialized_output = self.response_recipe.serialize_output(response)
+ data[self.config.name] = self._process_serialized_output(serialized_output)
+
+ should_save_trace = (
+ self.config.with_trace or self.resource_provider.run_config.debug_override_save_all_column_traces
+ )
+ if should_save_trace:
+ data[self.config.name + TRACE_COLUMN_POSTFIX] = [message.to_dict() for message in trace]
+
+ return data
+
def _process_serialized_output(self, serialized_output: str) -> str | dict | list:
"""Process the serialized output from the model. Subclasses can override to customize deserialization."""
return serialized_output
diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py
index 577c2976..9de339f5 100644
--- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py
+++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py
@@ -7,6 +7,7 @@
import importlib.metadata
import json
import logging
+import os
import time
import uuid
from pathlib import Path
@@ -31,6 +32,7 @@
from data_designer.engine.dataset_builders.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage
from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
+from data_designer.engine.dataset_builders.utils.async_concurrency import AsyncConcurrentExecutor
from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor
from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager
@@ -50,6 +52,11 @@
logger = logging.getLogger(__name__)
+DATA_DESIGNER_ASYNC_ENGINE = os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "0") == "1"
+
+if DATA_DESIGNER_ASYNC_ENGINE:
+ logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled — using async concurrency")
+
_CLIENT_VERSION: str = importlib.metadata.version("data-designer-engine")
@@ -205,7 +212,11 @@ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
max_workers = self._resource_provider.run_config.non_inference_max_parallel_workers
if isinstance(generator, ColumnGeneratorWithModel):
max_workers = generator.inference_parameters.max_parallel_requests
- self._fan_out_with_threads(generator, max_workers=max_workers)
+ if DATA_DESIGNER_ASYNC_ENGINE:
+ logger.info("⚡ Using async engine for concurrent execution")
+ self._fan_out_with_async(generator, max_workers=max_workers)
+ else:
+ self._fan_out_with_threads(generator, max_workers=max_workers)
def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True))
@@ -227,6 +238,41 @@ def _run_mcp_tool_check_if_needed(self) -> None:
raise DatasetGenerationError(f"Tool alias(es) {tool_aliases!r} specified but no MCPRegistry configured.")
self._resource_provider.mcp_registry.run_health_check(tool_aliases)
+ def _fan_out_with_async(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
+ if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL:
+ raise DatasetGenerationError(
+ f"Generator {generator.name} is not a {GenerationStrategy.CELL_BY_CELL} "
+ "generator so concurrency through async is not supported."
+ )
+
+ progress_tracker = ProgressTracker(
+ total_records=self.batch_manager.num_records_batch,
+ label=f"{generator.config.column_type} column '{generator.config.name}'",
+ )
+ progress_tracker.log_start(max_workers)
+
+ settings = self._resource_provider.run_config
+ executor = AsyncConcurrentExecutor(
+ max_workers=max_workers,
+ column_name=generator.config.name,
+ result_callback=self._make_result_callback(progress_tracker),
+ error_callback=self._make_error_callback(progress_tracker),
+ shutdown_error_rate=settings.shutdown_error_rate,
+ shutdown_error_window=settings.shutdown_error_window,
+ disable_early_shutdown=settings.disable_early_shutdown,
+ )
+
+ work_items = [
+ (generator.agenerate(record), {"index": i}) for i, record in self.batch_manager.iter_current_batch()
+ ]
+ executor.run(work_items)
+
+ progress_tracker.log_final()
+
+ if len(self._records_to_drop) > 0:
+ self.batch_manager.drop_records(self._records_to_drop)
+ self._records_to_drop.clear()
+
def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL:
raise DatasetGenerationError(
diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py
new file mode 100644
index 00000000..896a99cb
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py
@@ -0,0 +1,168 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+import threading
+from collections.abc import Coroutine
+from dataclasses import dataclass
+from typing import Any, Generic, TypeVar
+
+from data_designer.engine.dataset_builders.utils.concurrency import (
+ CallbackWithContext,
+ ErrorCallbackWithContext,
+ ExecutorResults,
+)
+from data_designer.engine.errors import DataDesignerRuntimeError
+
+logger = logging.getLogger(__name__)
+
+T = TypeVar("T")
+
+
+@dataclass(frozen=True, slots=True)
+class Success(Generic[T]):
+ index: int
+ value: T
+
+
+@dataclass(frozen=True, slots=True)
+class Failure:
+ index: int
+ error: Exception
+
+
+TaskResult = Success[T] | Failure
+
+_loop: asyncio.AbstractEventLoop | None = None
+_thread: threading.Thread | None = None
+_lock = threading.Lock()
+
+
+def _run_loop(loop: asyncio.AbstractEventLoop) -> None:
+ asyncio.set_event_loop(loop)
+ loop.run_forever()
+
+
+def _ensure_async_engine_loop() -> asyncio.AbstractEventLoop:
+ """Get or create a persistent event loop for async engine work.
+
+ A single event loop is shared across all AsyncConcurrentExecutor instances
+ to avoid breaking libraries (like LiteLLM) that bind internal async state
+ to a specific event loop.
+ """
+ global _loop, _thread
+ with _lock:
+ if _loop is None or not _loop.is_running():
+ _loop = asyncio.new_event_loop()
+ _thread = threading.Thread(target=_run_loop, args=(_loop,), daemon=True, name="AsyncEngine-EventLoop")
+ _thread.start()
+ return _loop
+
+
+class AsyncConcurrentExecutor:
+ """Async equivalent of ConcurrentThreadExecutor.
+
+ Executes a batch of coroutines with bounded concurrency, error rate
+ monitoring, and early shutdown semantics. Callers remain synchronous —
+ the ``run()`` method submits work to a persistent background event loop.
+
+ No locks are needed because asyncio tasks run cooperatively on a
+ single thread — mutations to ``_results`` are always sequential.
+ """
+
+ def __init__(
+ self,
+ *,
+ max_workers: int,
+ column_name: str,
+ result_callback: CallbackWithContext | None = None,
+ error_callback: ErrorCallbackWithContext | None = None,
+ shutdown_error_rate: float = 0.50,
+ shutdown_error_window: int = 10,
+ disable_early_shutdown: bool = False,
+ ) -> None:
+ self._column_name = column_name
+ self._max_workers = max_workers
+ self._result_callback = result_callback
+ self._error_callback = error_callback
+ self._shutdown_error_rate = shutdown_error_rate
+ self._shutdown_window_size = shutdown_error_window
+ self._disable_early_shutdown = disable_early_shutdown
+ self._results = ExecutorResults(failure_threshold=shutdown_error_rate)
+
+ @property
+ def results(self) -> ExecutorResults:
+ return self._results
+
+ @property
+ def max_workers(self) -> int:
+ return self._max_workers
+
+ @property
+ def shutdown_error_rate(self) -> float:
+ return self._shutdown_error_rate
+
+ @property
+ def shutdown_window_size(self) -> int:
+ return self._shutdown_window_size
+
+ def run(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> None:
+ """Execute all work items concurrently. Callers remain synchronous."""
+ logger.debug(
+ f"AsyncConcurrentExecutor: launching {len(work_items)} tasks "
+ f"with max_workers={self._max_workers} for column '{self._column_name}'"
+ )
+ loop = _ensure_async_engine_loop()
+ future = asyncio.run_coroutine_threadsafe(self._run_all(work_items), loop)
+ future.result()
+
+ async def _run_all(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> None:
+ self._semaphore = asyncio.Semaphore(self._max_workers)
+ self._shutdown_event = asyncio.Event()
+
+ async with asyncio.TaskGroup() as tg:
+ for i, (coro, context) in enumerate(work_items):
+ tg.create_task(self._run_task(i, coro, context))
+
+ if not self._disable_early_shutdown and self._results.early_shutdown:
+ self._raise_task_error()
+
+ async def _run_task(self, index: int, coro: Coroutine[Any, Any, Any], context: dict | None) -> None:
+ if self._shutdown_event.is_set():
+ return
+
+ async with self._semaphore:
+ if self._shutdown_event.is_set():
+ return
+
+ try:
+ result = await coro
+ self._results.completed_count += 1
+ self._results.success_count += 1
+ if self._result_callback is not None:
+ self._result_callback(result, context=context)
+ except Exception as err:
+ self._results.completed_count += 1
+ self._results.error_trap.handle_error(err)
+ if not self._disable_early_shutdown and self._results.is_error_rate_exceeded(
+ self._shutdown_window_size
+ ):
+ if not self._results.early_shutdown:
+ self._results.early_shutdown = True
+ self._shutdown_event.set()
+ if self._error_callback is not None:
+ self._error_callback(err, context=context)
+
+ def _raise_task_error(self) -> None:
+ raise DataDesignerRuntimeError(
+ "\n".join(
+ [
+ " |-- Data generation was terminated early due to error rate exceeding threshold.",
+ f" |-- The summary of encountered errors is: \n{json.dumps(self._results.summary, indent=4)}",
+ ]
+ )
+ )
diff --git a/packages/data-designer-engine/src/data_designer/engine/models/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/__init__.py
index 52a7a9da..ba053a37 100644
--- a/packages/data-designer-engine/src/data_designer/engine/models/__init__.py
+++ b/packages/data-designer-engine/src/data_designer/engine/models/__init__.py
@@ -1,2 +1,27 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import os
+from pathlib import Path
+
+_ASYNC_ENGINE_ENV_VAR = "DATA_DESIGNER_ASYNC_ENGINE"
+_TRUTHY_ENV_VALUES = {"1", "true", "yes"}
+
+
+def _is_async_engine_enabled() -> bool:
+ return os.getenv(_ASYNC_ENGINE_ENV_VAR, "").lower() in _TRUTHY_ENV_VALUES
+
+
+def _redirect_to_models_v2() -> None:
+ models_v2_path = Path(__file__).resolve().parent.parent / "models_v2"
+ # Set DATA_DESIGNER_ASYNC_ENGINE before importing this package for it to take effect.
+ global __path__
+ __path__ = [str(models_v2_path)]
+ if __spec__ is not None:
+ __spec__.submodule_search_locations = [str(models_v2_path)]
+
+
+if __name__ == "data_designer.engine.models" and _is_async_engine_enabled():
+ _redirect_to_models_v2()
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/__init__.py
new file mode 100644
index 00000000..52a7a9da
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/errors.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/errors.py
new file mode 100644
index 00000000..0bbb26b4
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/errors.py
@@ -0,0 +1,325 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import logging
+from collections.abc import Callable
+from functools import wraps
+from typing import TYPE_CHECKING, Any
+
+from pydantic import BaseModel
+
+from data_designer.engine.errors import DataDesignerError
+from data_designer.lazy_heavy_imports import litellm
+
+if TYPE_CHECKING:
+ import litellm
+
+logger = logging.getLogger(__name__)
+
+
+def get_exception_primary_cause(exception: BaseException) -> BaseException:
+ """Returns the primary cause of an exception by walking backwards.
+
+ This recursive walkback halts when it arrives at an exception which
+ has no provided __cause__ (e.g. __cause__ is None).
+
+ Args:
+ exception (Exception): An exception to start from.
+
+ Raises:
+ RecursionError: if for some reason exceptions have circular
+ dependencies (seems impossible in practice).
+ """
+ if exception.__cause__ is None:
+ return exception
+ else:
+ return get_exception_primary_cause(exception.__cause__)
+
+
+class GenerationValidationFailureError(Exception): ...
+
+
+class ModelRateLimitError(DataDesignerError): ...
+
+
+class ModelTimeoutError(DataDesignerError): ...
+
+
+class ModelContextWindowExceededError(DataDesignerError): ...
+
+
+class ModelAuthenticationError(DataDesignerError): ...
+
+
+class ModelPermissionDeniedError(DataDesignerError): ...
+
+
+class ModelNotFoundError(DataDesignerError): ...
+
+
+class ModelUnsupportedParamsError(DataDesignerError): ...
+
+
+class ModelBadRequestError(DataDesignerError): ...
+
+
+class ModelInternalServerError(DataDesignerError): ...
+
+
+class ModelAPIError(DataDesignerError): ...
+
+
+class ModelUnprocessableEntityError(DataDesignerError): ...
+
+
+class ModelAPIConnectionError(DataDesignerError): ...
+
+
+class ModelStructuredOutputError(DataDesignerError): ...
+
+
+class ModelGenerationValidationFailureError(DataDesignerError): ...
+
+
+class FormattedLLMErrorMessage(BaseModel):
+ cause: str
+ solution: str
+
+ def __str__(self) -> str:
+ return "\n".join(
+ [
+ " |----------",
+ f" | Cause: {self.cause}",
+ f" | Solution: {self.solution}",
+ " |----------",
+ ]
+ )
+
+
+def handle_llm_exceptions(
+ exception: Exception, model_name: str, model_provider_name: str, purpose: str | None = None
+) -> None:
+ """Handle LLM-related exceptions and convert them to appropriate DataDesignerError errors.
+
+ This method centralizes the exception handling logic for LLM operations,
+ making it reusable across different contexts.
+
+ Args:
+ exception: The exception that was raised
+ model_name: Name of the model that was being used
+ model_provider_name: Name of the model provider that was being used
+ purpose: The purpose of the model usage to show as context in the error message
+ Raises:
+ DataDesignerError: A more user-friendly error with appropriate error type and message
+ """
+ purpose = purpose or "running generation"
+ authentication_error = FormattedLLMErrorMessage(
+ cause=f"The API key provided for model {model_name!r} was found to be invalid or expired while {purpose}.",
+ solution=f"Verify your API key for model provider and update it in your settings for model provider {model_provider_name!r}.",
+ )
+ err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose)
+ match exception:
+ # Common errors that can come from LiteLLM
+ case litellm.exceptions.APIError():
+ raise err_msg_parser.parse_api_error(exception, authentication_error) from None
+
+ case litellm.exceptions.APIConnectionError():
+ raise ModelAPIConnectionError(
+ FormattedLLMErrorMessage(
+ cause=f"Connection to model {model_name!r} hosted on model provider {model_provider_name!r} failed while {purpose}.",
+ solution="Check your network/proxy/firewall settings.",
+ )
+ ) from None
+
+ case litellm.exceptions.AuthenticationError():
+ raise ModelAuthenticationError(authentication_error) from None
+
+ case litellm.exceptions.ContextWindowExceededError():
+ raise err_msg_parser.parse_context_window_exceeded_error(exception) from None
+
+ case litellm.exceptions.UnsupportedParamsError():
+ raise ModelUnsupportedParamsError(
+ FormattedLLMErrorMessage(
+ cause=f"One or more of the parameters you provided were found to be unsupported by model {model_name!r} while {purpose}.",
+ solution=f"Review the documentation for model provider {model_provider_name!r} and adjust your request.",
+ )
+ ) from None
+
+ case litellm.exceptions.BadRequestError():
+ raise err_msg_parser.parse_bad_request_error(exception) from None
+
+ case litellm.exceptions.InternalServerError():
+ raise ModelInternalServerError(
+ FormattedLLMErrorMessage(
+ cause=f"Model {model_name!r} is currently experiencing internal server issues while {purpose}.",
+ solution=f"Try again in a few moments. Check with your model provider {model_provider_name!r} if the issue persists.",
+ )
+ ) from None
+
+ case litellm.exceptions.NotFoundError():
+ raise ModelNotFoundError(
+ FormattedLLMErrorMessage(
+ cause=f"The specified model {model_name!r} could not be found while {purpose}.",
+ solution=f"Check that the model name is correct and supported by your model provider {model_provider_name!r} and try again.",
+ )
+ ) from None
+
+ case litellm.exceptions.PermissionDeniedError():
+ raise ModelPermissionDeniedError(
+ FormattedLLMErrorMessage(
+ cause=f"Your API key was found to lack the necessary permissions to use model {model_name!r} while {purpose}.",
+ solution=f"Use an API key that has the right permissions for the model or use a model the API key in use has access to in model provider {model_provider_name!r}.",
+ )
+ ) from None
+
+ case litellm.exceptions.RateLimitError():
+ raise ModelRateLimitError(
+ FormattedLLMErrorMessage(
+ cause=f"You have exceeded the rate limit for model {model_name!r} while {purpose}.",
+ solution="Wait and try again in a few moments.",
+ )
+ ) from None
+
+ case litellm.exceptions.Timeout():
+ raise ModelTimeoutError(
+ FormattedLLMErrorMessage(
+ cause=f"The request to model {model_name!r} timed out while {purpose}.",
+ solution="Check your connection and try again. You may need to increase the timeout setting for the model.",
+ )
+ ) from None
+
+ case litellm.exceptions.UnprocessableEntityError():
+ raise ModelUnprocessableEntityError(
+ FormattedLLMErrorMessage(
+ cause=f"The request to model {model_name!r} failed despite correct request format while {purpose}.",
+ solution="This is most likely temporary. Try again in a few moments.",
+ )
+ ) from None
+
+ # Parsing and validation errors
+ case GenerationValidationFailureError():
+ raise ModelGenerationValidationFailureError(
+ FormattedLLMErrorMessage(
+ cause=f"The provided output schema was unable to be parsed from model {model_name!r} responses while {purpose}.",
+ solution="This is most likely temporary as we make additional attempts. If you continue to see more of this, simplify or modify the output schema for structured output and try again. If you are attempting token-intensive tasks like generations with high-reasoning effort, ensure that max_tokens in the model config is high enough to reach completion.",
+ )
+ ) from None
+
+ case DataDesignerError():
+ raise exception from None
+
+ case _:
+ raise DataDesignerError(
+ FormattedLLMErrorMessage(
+ cause=f"An unexpected error occurred while {purpose}.",
+ solution=f"Review the stack trace for more details: {exception}",
+ )
+ ) from exception
+
+
+def catch_llm_exceptions(func: Callable) -> Callable:
+ """This decorator should be used on any `ModelFacade` method that could potentially raise
+ exceptions that should turn into upstream user-facing errors.
+ """
+
+ @wraps(func)
+ def wrapper(model_facade: Any, *args, **kwargs):
+ try:
+ return func(model_facade, *args, **kwargs)
+ except Exception as e:
+ logger.debug(
+ "\n".join(
+ [
+ "",
+ "|----------",
+ f"| Caught an exception downstream of type {type(e)!r}. Re-raising it below as a custom error with more context.",
+ "|----------",
+ ]
+ ),
+ exc_info=True,
+ stack_info=True,
+ )
+ handle_llm_exceptions(
+ e, model_facade.model_name, model_facade.model_provider_name, purpose=kwargs.get("purpose")
+ )
+
+ return wrapper
+
+
+def acatch_llm_exceptions(func: Callable) -> Callable:
+ @wraps(func)
+ async def wrapper(model_facade: Any, *args: Any, **kwargs: Any) -> Any:
+ try:
+ return await func(model_facade, *args, **kwargs)
+ except Exception as e:
+ logger.debug(
+ "\n".join(
+ [
+ "",
+ "|----------",
+ f"| Caught an exception downstream of type {type(e)!r}. Re-raising it below as a custom error with more context.",
+ "|----------",
+ ]
+ ),
+ exc_info=True,
+ stack_info=True,
+ )
+ handle_llm_exceptions(
+ e, model_facade.model_name, model_facade.model_provider_name, purpose=kwargs.get("purpose")
+ )
+
+ return wrapper
+
+
+class DownstreamLLMExceptionMessageParser:
+ def __init__(self, model_name: str, model_provider_name: str, purpose: str):
+ self.model_name = model_name
+ self.model_provider_name = model_provider_name
+ self.purpose = purpose
+
+ def parse_bad_request_error(self, exception: litellm.exceptions.BadRequestError) -> DataDesignerError:
+ err_msg = FormattedLLMErrorMessage(
+ cause=f"The request for model {self.model_name!r} was found to be malformed or missing required parameters while {self.purpose}.",
+ solution="Check your request parameters and try again.",
+ )
+ if "is not a multimodal model" in str(exception):
+ err_msg = FormattedLLMErrorMessage(
+ cause=f"Model {self.model_name!r} is not a multimodal model, but it looks like you are trying to provide multimodal context while {self.purpose}.",
+ solution="Check your request parameters and try again.",
+ )
+ return ModelBadRequestError(err_msg)
+
+ def parse_context_window_exceeded_error(
+ self, exception: litellm.exceptions.ContextWindowExceededError
+ ) -> DataDesignerError:
+ cause = f"The input data for model '{self.model_name}' was found to exceed its supported context width while {self.purpose}."
+ try:
+ if "OpenAIException - This model's maximum context length is " in str(exception):
+ openai_exception_cause = (
+ str(exception).split("OpenAIException - ")[1].split("\n")[0].split(" Please reduce ")[0]
+ )
+ cause = f"{cause} {openai_exception_cause}"
+ except Exception:
+ pass
+ finally:
+ return ModelContextWindowExceededError(
+ FormattedLLMErrorMessage(
+ cause=cause,
+ solution="Check the model's supported max context width. Adjust the length of your input along with completions and try again.",
+ )
+ )
+
+ def parse_api_error(
+ self, exception: litellm.exceptions.InternalServerError, auth_error_msg: FormattedLLMErrorMessage
+ ) -> DataDesignerError:
+ if "Error code: 403" in str(exception):
+ return ModelAuthenticationError(auth_error_msg)
+
+ return ModelAPIError(
+ FormattedLLMErrorMessage(
+ cause=f"An unexpected API error occurred with model {self.model_name!r} while {self.purpose}.",
+ solution=f"Try again in a few moments. Check with your model provider {self.model_provider_name!r} if the issue persists.",
+ )
+ )
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/facade.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/facade.py
new file mode 100644
index 00000000..031ab919
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/facade.py
@@ -0,0 +1,495 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import asyncio
+import logging
+from collections.abc import Callable
+from copy import deepcopy
+from typing import TYPE_CHECKING, Any
+
+from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
+from data_designer.engine.mcp.errors import MCPConfigurationError
+from data_designer.engine.model_provider import ModelProviderRegistry
+from data_designer.engine.models.errors import (
+ GenerationValidationFailureError,
+ catch_llm_exceptions,
+ get_exception_primary_cause,
+)
+from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
+from data_designer.engine.models.parsers.errors import ParserException
+from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
+from data_designer.engine.models.utils import ChatMessage, prompt_to_messages
+from data_designer.engine.models_v2.errors import acatch_llm_exceptions
+from data_designer.engine.secret_resolver import SecretResolver
+from data_designer.lazy_heavy_imports import litellm
+
+if TYPE_CHECKING:
+ import litellm
+
+ from data_designer.engine.mcp.facade import MCPFacade
+ from data_designer.engine.mcp.registry import MCPRegistry
+
+logger = logging.getLogger(__name__)
+
+
+class ModelFacade:
+ def __init__(
+ self,
+ model_config: ModelConfig,
+ secret_resolver: SecretResolver,
+ model_provider_registry: ModelProviderRegistry,
+ *,
+ mcp_registry: MCPRegistry | None = None,
+ ) -> None:
+ self._model_config = model_config
+ self._secret_resolver = secret_resolver
+ self._model_provider_registry = model_provider_registry
+ self._mcp_registry = mcp_registry
+ self._litellm_deployment = self._get_litellm_deployment(model_config)
+ self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump())
+ self._usage_stats = ModelUsageStats()
+
+ @property
+ def model_name(self) -> str:
+ return self._model_config.model
+
+ @property
+ def model_provider(self) -> ModelProvider:
+ return self._model_provider_registry.get_provider(self._model_config.provider)
+
+ @property
+ def model_generation_type(self) -> GenerationType:
+ return self._model_config.generation_type
+
+ @property
+ def model_provider_name(self) -> str:
+ return self.model_provider.name
+
+ @property
+ def model_alias(self) -> str:
+ return self._model_config.alias
+
+ @property
+ def usage_stats(self) -> ModelUsageStats:
+ return self._usage_stats
+
+ def completion(
+ self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs
+ ) -> litellm.ModelResponse:
+ message_payloads = [message.to_dict() for message in messages]
+ logger.debug(
+ f"Prompting model {self.model_name!r}...",
+ extra={"model": self.model_name, "messages": message_payloads},
+ )
+ response = None
+ kwargs = self.consolidate_kwargs(**kwargs)
+ try:
+ response = self._router.completion(model=self.model_name, messages=message_payloads, **kwargs)
+ logger.debug(
+ f"Received completion from model {self.model_name!r}",
+ extra={
+ "model": self.model_name,
+ "response": response,
+ "text": response.choices[0].message.content,
+ "usage": self._usage_stats.model_dump(),
+ },
+ )
+ return response
+ except Exception as e:
+ raise e
+ finally:
+ if not skip_usage_tracking and response is not None:
+ self._track_usage(response)
+
+ def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
+ # Remove purpose from kwargs to avoid passing it to the model
+ kwargs.pop("purpose", None)
+ kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
+ if self.model_provider.extra_body:
+ kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
+ if self.model_provider.extra_headers:
+ kwargs["extra_headers"] = self.model_provider.extra_headers
+ return kwargs
+
+ def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None:
+ if tool_alias is None:
+ return None
+ if self._mcp_registry is None:
+ raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.")
+
+ try:
+ return self._mcp_registry.get_mcp(tool_alias=tool_alias)
+ except ValueError as exc:
+ raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc
+
+ @catch_llm_exceptions
+ def generate_text_embeddings(
+ self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
+ ) -> list[list[float]]:
+ logger.debug(
+ f"Generating embeddings with model {self.model_name!r}...",
+ extra={
+ "model": self.model_name,
+ "input_count": len(input_texts),
+ },
+ )
+ kwargs = self.consolidate_kwargs(**kwargs)
+ response = None
+ try:
+ response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
+ logger.debug(
+ f"Received embeddings from model {self.model_name!r}",
+ extra={
+ "model": self.model_name,
+ "embedding_count": len(response.data) if response.data else 0,
+ "usage": self._usage_stats.model_dump(),
+ },
+ )
+ if response.data and len(response.data) == len(input_texts):
+ return [data["embedding"] for data in response.data]
+ else:
+ raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
+ except Exception as e:
+ raise e
+ finally:
+ if not skip_usage_tracking and response is not None:
+ self._track_usage_from_embedding(response)
+
+ @catch_llm_exceptions
+ def generate(
+ self,
+ prompt: str,
+ *,
+ parser: Callable[[str], Any],
+ system_prompt: str | None = None,
+ multi_modal_context: list[dict[str, Any]] | None = None,
+ tool_alias: str | None = None,
+ max_correction_steps: int = 0,
+ max_conversation_restarts: int = 0,
+ skip_usage_tracking: bool = False,
+ purpose: str | None = None,
+ **kwargs,
+ ) -> tuple[Any, list[ChatMessage]]:
+ """Generate a parsed output with correction steps.
+
+ This generation call will attempt to generate an output which is
+ valid according to the specified parser, where "valid" implies
+ that the parser can process the LLM response without raising
+ an exception.
+
+ `ParserExceptions` are routed back
+ to the LLM as new rounds in the conversation, where the LLM is provided its
+ earlier response along with the "user" role responding with the exception string
+ (not traceback). This will continue for the number of rounds specified by
+ `max_correction_steps`.
+
+ Args:
+ prompt (str): Task prompt.
+ system_prompt (str, optional): Optional system instructions. If not specified,
+ no system message is provided and the model should use its default system
+ prompt.
+ parser (func(str) -> Any): A function applied to the LLM response which processes
+ an LLM response into some output object.
+ tool_alias (str | None): Optional tool configuration alias. When provided,
+ the model may call permitted tools from the configured MCP providers.
+ The alias must reference a ToolConfig registered in the MCPRegistry.
+ max_correction_steps (int): Maximum number of correction rounds permitted
+ within a single conversation. Note, many rounds can lead to increasing
+ context size without necessarily improving performance -- small language
+ models can enter repeated cycles which will not be solved with more steps.
+ Default: `0` (no correction).
+ max_conversation_restarts (int): Maximum number of full conversation restarts permitted
+ if generation fails. Default: `0` (no restarts).
+ skip_usage_tracking (bool): Whether to skip usage tracking. Default: `False`.
+ purpose (str): The purpose of the model usage to show as context in the error message.
+ It is expected to be used by the @catch_llm_exceptions decorator.
+ **kwargs: Additional arguments to pass to the model.
+
+ Returns:
+ A tuple containing:
+ - The parsed output object from the parser.
+ - The full trace of ChatMessage entries in the conversation, including any tool calls,
+ corrections, and reasoning traces. Callers can decide whether to store this.
+
+ Raises:
+ GenerationValidationFailureError: If the maximum number of retries or
+ correction steps are met and the last response failures on
+ generation validation.
+ MCPConfigurationError: If tool_alias is specified but no MCPRegistry is configured.
+ """
+ output_obj = None
+ tool_schemas = None
+ tool_call_turns = 0
+ curr_num_correction_steps = 0
+ curr_num_restarts = 0
+
+ mcp_facade = self._get_mcp_facade(tool_alias)
+
+ # Checkpoint for restarts - updated after tool calls so we don't repeat them
+ restart_checkpoint = prompt_to_messages(
+ user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
+ )
+ checkpoint_tool_call_turns = 0
+ messages: list[ChatMessage] = deepcopy(restart_checkpoint)
+
+ if mcp_facade is not None:
+ tool_schemas = mcp_facade.get_tool_schemas()
+
+ while True:
+ completion_kwargs = dict(kwargs)
+ if tool_schemas is not None:
+ completion_kwargs["tools"] = tool_schemas
+
+ completion_response = self.completion(
+ messages,
+ skip_usage_tracking=skip_usage_tracking,
+ **completion_kwargs,
+ )
+
+ # Process any tool calls in the response (handles parallel tool calling)
+ if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response):
+ tool_call_turns += 1
+
+ if tool_call_turns > mcp_facade.max_tool_call_turns:
+ # Gracefully refuse tool calls when budget is exhausted
+ messages.extend(mcp_facade.refuse_completion_response(completion_response))
+ else:
+ messages.extend(mcp_facade.process_completion_response(completion_response))
+
+ # Update checkpoint so restarts don't repeat tool calls
+ restart_checkpoint = deepcopy(messages)
+ checkpoint_tool_call_turns = tool_call_turns
+
+ continue # Back to top
+
+ # No tool calls remaining to process
+ response = completion_response.choices[0].message.content or ""
+ reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
+ messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None))
+ curr_num_correction_steps += 1
+
+ try:
+ output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below
+ break
+ except ParserException as exc:
+ if max_correction_steps == 0 and max_conversation_restarts == 0:
+ raise GenerationValidationFailureError(
+ "Unsuccessful generation attempt. No retries were attempted."
+ ) from exc
+
+ if curr_num_correction_steps <= max_correction_steps:
+ # Add user message with error for correction
+ messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc))))
+
+ elif curr_num_restarts < max_conversation_restarts:
+ curr_num_correction_steps = 0
+ curr_num_restarts += 1
+ messages = deepcopy(restart_checkpoint)
+ tool_call_turns = checkpoint_tool_call_turns
+
+ else:
+ raise GenerationValidationFailureError(
+ f"Unsuccessful generation despite {max_correction_steps} correction steps "
+ f"and {max_conversation_restarts} conversation restarts."
+ ) from exc
+
+ return output_obj, messages
+
+ def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
+ provider = self._model_provider_registry.get_provider(model_config.provider)
+ api_key = None
+ if provider.api_key:
+ api_key = self._secret_resolver.resolve(provider.api_key)
+ api_key = api_key or "not-used-but-required"
+
+ litellm_params = litellm.LiteLLM_Params(
+ model=f"{provider.provider_type}/{model_config.model}",
+ api_base=provider.endpoint,
+ api_key=api_key,
+ )
+ return {
+ "model_name": model_config.model,
+ "litellm_params": litellm_params.model_dump(),
+ }
+
+ def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> None:
+ if response is None:
+ self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
+ return
+ if (
+ response.usage is not None
+ and response.usage.prompt_tokens is not None
+ and response.usage.completion_tokens is not None
+ ):
+ self._usage_stats.extend(
+ token_usage=TokenUsageStats(
+ input_tokens=response.usage.prompt_tokens,
+ output_tokens=response.usage.completion_tokens,
+ ),
+ request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
+ )
+
+ def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None:
+ if response is None:
+ self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
+ return
+ if response.usage is not None and response.usage.prompt_tokens is not None:
+ self._usage_stats.extend(
+ token_usage=TokenUsageStats(
+ input_tokens=response.usage.prompt_tokens,
+ output_tokens=0,
+ ),
+ request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
+ )
+
+ async def acompletion(
+ self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any
+ ) -> litellm.ModelResponse:
+ message_payloads = [message.to_dict() for message in messages]
+ logger.debug(
+ f"Prompting model {self.model_name!r}...",
+ extra={"model": self.model_name, "messages": message_payloads},
+ )
+ response = None
+ kwargs = self.consolidate_kwargs(**kwargs)
+ try:
+ response = await self._router.acompletion(model=self.model_name, messages=message_payloads, **kwargs)
+ logger.debug(
+ f"Received completion from model {self.model_name!r}",
+ extra={
+ "model": self.model_name,
+ "response": response,
+ "text": response.choices[0].message.content,
+ "usage": self._usage_stats.model_dump(),
+ },
+ )
+ return response
+ except Exception as e:
+ raise e
+ finally:
+ if not skip_usage_tracking and response is not None:
+ self._track_usage(response)
+
+ @acatch_llm_exceptions
+ async def agenerate_text_embeddings(
+ self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs: Any
+ ) -> list[list[float]]:
+ logger.debug(
+ f"Generating embeddings with model {self.model_name!r}...",
+ extra={
+ "model": self.model_name,
+ "input_count": len(input_texts),
+ },
+ )
+ kwargs = self.consolidate_kwargs(**kwargs)
+ response = None
+ try:
+ response = await self._router.aembedding(model=self.model_name, input=input_texts, **kwargs)
+ logger.debug(
+ f"Received embeddings from model {self.model_name!r}",
+ extra={
+ "model": self.model_name,
+ "embedding_count": len(response.data) if response.data else 0,
+ "usage": self._usage_stats.model_dump(),
+ },
+ )
+ if response.data and len(response.data) == len(input_texts):
+ return [data["embedding"] for data in response.data]
+ else:
+ raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
+ except Exception as e:
+ raise e
+ finally:
+ if not skip_usage_tracking and response is not None:
+ self._track_usage_from_embedding(response)
+
+ @acatch_llm_exceptions
+ async def agenerate(
+ self,
+ prompt: str,
+ *,
+ parser: Callable[[str], Any],
+ system_prompt: str | None = None,
+ multi_modal_context: list[dict[str, Any]] | None = None,
+ tool_alias: str | None = None,
+ max_correction_steps: int = 0,
+ max_conversation_restarts: int = 0,
+ skip_usage_tracking: bool = False,
+ purpose: str | None = None,
+ **kwargs: Any,
+ ) -> tuple[Any, list[ChatMessage]]:
+ output_obj = None
+ tool_schemas = None
+ tool_call_turns = 0
+ curr_num_correction_steps = 0
+ curr_num_restarts = 0
+
+ mcp_facade = self._get_mcp_facade(tool_alias)
+
+ restart_checkpoint = prompt_to_messages(
+ user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
+ )
+ checkpoint_tool_call_turns = 0
+ messages: list[ChatMessage] = deepcopy(restart_checkpoint)
+
+ if mcp_facade is not None:
+ tool_schemas = await asyncio.to_thread(mcp_facade.get_tool_schemas)
+
+ while True:
+ completion_kwargs = dict(kwargs)
+ if tool_schemas is not None:
+ completion_kwargs["tools"] = tool_schemas
+
+ completion_response = await self.acompletion(
+ messages,
+ skip_usage_tracking=skip_usage_tracking,
+ **completion_kwargs,
+ )
+
+ if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response):
+ tool_call_turns += 1
+
+ if tool_call_turns > mcp_facade.max_tool_call_turns:
+ messages.extend(mcp_facade.refuse_completion_response(completion_response))
+ else:
+ messages.extend(
+ await asyncio.to_thread(mcp_facade.process_completion_response, completion_response)
+ )
+
+ restart_checkpoint = deepcopy(messages)
+ checkpoint_tool_call_turns = tool_call_turns
+
+ continue
+
+ response = completion_response.choices[0].message.content or ""
+ reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
+ messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None))
+ curr_num_correction_steps += 1
+
+ try:
+ output_obj = parser(response)
+ break
+ except ParserException as exc:
+ if max_correction_steps == 0 and max_conversation_restarts == 0:
+ raise GenerationValidationFailureError(
+ "Unsuccessful generation attempt. No retries were attempted."
+ ) from exc
+
+ if curr_num_correction_steps <= max_correction_steps:
+ messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc))))
+
+ elif curr_num_restarts < max_conversation_restarts:
+ curr_num_correction_steps = 0
+ curr_num_restarts += 1
+ messages = deepcopy(restart_checkpoint)
+ tool_call_turns = checkpoint_tool_call_turns
+
+ else:
+ raise GenerationValidationFailureError(
+ f"Unsuccessful generation despite {max_correction_steps} correction steps "
+ f"and {max_conversation_restarts} conversation restarts."
+ ) from exc
+
+ return output_obj, messages
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/factory.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/factory.py
new file mode 100644
index 00000000..fb3b2e1d
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/factory.py
@@ -0,0 +1,59 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from data_designer.config.models import ModelConfig
+from data_designer.engine.model_provider import ModelProviderRegistry
+from data_designer.engine.secret_resolver import SecretResolver
+
+if TYPE_CHECKING:
+ from data_designer.engine.mcp.registry import MCPRegistry
+ from data_designer.engine.models.registry import ModelRegistry
+
+
+def create_model_registry(
+ *,
+ model_configs: list[ModelConfig] | None = None,
+ secret_resolver: SecretResolver,
+ model_provider_registry: ModelProviderRegistry,
+ mcp_registry: MCPRegistry | None = None,
+) -> ModelRegistry:
+ """Factory function for creating a ModelRegistry instance.
+
+ Heavy dependencies (litellm, httpx) are deferred until this function is called.
+ This is a factory function pattern - imports inside factories are idiomatic Python
+ for lazy initialization.
+
+ Args:
+ model_configs: Optional list of model configurations to register.
+ secret_resolver: Resolver for secrets referenced in provider configs.
+ model_provider_registry: Registry of model provider configurations.
+ mcp_registry: Optional MCP registry for tool operations. When provided,
+ ModelFacades can look up MCPFacades by tool_alias for tool-enabled generation.
+
+ Returns:
+ A configured ModelRegistry instance.
+ """
+ from data_designer.engine.models.facade import ModelFacade
+ from data_designer.engine.models.litellm_overrides import apply_litellm_patches
+ from data_designer.engine.models.registry import ModelRegistry
+
+ apply_litellm_patches()
+
+ def model_facade_factory(model_config, secret_resolver, model_provider_registry):
+ return ModelFacade(
+ model_config,
+ secret_resolver,
+ model_provider_registry,
+ mcp_registry=mcp_registry,
+ )
+
+ return ModelRegistry(
+ model_configs=model_configs,
+ secret_resolver=secret_resolver,
+ model_provider_registry=model_provider_registry,
+ model_facade_factory=model_facade_factory,
+ )
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/litellm_overrides.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/litellm_overrides.py
new file mode 100644
index 00000000..92070def
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/litellm_overrides.py
@@ -0,0 +1,179 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+"""
+LiteLLM overrides and customizations.
+
+Note on imports: This module uses direct (eager) imports for litellm rather than lazy loading.
+This is intentional because:
+
+1. Class inheritance requires base classes to be resolved at class definition time,
+ making lazy imports incompatible with our ThreadSafeCache and CustomRouter classes.
+
+2. This module is already lazily loaded at the application level - it's only imported
+ by facade.py, which itself is imported inside the create_model_registry() factory
+ function. So litellm is only loaded when models are actually needed.
+
+3. Attempting to use lazy imports here causes intermittent ImportErrors.
+"""
+
+from __future__ import annotations
+
+import random
+import threading
+
+import httpx
+import litellm
+from litellm import RetryPolicy
+from litellm.caching.in_memory_cache import InMemoryCache
+from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
+from litellm.router import Router
+from pydantic import BaseModel, Field
+from typing_extensions import override
+
+from data_designer.logging import quiet_noisy_logger
+
+DEFAULT_MAX_CALLBACKS = 1000
+
+
+class LiteLLMRouterDefaultKwargs(BaseModel):
+ ## Number of seconds to wait initially after a connection
+ ## failure.
+ initial_retry_after_s: float = 2.0
+
+ ## Jitter percentage added during exponential backoff to
+ ## smooth repeated retries over time.
+ jitter_pct: float = 0.2
+
+ ## Maximum number of seconds to wait for an API request
+ ## before letting it die. Will trigger a retry.
+ timeout: float = 60.0
+
+ ## Sets the default retry policy, including the number
+ ## of retries to use in particular scenarios.
+ retry_policy: RetryPolicy = Field(
+ default_factory=lambda: RetryPolicy(
+ RateLimitErrorRetries=3,
+ TimeoutErrorRetries=3,
+ )
+ )
+
+
+class ThreadSafeCache(InMemoryCache):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self._lock = threading.RLock()
+
+ def get_cache(self, key, **kwargs):
+ with self._lock:
+ return super().get_cache(key, **kwargs)
+
+ def set_cache(self, key, value, **kwargs):
+ with self._lock:
+ super().set_cache(key, value, **kwargs)
+
+ def batch_get_cache(self, keys: list, **kwargs):
+ with self._lock:
+ return super().batch_get_cache(keys, **kwargs)
+
+ def delete_cache(self, key):
+ with self._lock:
+ super().delete_cache(key)
+
+ def evict_cache(self):
+ with self._lock:
+ super().evict_cache()
+
+ def increment_cache(self, key, value: int, **kwargs) -> int:
+ with self._lock:
+ return super().increment_cache(key, value, **kwargs)
+
+ def flush_cache(self):
+ with self._lock:
+ super().flush_cache()
+
+
+class CustomRouter(Router):
+ def __init__(
+ self,
+ *args,
+ initial_retry_after_s: float,
+ jitter_pct: float,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self._initial_retry_after_s = initial_retry_after_s
+ self._jitter_pct = jitter_pct
+
+ def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None:
+ """
+ Most of this code logic was extracted directly from the parent
+ `Router`'s `_time_to_sleep_before_retry` function. Our override
+ of that method below should only affect requests where the server
+ didn't explicitly return a desired retry-delay. If the server did
+ return this info, we'll simply use that retry value returned here.
+ """
+
+ response_headers: httpx.Headers | None = None
+ if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
+ response_headers = e.response.headers # type: ignore
+ if hasattr(e, "litellm_response_headers"):
+ response_headers = e.litellm_response_headers # type: ignore
+
+ retry_after = litellm.utils._get_retry_after_from_exception_header(response_headers)
+
+ # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
+ if retry_after is not None and 0 < retry_after <= 60:
+ return retry_after
+ else:
+ return None
+
+ @override
+ def _time_to_sleep_before_retry(
+ self,
+ e: Exception,
+ remaining_retries: int,
+ num_retries: int,
+ healthy_deployments: list | None = None,
+ all_deployments: list | None = None,
+ ) -> int | float:
+ """
+ Implements exponential backoff for retries.
+
+ Technically, litellm's `Router` already implements some
+ form of exponential backoff. However, that backoff
+ is not customizable w.r.t jitter and initial delay
+ timing. For that reason, we override this method to
+ utilize our own custom instance variables, deferring
+ to the existing implementation wherever we can.
+ """
+
+ # If the response headers indicated how long we should wait,
+ # use that information.
+ if retry_after := self._extract_retry_delay_from_headers(e):
+ return retry_after
+
+ return self.calculate_exponential_backoff(
+ initial_retry_after_s=self._initial_retry_after_s,
+ current_retry=num_retries - remaining_retries,
+ jitter_pct=self._jitter_pct,
+ )
+
+ @staticmethod
+ def calculate_exponential_backoff(initial_retry_after_s: float, current_retry: int, jitter_pct: float) -> float:
+ sleep_s = initial_retry_after_s * (pow(2.0, current_retry))
+ jitter = 1.0 + random.uniform(-jitter_pct, jitter_pct)
+ return sleep_s * jitter
+
+
+def apply_litellm_patches():
+ litellm.in_memory_llm_clients_cache = ThreadSafeCache()
+
+ # Workaround for the litellm issue described in https://github.com/BerriAI/litellm/issues/9792
+ LoggingCallbackManager.MAX_CALLBACKS = DEFAULT_MAX_CALLBACKS
+
+ quiet_noisy_logger("httpx")
+ quiet_noisy_logger("LiteLLM")
+ quiet_noisy_logger("LiteLLM Router")
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/__init__.py
new file mode 100644
index 00000000..52a7a9da
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/errors.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/errors.py
new file mode 100644
index 00000000..7d1db351
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/errors.py
@@ -0,0 +1,34 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+
+class ParserException(Exception):
+ """Identifies errors resulting from generic parser errors.
+
+ Attributes:
+ source (str | None): The source string that the parser
+ attempted to parse.
+ """
+
+ source: str | None
+
+ @staticmethod
+ def _log_format(source: str) -> str:
+ ## NOTE: The point of this was to be able to report offending
+ ## failure cases to the logs. This might not be what we want
+ ## to do in all cases. In the meantime, this note is left
+ ## for later review.
+ #
+ # return f"
,
element node.
+
+ This parser handles the special case of Markdown->HTML conversion
+ for fenced code blocks. These take on the form:
+
+ ```xx
+ ...
+ ```
+
+ ...
+
+ This parser is intended to be attached to the special case of "pre.code"
+ tag hierarchies.
+
+ Syntax Handling
+
+ If the syntax is not specified, e.g. ``...`` or
+ ``...``, then the syntax field is returned
+ as None. However, the parser does not _enforce_ the prefix
+ `language-` on the value of the class attribute.
+ If it is not present, then the entire value
+
+ Args:
+ element (lxml.etree._Element): An element of the lxml-parsed
+ element tree.
+
+ Returns:
+ CodeBlock: Datat structured containing both the body of the code
+ as well as the specified synax of the code block.
+
+ """
+ prefix = "language-"
+ language_identifier = element.attrib.get("class", "")
+ language_identifier = language_identifier.removeprefix(prefix)
+ return CodeBlock(
+ code=element.text.strip() if element.text else "",
+ code_lang=language_identifier if language_identifier else None,
+ )
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/types.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/types.py
new file mode 100644
index 00000000..81575ef9
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/types.py
@@ -0,0 +1,84 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+from typing import Any, Protocol, runtime_checkable
+
+from lxml.etree import _Element
+from pydantic import BaseModel, Field
+from typing_extensions import Self
+
+
+class LLMStructuredResponse(BaseModel):
+ """Output format for the LLM Response Parser."""
+
+ response: str = Field(description="Raw Markdown/Markup response received from the LLM and input to the parser.")
+ markup: str = Field(description="Markup/HTML resulting from running Markdown parsing on response.")
+ parsed: list[BaseModel] = Field(
+ default_factory=list,
+ description="Structured content parsed from markup. Elements of this list are in document-order.",
+ )
+
+ def head(self, n: int) -> Self:
+ """Retain only the first n elements of the parsed response."""
+ out = self.model_copy()
+ out.parsed = out.parsed[:n]
+ return out
+
+ def tail(self, n: int) -> Self:
+ """Retain only the last n elements of the parsed response."""
+ out = self.model_copy()
+ out.parsed = out.parsed[-n:]
+ return out
+
+ def filter(self, block_types: list[type[BaseModel]]) -> Self:
+ out = self.model_copy()
+ out.parsed = [b for b in out.parsed if isinstance(b, tuple(block_types))]
+ return out
+
+
+@runtime_checkable
+class TagParser(Protocol):
+ """Protocol for tag parsing implementations.
+
+ All TagParsers are objects which can take as input an `lxml`
+ element, do some computation, and return some kind of structured
+ output, represented as a subclass of Pydantic `BaseModel`.
+ This protocol implementation can cover both classes as well
+ as curried functions as parsers (e.g. `partial`).
+ """
+
+ def __call__(self, element: _Element) -> BaseModel: ...
+
+
+@runtime_checkable
+class PostProcessor(Protocol):
+ """Protocol for parsed output postprocessing implementations.
+
+ Implementations of this protocol are used to transform the results of
+ the LLM response parser while retaining the same output structure.
+ This is done so that PostProcessor implementations can be chained
+ together.
+ """
+
+ def __call__(self, structured_response: LLMStructuredResponse) -> LLMStructuredResponse: ...
+
+
+class TextBlock(BaseModel):
+ text: str
+
+
+class CodeBlock(BaseModel):
+ code: str
+ code_lang: str | None = None
+
+
+class StructuredDataBlock(BaseModel):
+ serialized: str
+ obj: Any
+
+
+class PydanticTypeBlock(BaseModel):
+ serialized: str
+ obj: BaseModel
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/recipes/base.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/recipes/base.py
new file mode 100644
index 00000000..ab3e313a
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/recipes/base.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import abc
+from collections.abc import Callable
+from typing import Generic, TypeVar
+
+T = TypeVar("T")
+
+
+class ResponseRecipe(abc.ABC, Generic[T]):
+ """Base class for defining response recipes.
+
+ Response recipes contain all necessary information for
+ getting an LLM to perform a particular common task,
+ like outputting code in a desired format or following
+ structured output.
+ """
+
+ @abc.abstractmethod
+ def _build_parser_fn(self) -> Callable[[str], T]:
+ """Build the recipe's output parser function."""
+ ...
+
+ @property
+ @abc.abstractmethod
+ def example_template(self) -> str: ...
+
+ @abc.abstractmethod
+ def serialize_output(self, output: T) -> str:
+ """Serialize an instance of the parser output."""
+ ...
+
+ @abc.abstractmethod
+ def deserialize_output(self, serialized_output: str) -> T:
+ """Deserialize a serialized instance of the parser output."""
+ ...
+
+ def __init__(self):
+ self._parse_fn = self._build_parser_fn()
+
+ @property
+ def task_instructions(self) -> str | None:
+ """Specifies task instructions.
+
+ These instructions lay out the particular task information the
+ LLM requires in order to carry out the function of the recipe.
+ """
+ return None
+
+ def parse(self, response: str) -> T:
+ """Apply the recipe's parser to a raw model output."""
+ return self._parse_fn(response)
+
+ def generate_response_example(self, example: T) -> str:
+ """Create a serialized response example that the parser would admit."""
+ return self.example_template.format(example=example)
+
+ def apply_recipe_to_user_prompt(self, user_prompt: str) -> str:
+ """Appends recipe specific task instructions if applicable.
+
+ Args:
+ user_prompt (str): User prompt to be appended with recipe specific task instructions if applicable.
+
+ Returns:
+ str: Final user prompt
+ """
+ return f"{user_prompt}\n\n{self.task_instructions}" if self.task_instructions else user_prompt
+
+ def apply_recipe_to_system_prompt(self, system_prompt: str | None) -> str:
+ """Appends recipe specific task instructions if applicable.
+
+ Args:
+ system_prompt (str): System prompt to be appended with recipe specific task instructions if applicable.
+
+ Returns:
+ str: Final system prompt
+ """
+ return system_prompt
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/recipes/response_recipes.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/recipes/response_recipes.py
new file mode 100644
index 00000000..deba050e
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/recipes/response_recipes.py
@@ -0,0 +1,293 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import json
+from collections.abc import Callable
+
+from pydantic import BaseModel
+
+from data_designer.config.utils.code_lang import CodeLang
+from data_designer.engine.models.parsers.errors import ParserException
+from data_designer.engine.models.parsers.parser import LLMResponseParser
+from data_designer.engine.models.parsers.postprocessors import (
+ StructuredDataBlock,
+ deserialize_json_code,
+ merge_text_blocks,
+)
+from data_designer.engine.models.parsers.types import CodeBlock
+from data_designer.engine.models.recipes.base import (
+ ResponseRecipe,
+)
+from data_designer.engine.processing.gsonschema.validators import JSONSchemaValidationError, validate
+
+
+class TextResponseRecipe(ResponseRecipe[str]):
+ """Default text-parser.
+
+ This parser is meant to cover the "pass-through" case of natural language LLM responses.
+ """
+
+ @property
+ def example_template(self) -> str:
+ return "{example}"
+
+ def serialize_output(self, output: str) -> str:
+ return output
+
+ def deserialize_output(self, serialized_output: str) -> str:
+ return serialized_output
+
+ def _build_parser_fn(self) -> Callable[[str], str]:
+ parser = LLMResponseParser(
+ postprocessors=[
+ merge_text_blocks,
+ ]
+ )
+
+ return lambda x: parser.parse(x).response
+
+
+class StructuredResponseRecipe(ResponseRecipe[dict]):
+ """Recipe for structured responses.
+
+ This recipe is intended to cover the generic case of
+ prompting-based requests for structured data outputs,
+ and the structure in question is determined by a
+ provided JSON Schema.
+
+ The LLM's response us validated against the provided
+ JSON Schema, however the object returned is python
+ dictionary obtained from deserializing the LLM's
+ JSON response.
+ """
+
+ json_schema: dict
+ pruning: bool
+ no_extra_properties: bool
+
+ def __init__(
+ self,
+ json_schema: dict,
+ pruning: bool = True,
+ no_extra_properties: bool = True,
+ **kwargs,
+ ):
+ """Initialize StructuredResponseRecipe.
+
+ Args:
+ json_schema (dict): A target JSON schema that the LLM
+ should adhere to when making its response.
+ pruning (bool): If `True`, then any extra fields in the returned
+ JSON object will be removed. Otherwise, they are retained,
+ which could raise validation errors. Default=True
+ no_extra_properties (bool) If `True`, then validation will fail
+ if extra properties are encountered in the returned JSON response.
+ Default=True.
+ """
+ super().__init__(**kwargs)
+ self.json_schema = json_schema
+ self.pruning = pruning
+ self.no_extra_properties = no_extra_properties
+
+ @property
+ def task_instructions(self) -> str:
+ return (
+ "* Your response must be in JSON format.\n"
+ "* Your JSON response must be returned within a Markdown, ```json code fence.\n"
+ "* The JSON format is given as a JSON Schema description within markup tags.\n\n"
+ f"\n{self.schema}\n "
+ )
+
+ @property
+ def example_template(self) -> str:
+ return "```json\n{example}\n```"
+
+ def generate_response_example(self, example: dict) -> str:
+ return self.example_template.format(example=json.dumps(example))
+
+ @property
+ def schema(self) -> str:
+ return json.dumps(self.json_schema)
+
+ def serialize_output(self, output: dict) -> str:
+ return json.dumps(output, ensure_ascii=False)
+
+ def deserialize_output(self, serialized_output: str) -> dict:
+ return json.loads(serialized_output)
+
+ @property
+ def _validate_args(self):
+ return {
+ "schema": self.json_schema,
+ "pruning": self.pruning,
+ "no_extra_properties": self.no_extra_properties,
+ }
+
+ def _build_parser_fn(self) -> Callable[[str], dict]:
+ parser = LLMResponseParser(
+ postprocessors=[
+ merge_text_blocks,
+ deserialize_json_code,
+ ]
+ )
+
+ def parse_fn(response: str) -> dict:
+ try:
+ obj = parser.parse(response).filter([StructuredDataBlock]).parsed.pop().obj
+ return validate(obj, **self._validate_args)
+ except IndexError:
+ raise ParserException(
+ "No parsable JSON structure within ```json markdown fence.",
+ source=response,
+ ) from None
+ except JSONSchemaValidationError as exc:
+ raise ParserException(
+ "Response doesn't match requested \n" + str(exc),
+ source=response,
+ ) from None
+
+ return parse_fn
+
+
+class PydanticResponseRecipe(ResponseRecipe[BaseModel]):
+ """Recipe for Pydantic responses.
+
+ This recipe covers the case that we have a Pydantic
+ data type (BaseModel) already specified in the runtime
+ making LLM calls, and we want to obtain an object of
+ that same data type as the output from the parser.
+
+ This recipe operates in a very similar fashion to
+ `StructuredResponseRecipe` except that it is initialized
+ from a Pydantic `BaseModel` and does the extra step of
+ validating against that `BaseModel` using
+ `BaseModel.model_validate` for its return.
+ """
+
+ data_type: type[BaseModel]
+
+ def __init__(self, data_type: type[BaseModel], **kwargs):
+ """Initialize a PydanticResponseRecipe.
+
+ Args:
+ data_type (type(BaseModel)): The target Pydantic BaseModel
+ subclass that the LLM should adhere to in its response,
+ and defines the output type of the parser.
+ """
+ super().__init__(**kwargs)
+ self.data_type = data_type
+
+ @property
+ def schema(self) -> str:
+ return json.dumps(self.data_type.model_json_schema())
+
+ @property
+ def task_instructions(self) -> str:
+ return (
+ "* Your response must be in JSON format.\n"
+ "* Your JSON response must be returned within a Markdown, ```json code fence.\n"
+ "* The JSON format is given as a JSON Schema description within markup tags.\n\n"
+ f"\n{self.schema}\n "
+ )
+
+ @property
+ def example_template(self) -> str:
+ return "```json\n{example}\n```"
+
+ def generate_response_example(self, example: BaseModel) -> str:
+ return self.example_template.format(example=example.model_dump_json())
+
+ def serialize_output(self, output: BaseModel) -> str:
+ return output.model_dump_json()
+
+ def deserialize_output(self, serialized_output: str) -> BaseModel:
+ return self.data_type.model_validate_json(serialized_output)
+
+ def _build_parser_fn(self) -> Callable[[str], BaseModel]:
+ parser = LLMResponseParser(
+ postprocessors=[
+ merge_text_blocks,
+ deserialize_json_code,
+ ]
+ )
+
+ def parse_fn(response: str) -> BaseModel:
+ try:
+ obj = parser.parse(response).filter([StructuredDataBlock]).parsed.pop().obj
+ return self.data_type.model_validate(obj)
+ except IndexError:
+ raise ParserException(
+ "No parsable JSON structure within ```json markdown fence.",
+ source=response,
+ ) from None
+ except Exception as exc:
+ raise ParserException(
+ "Response doesn't match requested \n" + str(exc),
+ source=response,
+ ) from None
+
+ return parse_fn
+
+
+class CodeResponseRecipe(ResponseRecipe[str]):
+ """Obtain a code snippet from an LLM."""
+
+ def __init__(self, syntax: str | CodeLang, **kwargs):
+ """Initialize a CodeResponseRecipe.
+
+ Args:
+ syntax (str | CodeLang): The code syntax that the
+ LLM should adhere to, e.g. `"python"`, `"sql"`, etc.
+ """
+ super().__init__(**kwargs)
+ self.syntax = CodeLang.parse_lang(syntax)
+
+ @property
+ def task_instructions(self) -> str:
+ return (
+ f"* Your response must be code written in {self.syntax}.\n"
+ "* You will follow accepted and common syntax and best-practices.\n"
+ f"* Your response will be given in markdown code fences specifying the correct language.\n"
+ "* Only respond with a SINGLE code block."
+ )
+
+ @property
+ def example_template(self) -> str:
+ return f"```{self.syntax}\n{{example}}\n```\n"
+
+ def serialize_output(self, output: str) -> str:
+ return output
+
+ def deserialize_output(self, serialized_output: str) -> str:
+ return serialized_output
+
+ def _build_parser_fn(self) -> Callable[[str], str]:
+ parser = LLMResponseParser(
+ postprocessors=[
+ merge_text_blocks,
+ ]
+ )
+
+ def parse_fn(response: str) -> str:
+ try:
+ code_block = parser.parse(response).filter([CodeBlock]).parsed.pop()
+ # For the type checker -- should always pass
+ assert isinstance(code_block, CodeBlock)
+ except IndexError:
+ raise ParserException(
+ "No parsable code response.",
+ source=response,
+ ) from None
+
+ # Only report this as a parser error if there was a mismatch.
+ if code_block.code_lang and code_block.code_lang != self.syntax:
+ raise ParserException(
+ f"Responded with code not matching the requested syntax ({self.syntax}).",
+ source=response,
+ )
+
+ return code_block.code.strip()
+
+ return parse_fn
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/registry.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/registry.py
new file mode 100644
index 00000000..8f6a0e9d
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/registry.py
@@ -0,0 +1,151 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import logging
+from collections.abc import Callable
+from typing import TYPE_CHECKING
+
+from data_designer.config.models import GenerationType, ModelConfig
+from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
+from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
+from data_designer.engine.secret_resolver import SecretResolver
+
+if TYPE_CHECKING:
+ from data_designer.engine.models.facade import ModelFacade
+
+logger = logging.getLogger(__name__)
+
+
+class ModelRegistry:
+ def __init__(
+ self,
+ *,
+ secret_resolver: SecretResolver,
+ model_provider_registry: ModelProviderRegistry,
+ model_configs: list[ModelConfig] | None = None,
+ model_facade_factory: Callable[[ModelConfig, SecretResolver, ModelProviderRegistry], ModelFacade] | None = None,
+ ):
+ self._secret_resolver = secret_resolver
+ self._model_provider_registry = model_provider_registry
+ self._model_facade_factory = model_facade_factory
+ self._model_configs: dict[str, ModelConfig] = {}
+ self._models: dict[str, ModelFacade] = {}
+ self._set_model_configs(model_configs)
+
+ @property
+ def model_configs(self) -> dict[str, ModelConfig]:
+ return self._model_configs
+
+ @property
+ def models(self) -> dict[str, ModelFacade]:
+ return self._models
+
+ def register_model_configs(self, model_configs: list[ModelConfig]) -> None:
+ """Register a new Model configuration at runtime.
+
+ Args:
+ model_config: A new Model configuration to register. If an
+ Model configuration already exists in the registry
+ with the same name, then it will be overwritten.
+ """
+ self._set_model_configs(list(self._model_configs.values()) + model_configs)
+
+ def get_model(self, *, model_alias: str) -> ModelFacade:
+ # Check if model config exists first
+ if model_alias not in self._model_configs:
+ raise ValueError(f"No model config with alias {model_alias!r} found!")
+
+ # Lazy initialization: only create model facade when first requested
+ if model_alias not in self._models:
+ self._models[model_alias] = self._get_model(self._model_configs[model_alias])
+
+ return self._models[model_alias]
+
+ def get_model_config(self, *, model_alias: str) -> ModelConfig:
+ if model_alias not in self._model_configs:
+ raise ValueError(f"No model config with alias {model_alias!r} found!")
+ return self._model_configs[model_alias]
+
+ def get_model_usage_stats(self, total_time_elapsed: float) -> dict[str, dict]:
+ return {
+ model.model_name: model.usage_stats.get_usage_stats(total_time_elapsed=total_time_elapsed)
+ for model in self._models.values()
+ if model.usage_stats.has_usage
+ }
+
+ def get_model_usage_snapshot(self) -> dict[str, ModelUsageStats]:
+ return {
+ model.model_name: model.usage_stats.model_copy(deep=True)
+ for model in self._models.values()
+ if model.usage_stats.has_usage
+ }
+
+ def get_usage_deltas(self, snapshot: dict[str, ModelUsageStats]) -> dict[str, ModelUsageStats]:
+ deltas = {}
+ for model_name, current in self.get_model_usage_snapshot().items():
+ prev = snapshot.get(model_name)
+ delta_input = current.token_usage.input_tokens - (prev.token_usage.input_tokens if prev else 0)
+ delta_output = current.token_usage.output_tokens - (prev.token_usage.output_tokens if prev else 0)
+ delta_successful = current.request_usage.successful_requests - (
+ prev.request_usage.successful_requests if prev else 0
+ )
+ delta_failed = current.request_usage.failed_requests - (prev.request_usage.failed_requests if prev else 0)
+
+ if delta_input > 0 or delta_output > 0 or delta_successful > 0 or delta_failed > 0:
+ deltas[model_name] = ModelUsageStats(
+ token_usage=TokenUsageStats(input_tokens=delta_input, output_tokens=delta_output),
+ request_usage=RequestUsageStats(successful_requests=delta_successful, failed_requests=delta_failed),
+ )
+ return deltas
+
+ def get_model_provider(self, *, model_alias: str) -> ModelProvider:
+ model_config = self.get_model_config(model_alias=model_alias)
+ return self._model_provider_registry.get_provider(model_config.provider)
+
+ def run_health_check(self, model_aliases: list[str]) -> None:
+ logger.info("🩺 Running health checks for models...")
+ for model_alias in model_aliases:
+ model_config = self.get_model_config(model_alias=model_alias)
+ if model_config.skip_health_check:
+ logger.info(f" |-- ⏭️ Skipping health check for model alias {model_alias!r} (skip_health_check=True)")
+ continue
+
+ model = self.get_model(model_alias=model_alias)
+ logger.info(
+ f" |-- 👀 Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..."
+ )
+ try:
+ if model.model_generation_type == GenerationType.EMBEDDING:
+ model.generate_text_embeddings(
+ input_texts=["Hello!"],
+ skip_usage_tracking=True,
+ purpose="running health checks",
+ )
+ elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
+ model.generate(
+ prompt="Hello!",
+ parser=lambda x: x,
+ system_prompt="You are a helpful assistant.",
+ max_correction_steps=0,
+ max_conversation_restarts=0,
+ skip_usage_tracking=True,
+ purpose="running health checks",
+ )
+ else:
+ raise ValueError(f"Unsupported generation type: {model.model_generation_type}")
+ logger.info(" |-- ✅ Passed!")
+ except Exception as e:
+ logger.error(" |-- ❌ Failed!")
+ raise e
+
+ def _set_model_configs(self, model_configs: list[ModelConfig]) -> None:
+ model_configs = model_configs or []
+ self._model_configs = {mc.alias: mc for mc in model_configs}
+ # Models are now lazily initialized in get_model() when first requested
+
+ def _get_model(self, model_config: ModelConfig) -> ModelFacade:
+ if self._model_facade_factory is None:
+ raise RuntimeError("ModelRegistry was not initialized with a model_facade_factory")
+ return self._model_facade_factory(model_config, self._secret_resolver, self._model_provider_registry)
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/telemetry.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/telemetry.py
new file mode 100644
index 00000000..f90e39c0
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/telemetry.py
@@ -0,0 +1,362 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Telemetry handler for NeMo products.
+
+Environment variables:
+- NEMO_TELEMETRY_ENABLED: Whether telemetry is enabled.
+- NEMO_DEPLOYMENT_TYPE: The deployment type the event came from.
+- NEMO_TELEMETRY_ENDPOINT: The endpoint to send the telemetry events to.
+- NEMO_SESSION_PREFIX: Optional prefix to add to session IDs.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import os
+import platform
+from dataclasses import dataclass
+from datetime import datetime, timezone
+from enum import Enum
+from typing import Any, ClassVar
+
+from pydantic import BaseModel, Field
+
+from data_designer.lazy_heavy_imports import httpx
+
+TELEMETRY_ENABLED = os.getenv("NEMO_TELEMETRY_ENABLED", "true").lower() in ("1", "true", "yes")
+CLIENT_ID = "184482118588404"
+NEMO_TELEMETRY_VERSION = "nemo-telemetry/1.0"
+MAX_RETRIES = 3
+NEMO_TELEMETRY_ENDPOINT = os.getenv(
+ "NEMO_TELEMETRY_ENDPOINT", "https://events.telemetry.data.nvidia.com/v1.1/events/json"
+).lower()
+CPU_ARCHITECTURE = platform.uname().machine
+SESSION_PREFIX = os.getenv("NEMO_SESSION_PREFIX")
+
+
+class NemoSourceEnum(str, Enum):
+ INFERENCE = "inference"
+ AUDITOR = "auditor"
+ DATADESIGNER = "datadesigner"
+ EVALUATOR = "evaluator"
+ GUARDRAILS = "guardrails"
+ UNDEFINED = "undefined"
+
+
+class DeploymentTypeEnum(str, Enum):
+ LIBRARY = "library"
+ API = "api"
+ UNDEFINED = "undefined"
+
+
+_deployment_type_raw = os.getenv("NEMO_DEPLOYMENT_TYPE", "library").lower()
+try:
+ DEPLOYMENT_TYPE = DeploymentTypeEnum(_deployment_type_raw)
+except ValueError:
+ valid_values = [e.value for e in DeploymentTypeEnum]
+ raise ValueError(
+ f"Invalid NEMO_DEPLOYMENT_TYPE: {_deployment_type_raw!r}. Must be one of: {valid_values}"
+ ) from None
+
+
+class TaskStatusEnum(str, Enum):
+ SUCCESS = "success"
+ FAILURE = "failure"
+ UNDEFINED = "undefined"
+
+
+class TelemetryEvent(BaseModel):
+ _event_name: ClassVar[str] # Subclasses must define this
+ _schema_version: ClassVar[str] = "1.3"
+
+ def __init_subclass__(cls, **kwargs: Any) -> None:
+ super().__init_subclass__(**kwargs)
+ if "_event_name" not in cls.__dict__:
+ raise TypeError(f"{cls.__name__} must define '_event_name' class variable")
+
+
+class InferenceEvent(TelemetryEvent):
+ _event_name: ClassVar[str] = "inference_event"
+
+ nemo_source: NemoSourceEnum = Field(
+ ...,
+ alias="nemoSource",
+ description="The NeMo product that created the event (i.e. data-designer).",
+ )
+ task: str = Field(
+ ...,
+ description="The type of task that was performed that generated the inference event (i.e. preview-job, batch-job).",
+ )
+ task_status: TaskStatusEnum = Field(
+ ...,
+ alias="taskStatus",
+ description="The status of the task.",
+ )
+ deployment_type: DeploymentTypeEnum = Field(
+ default=DEPLOYMENT_TYPE,
+ alias="deploymentType",
+ description="The deployment type the event came from.",
+ )
+ model: str = Field(
+ ...,
+ description="The name of the model that was used.",
+ )
+ model_group: str = Field(
+ default="undefined",
+ alias="modelGroup",
+ description="An optional identifier to group models together.",
+ )
+ input_bytes: int = Field(
+ default=-1,
+ alias="inputBytes",
+ description="Number of bytes provided as input to the model. -1 if not available.",
+ ge=-9223372036854775808,
+ le=9223372036854775807,
+ )
+ input_tokens: int = Field(
+ default=-1,
+ alias="inputTokens",
+ description="Number of tokens provided as input to the model. -1 if not available.",
+ ge=-9223372036854775808,
+ le=9223372036854775807,
+ )
+ output_bytes: int = Field(
+ default=-1,
+ alias="outputBytes",
+ description="Number of bytes returned by the model. -1 if not available.",
+ ge=-9223372036854775808,
+ le=9223372036854775807,
+ )
+ output_tokens: int = Field(
+ default=-1,
+ alias="outputTokens",
+ description="Number of tokens returned by the model. -1 if not available.",
+ ge=-9223372036854775808,
+ le=9223372036854775807,
+ )
+
+ model_config = {"populate_by_name": True}
+
+
+@dataclass
+class QueuedEvent:
+ event: TelemetryEvent
+ timestamp: datetime
+ retry_count: int = 0
+
+
+def _get_iso_timestamp(dt: datetime | None = None) -> str:
+ if dt is None:
+ dt = datetime.now(timezone.utc)
+ return dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{dt.microsecond // 1000:03d}Z"
+
+
+def build_payload(
+ events: list[QueuedEvent], *, source_client_version: str, session_id: str = "undefined"
+) -> dict[str, Any]:
+ return {
+ "browserType": "undefined", # do not change
+ "clientId": CLIENT_ID,
+ "clientType": "Native", # do not change
+ "clientVariant": "Release", # do not change
+ "clientVer": source_client_version,
+ "cpuArchitecture": CPU_ARCHITECTURE,
+ "deviceGdprBehOptIn": "None", # do not change
+ "deviceGdprFuncOptIn": "None", # do not change
+ "deviceGdprTechOptIn": "None", # do not change
+ "deviceId": "undefined", # do not change
+ "deviceMake": "undefined", # do not change
+ "deviceModel": "undefined", # do not change
+ "deviceOS": "undefined", # do not change
+ "deviceOSVersion": "undefined", # do not change
+ "deviceType": "undefined", # do not change
+ "eventProtocol": "1.6", # do not change
+ "eventSchemaVer": events[0].event._schema_version,
+ "eventSysVer": NEMO_TELEMETRY_VERSION,
+ "externalUserId": "undefined", # do not change
+ "gdprBehOptIn": "None", # do not change
+ "gdprFuncOptIn": "None", # do not change
+ "gdprTechOptIn": "None", # do not change
+ "idpId": "undefined", # do not change
+ "integrationId": "undefined", # do not change
+ "productName": "undefined", # do not change
+ "productVersion": "undefined", # do not change
+ "sentTs": _get_iso_timestamp(),
+ "sessionId": session_id,
+ "userId": "undefined", # do not change
+ "events": [
+ {
+ "ts": _get_iso_timestamp(queued.timestamp),
+ "parameters": queued.event.model_dump(by_alias=True),
+ "name": queued.event._event_name,
+ }
+ for queued in events
+ ],
+ }
+
+
+class TelemetryHandler:
+ """
+ Handles telemetry event batching, flushing, and retry logic for NeMo products.
+
+ Args:
+ flush_interval_seconds (float): The interval in seconds to flush the events.
+ max_queue_size (int): The maximum number of events to queue before flushing.
+ max_retries (int): The maximum number of times to retry sending an event.
+ source_client_version (str): The version of the source client. This should be the version of
+ the actual NeMo product that is sending the events, typically the same as the version of
+ a PyPi package that a user would install.
+ session_id (str): An optional session ID to associate with the events.
+ This should be a unique identifier for the session, such as a UUID.
+ It is used to group events together.
+ """
+
+ def __init__(
+ self,
+ flush_interval_seconds: float = 120.0,
+ max_queue_size: int = 50,
+ max_retries: int = MAX_RETRIES,
+ source_client_version: str = "undefined",
+ session_id: str = "undefined",
+ ):
+ self._flush_interval = flush_interval_seconds
+ self._max_queue_size = max_queue_size
+ self._max_retries = max_retries
+ self._events: list[QueuedEvent] = []
+ self._dlq: list[QueuedEvent] = [] # Dead letter queue for retry
+ self._flush_signal = asyncio.Event()
+ self._timer_task: asyncio.Task | None = None
+ self._running = False
+ self._source_client_version = source_client_version
+ # Apply session prefix if environment variable is set
+ if SESSION_PREFIX:
+ self._session_id = f"{SESSION_PREFIX}{session_id}"
+ else:
+ self._session_id = session_id
+
+ async def astart(self) -> None:
+ if self._running:
+ return
+ self._running = True
+ self._timer_task = asyncio.create_task(self._timer_loop())
+
+ async def astop(self) -> None:
+ self._running = False
+ self._flush_signal.set()
+ if self._timer_task:
+ self._timer_task.cancel()
+ try:
+ await self._timer_task
+ except asyncio.CancelledError:
+ pass
+ self._timer_task = None
+ await self._flush_events()
+
+ async def aflush(self) -> None:
+ self._flush_signal.set()
+
+ def start(self) -> None:
+ self._run_sync(self.astart())
+
+ def stop(self) -> None:
+ self._run_sync(self.astop())
+
+ def flush(self) -> None:
+ self._flush_signal.set()
+
+ def enqueue(self, event: TelemetryEvent) -> None:
+ if not TELEMETRY_ENABLED:
+ return
+ if not isinstance(event, TelemetryEvent):
+ # Silently fail as we prioritize not disrupting upstream call sites and telemetry is best effort
+ return
+ queued = QueuedEvent(event=event, timestamp=datetime.now(timezone.utc))
+ self._events.append(queued)
+ if len(self._events) >= self._max_queue_size:
+ self._flush_signal.set()
+
+ def _run_sync(self, coro: Any) -> Any:
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = None
+
+ if loop and loop.is_running():
+ import concurrent.futures
+
+ with concurrent.futures.ThreadPoolExecutor() as pool:
+ future = pool.submit(asyncio.run, coro)
+ return future.result()
+ else:
+ return asyncio.run(coro)
+
+ def __enter__(self) -> TelemetryHandler:
+ self.start()
+ return self
+
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+ self.stop()
+
+ async def __aenter__(self) -> TelemetryHandler:
+ await self.astart()
+ return self
+
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+ await self.astop()
+
+ async def _timer_loop(self) -> None:
+ while self._running:
+ try:
+ await asyncio.wait_for(
+ self._flush_signal.wait(),
+ timeout=self._flush_interval,
+ )
+ except asyncio.TimeoutError:
+ pass
+ self._flush_signal.clear()
+ await self._flush_events()
+
+ async def _flush_events(self) -> None:
+ dlq_events, self._dlq = self._dlq, []
+ new_events, self._events = self._events, []
+ events_to_send = dlq_events + new_events
+ if events_to_send:
+ await self._send_events(events_to_send)
+
+ async def _send_events(self, events: list[QueuedEvent]) -> None:
+ async with httpx.AsyncClient() as client:
+ await self._send_events_with_client(client, events)
+
+ async def _send_events_with_client(self, client: httpx.AsyncClient, events: list[QueuedEvent]) -> None:
+ if not events:
+ return
+
+ payload = build_payload(events, source_client_version=self._source_client_version, session_id=self._session_id)
+ try:
+ response = await client.post(NEMO_TELEMETRY_ENDPOINT, json=payload)
+ # 2xx, 400, 422 are all considered complete (no retry)
+ # 400/422 indicate bad payload which retrying won't fix
+ if response.status_code in (400, 422) or response.is_success:
+ return
+ # 413 (payload too large) - split and retry
+ if response.status_code == 413:
+ if len(events) == 1:
+ # Can't split further, drop the event
+ return
+ mid = len(events) // 2
+ await self._send_events_with_client(client, events[:mid])
+ await self._send_events_with_client(client, events[mid:])
+ return
+ if response.status_code == 408 or response.status_code >= 500:
+ self._add_to_dlq(events)
+ except httpx.HTTPError:
+ self._add_to_dlq(events)
+
+ def _add_to_dlq(self, events: list[QueuedEvent]) -> None:
+ for queued in events:
+ queued.retry_count += 1
+ if queued.retry_count > self._max_retries:
+ continue
+ self._dlq.append(queued)
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/usage.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/usage.py
new file mode 100644
index 00000000..b239a9f3
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/usage.py
@@ -0,0 +1,73 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import logging
+
+from pydantic import BaseModel, computed_field
+
+logger = logging.getLogger(__name__)
+
+
+class TokenUsageStats(BaseModel):
+ input_tokens: int = 0
+ output_tokens: int = 0
+
+ @computed_field
+ def total_tokens(self) -> int:
+ return self.input_tokens + self.output_tokens
+
+ @property
+ def has_usage(self) -> bool:
+ return self.total_tokens > 0
+
+ def extend(self, *, input_tokens: int, output_tokens: int) -> None:
+ self.input_tokens += input_tokens
+ self.output_tokens += output_tokens
+
+
+class RequestUsageStats(BaseModel):
+ successful_requests: int = 0
+ failed_requests: int = 0
+
+ @computed_field
+ def total_requests(self) -> int:
+ return self.successful_requests + self.failed_requests
+
+ @property
+ def has_usage(self) -> bool:
+ return self.total_requests > 0
+
+ def extend(self, *, successful_requests: int, failed_requests: int) -> None:
+ self.successful_requests += successful_requests
+ self.failed_requests += failed_requests
+
+
+class ModelUsageStats(BaseModel):
+ token_usage: TokenUsageStats = TokenUsageStats()
+ request_usage: RequestUsageStats = RequestUsageStats()
+
+ @property
+ def has_usage(self) -> bool:
+ return self.token_usage.has_usage and self.request_usage.has_usage
+
+ def extend(
+ self, *, token_usage: TokenUsageStats | None = None, request_usage: RequestUsageStats | None = None
+ ) -> None:
+ if token_usage is not None:
+ self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens)
+ if request_usage is not None:
+ self.request_usage.extend(
+ successful_requests=request_usage.successful_requests, failed_requests=request_usage.failed_requests
+ )
+
+ def get_usage_stats(self, *, total_time_elapsed: float) -> dict:
+ return self.model_dump() | {
+ "tokens_per_second": int(self.token_usage.total_tokens / total_time_elapsed)
+ if total_time_elapsed > 0
+ else 0,
+ "requests_per_minute": int(self.request_usage.total_requests / total_time_elapsed * 60)
+ if total_time_elapsed > 0
+ else 0,
+ }
diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/utils.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/utils.py
new file mode 100644
index 00000000..6c0418c4
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/utils.py
@@ -0,0 +1,101 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import Any, Literal
+
+
+@dataclass
+class ChatMessage:
+ """A chat message in an LLM conversation.
+
+ This dataclass represents messages exchanged in a conversation with an LLM,
+ supporting various message types including user prompts, assistant responses,
+ system instructions, and tool interactions.
+
+ Attributes:
+ role: The role of the message sender. One of 'user', 'assistant', 'system', or 'tool'.
+ content: The message content. Can be a string or a list of content blocks
+ for multimodal messages (e.g., text + images).
+ reasoning_content: Optional reasoning/thinking content from the assistant,
+ typically from extended thinking or chain-of-thought models.
+ tool_calls: Optional list of tool calls requested by the assistant.
+ Each tool call contains 'id', 'type', and 'function' keys.
+ tool_call_id: Optional ID linking a tool response to its corresponding
+ tool call. Required for messages with role='tool'.
+ """
+
+ role: Literal["user", "assistant", "system", "tool"]
+ content: str | list[dict[str, Any]] = ""
+ reasoning_content: str | None = None
+ tool_calls: list[dict[str, Any]] = field(default_factory=list)
+ tool_call_id: str | None = None
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert the message to a dictionary format for API calls.
+
+ Returns:
+ A dictionary containing the message fields. Only includes non-empty
+ optional fields to keep the output clean.
+ """
+ result: dict[str, Any] = {"role": self.role, "content": self.content}
+ if self.reasoning_content:
+ result["reasoning_content"] = self.reasoning_content
+ if self.tool_calls:
+ result["tool_calls"] = self.tool_calls
+ if self.tool_call_id:
+ result["tool_call_id"] = self.tool_call_id
+ return result
+
+ @classmethod
+ def as_user(cls, content: str | list[dict[str, Any]]) -> ChatMessage:
+ """Create a user message."""
+ return cls(role="user", content=content)
+
+ @classmethod
+ def as_assistant(
+ cls,
+ content: str = "",
+ reasoning_content: str | None = None,
+ tool_calls: list[dict[str, Any]] | None = None,
+ ) -> ChatMessage:
+ """Create an assistant message."""
+ return cls(
+ role="assistant",
+ content=content,
+ reasoning_content=reasoning_content,
+ tool_calls=tool_calls or [],
+ )
+
+ @classmethod
+ def as_system(cls, content: str) -> ChatMessage:
+ """Create a system message."""
+ return cls(role="system", content=content)
+
+ @classmethod
+ def as_tool(cls, content: str, tool_call_id: str) -> ChatMessage:
+ """Create a tool response message."""
+ return cls(role="tool", content=content, tool_call_id=tool_call_id)
+
+
+def prompt_to_messages(
+ *,
+ user_prompt: str,
+ system_prompt: str | None = None,
+ multi_modal_context: list[dict[str, Any]] | None = None,
+) -> list[ChatMessage]:
+ """Convert a user and system prompt into ChatMessage list.
+
+ Args:
+ user_prompt (str): A user prompt.
+ system_prompt (str, optional): An optional system prompt.
+ """
+ user_content: str | list[dict[str, Any]] = user_prompt
+ if multi_modal_context:
+ user_content = [*multi_modal_context, {"type": "text", "text": user_prompt}]
+
+ if system_prompt:
+ return [ChatMessage.as_system(system_prompt), ChatMessage.as_user(user_content)]
+ return [ChatMessage.as_user(user_content)]
diff --git a/packages/data-designer-engine/tests/engine/models/test_async_engine_switch.py b/packages/data-designer-engine/tests/engine/models/test_async_engine_switch.py
new file mode 100644
index 00000000..951efa07
--- /dev/null
+++ b/packages/data-designer-engine/tests/engine/models/test_async_engine_switch.py
@@ -0,0 +1,61 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import importlib
+import sys
+from pathlib import Path
+from types import ModuleType
+
+import pytest
+
+_MODEL_PREFIXES = (
+ "data_designer.engine.models",
+ "data_designer.engine.models_v2",
+)
+
+
+def _matches_models_namespace(module_name: str) -> bool:
+ return any(module_name == prefix or module_name.startswith(f"{prefix}.") for prefix in _MODEL_PREFIXES)
+
+
+def _purge_models_modules() -> dict[str, ModuleType]:
+ saved: dict[str, ModuleType] = {}
+ for module_name in list(sys.modules):
+ if _matches_models_namespace(module_name):
+ module = sys.modules.pop(module_name, None)
+ if isinstance(module, ModuleType):
+ saved[module_name] = module
+ return saved
+
+
+def _restore_models_modules(saved: dict[str, ModuleType]) -> None:
+ for module_name in list(sys.modules):
+ if _matches_models_namespace(module_name):
+ sys.modules.pop(module_name, None)
+ sys.modules.update(saved)
+
+
+def _module_path_parts(module: ModuleType) -> tuple[str, ...]:
+ module_file = module.__file__
+ assert module_file is not None
+ return Path(module_file).parts
+
+
+def test_async_engine_env_switch(monkeypatch: pytest.MonkeyPatch) -> None:
+ saved_modules = _purge_models_modules()
+ try:
+ monkeypatch.delenv("DATA_DESIGNER_ASYNC_ENGINE", raising=False)
+ default_facade = importlib.import_module("data_designer.engine.models.facade")
+ default_parts = _module_path_parts(default_facade)
+ assert "models_v2" not in default_parts
+ assert "models" in default_parts
+
+ _purge_models_modules()
+ monkeypatch.setenv("DATA_DESIGNER_ASYNC_ENGINE", "1")
+ async_facade = importlib.import_module("data_designer.engine.models.facade")
+ async_parts = _module_path_parts(async_facade)
+ assert "models_v2" in async_parts
+ finally:
+ _restore_models_modules(saved_modules)
diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py
index 89eb9b10..81810392 100644
--- a/packages/data-designer-engine/tests/engine/models/test_facade.py
+++ b/packages/data-designer-engine/tests/engine/models/test_facade.py
@@ -9,7 +9,7 @@
from data_designer.engine.mcp.errors import MCPConfigurationError
from data_designer.engine.models.errors import ModelGenerationValidationFailureError
-from data_designer.engine.models.facade import ModelFacade
+from data_designer.engine.models.facade import CustomRouter, ModelFacade
from data_designer.engine.models.parsers.errors import ParserException
from data_designer.engine.models.utils import ChatMessage
@@ -77,18 +77,18 @@ def stub_expected_embedding_response():
(3, 3, 16),
],
)
-@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
+@patch.object(ModelFacade, "completion", autospec=True)
def test_generate(
- mock_completion,
- stub_model_facade,
- max_correction_steps,
- max_conversation_restarts,
- total_calls,
-):
+ mock_completion: Any,
+ stub_model_facade: ModelFacade,
+ max_correction_steps: int,
+ max_conversation_restarts: int,
+ total_calls: int,
+) -> None:
bad_response = mock_oai_response_object("bad response")
mock_completion.side_effect = lambda *args, **kwargs: bad_response
- def _failing_parser(response: str):
+ def _failing_parser(response: str) -> str:
raise ParserException("parser exception")
with pytest.raises(ModelGenerationValidationFailureError):
@@ -119,7 +119,7 @@ def _failing_parser(response: str):
("hello!", [ChatMessage.as_system("hello!"), ChatMessage.as_user("does not matter")]),
],
)
-@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
+@patch.object(ModelFacade, "completion", autospec=True)
def test_generate_with_system_prompt(
mock_completion: Any,
stub_model_facade: ModelFacade,
@@ -183,7 +183,7 @@ def test_consolidate_kwargs(stub_model_configs, stub_model_facade):
True,
],
)
-@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True)
+@patch.object(CustomRouter, "completion", autospec=True)
def test_completion_success(
mock_router_completion: Any,
stub_completion_messages: list[ChatMessage],
@@ -204,7 +204,7 @@ def test_completion_success(
}
-@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True)
+@patch.object(CustomRouter, "completion", autospec=True)
def test_completion_with_exception(
mock_router_completion: Any,
stub_completion_messages: list[ChatMessage],
@@ -216,7 +216,7 @@ def test_completion_with_exception(
stub_model_facade.completion(stub_completion_messages)
-@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True)
+@patch.object(CustomRouter, "completion", autospec=True)
def test_completion_with_kwargs(
mock_router_completion: Any,
stub_completion_messages: list[ChatMessage],
@@ -240,29 +240,36 @@ def mock_completion(self: Any, model: str, messages: list[dict[str, Any]], **kwa
assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs}
-@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True)
-def test_generate_text_embeddings_success(mock_router_embedding, stub_model_facade, stub_expected_embedding_response):
+@patch.object(CustomRouter, "embedding", autospec=True)
+def test_generate_text_embeddings_success(
+ mock_router_embedding: Any,
+ stub_model_facade: ModelFacade,
+ stub_expected_embedding_response: EmbeddingResponse,
+) -> None:
mock_router_embedding.side_effect = lambda self, model, input, **kwargs: stub_expected_embedding_response
input_texts = ["test1", "test2"]
result = stub_model_facade.generate_text_embeddings(input_texts)
assert result == [data["embedding"] for data in stub_expected_embedding_response.data]
-@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True)
-def test_generate_text_embeddings_with_exception(mock_router_embedding, stub_model_facade):
+@patch.object(CustomRouter, "embedding", autospec=True)
+def test_generate_text_embeddings_with_exception(mock_router_embedding: Any, stub_model_facade: ModelFacade) -> None:
mock_router_embedding.side_effect = Exception("Router error")
with pytest.raises(Exception, match="Router error"):
stub_model_facade.generate_text_embeddings(["test1", "test2"])
-@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True)
+@patch.object(CustomRouter, "embedding", autospec=True)
def test_generate_text_embeddings_with_kwargs(
- mock_router_embedding, stub_model_configs, stub_model_facade, stub_expected_embedding_response
-):
+ mock_router_embedding: Any,
+ stub_model_configs: Any,
+ stub_model_facade: ModelFacade,
+ stub_expected_embedding_response: EmbeddingResponse,
+) -> None:
captured_kwargs = {}
- def mock_embedding(self, model, input, **kwargs):
+ def mock_embedding(self: Any, model: str, input: list[str], **kwargs: Any) -> EmbeddingResponse:
captured_kwargs.update(kwargs)
return stub_expected_embedding_response
diff --git a/packages/data-designer-engine/tests/engine/models/test_litellm_overrides.py b/packages/data-designer-engine/tests/engine/models/test_litellm_overrides.py
index 05f3ed79..6ccbe63a 100644
--- a/packages/data-designer-engine/tests/engine/models/test_litellm_overrides.py
+++ b/packages/data-designer-engine/tests/engine/models/test_litellm_overrides.py
@@ -8,6 +8,7 @@
import litellm
import pytest
+from data_designer.engine.models import litellm_overrides
from data_designer.engine.models.litellm_overrides import (
DEFAULT_MAX_CALLBACKS,
CustomRouter,
@@ -56,9 +57,9 @@ def test_apply_litellm_patches_no_exceptions():
pytest.fail(f"apply_litellm_patches() raised an unexpected exception: {e}")
-@patch("data_designer.engine.models.litellm_overrides.quiet_noisy_logger", autospec=True)
-def test_apply_litellm_patches(mock_quiet_noisy_logger):
- apply_litellm_patches()
+@patch.object(litellm_overrides, "quiet_noisy_logger", autospec=True)
+def test_apply_litellm_patches(mock_quiet_noisy_logger: object) -> None:
+ litellm_overrides.apply_litellm_patches()
assert isinstance(litellm.in_memory_llm_clients_cache, ThreadSafeCache)
assert (
litellm.litellm_core_utils.logging_callback_manager.LoggingCallbackManager.MAX_CALLBACKS
diff --git a/packages/data-designer-engine/tests/engine/models/test_model_registry.py b/packages/data-designer-engine/tests/engine/models/test_model_registry.py
index be109a01..a4bfc8e6 100644
--- a/packages/data-designer-engine/tests/engine/models/test_model_registry.py
+++ b/packages/data-designer-engine/tests/engine/models/test_model_registry.py
@@ -6,6 +6,7 @@
import pytest
from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig
+from data_designer.engine.models import litellm_overrides
from data_designer.engine.models.errors import ModelAuthenticationError
from data_designer.engine.models.facade import ModelFacade
from data_designer.engine.models.factory import create_model_registry
@@ -40,10 +41,13 @@ def stub_no_usage_config():
)
-@patch("data_designer.engine.models.litellm_overrides.apply_litellm_patches", autospec=True)
+@patch.object(litellm_overrides, "apply_litellm_patches", autospec=True)
def test_create_model_registry(
- mock_apply_litellm_patches, stub_model_configs, stub_secrets_resolver, stub_model_provider_registry
-):
+ mock_apply_litellm_patches: object,
+ stub_model_configs: list[ModelConfig],
+ stub_secrets_resolver: object,
+ stub_model_provider_registry: object,
+) -> None:
model_registry = create_model_registry(
model_configs=stub_model_configs,
secret_resolver=stub_secrets_resolver,
@@ -272,20 +276,26 @@ def test_get_usage_deltas(
assert deltas == {}
-@patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True)
-@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
-def test_run_health_check_success(mock_completion, mock_generate_text_embeddings, stub_model_registry):
+@patch.object(ModelFacade, "generate_text_embeddings", autospec=True)
+@patch.object(ModelFacade, "completion", autospec=True)
+def test_run_health_check_success(
+ mock_completion: object,
+ mock_generate_text_embeddings: object,
+ stub_model_registry: ModelRegistry,
+) -> None:
model_aliases = {"stub-text", "stub-reasoning", "stub-embedding"}
stub_model_registry.run_health_check(model_aliases)
assert mock_completion.call_count == 2
assert mock_generate_text_embeddings.call_count == 1
-@patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True)
-@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
+@patch.object(ModelFacade, "generate_text_embeddings", autospec=True)
+@patch.object(ModelFacade, "completion", autospec=True)
def test_run_health_check_completion_authentication_error(
- mock_completion, mock_generate_text_embeddings, stub_model_registry
-):
+ mock_completion: object,
+ mock_generate_text_embeddings: object,
+ stub_model_registry: ModelRegistry,
+) -> None:
auth_error = ModelAuthenticationError("Invalid API key for completion model")
mock_completion.side_effect = auth_error
model_aliases = ["stub-text", "stub-reasoning", "stub-embedding"]
@@ -297,11 +307,13 @@ def test_run_health_check_completion_authentication_error(
mock_generate_text_embeddings.assert_not_called()
-@patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True)
-@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
+@patch.object(ModelFacade, "generate_text_embeddings", autospec=True)
+@patch.object(ModelFacade, "completion", autospec=True)
def test_run_health_check_embedding_authentication_error(
- mock_completion, mock_generate_text_embeddings, stub_model_registry
-):
+ mock_completion: object,
+ mock_generate_text_embeddings: object,
+ stub_model_registry: ModelRegistry,
+) -> None:
auth_error = ModelAuthenticationError("Invalid API key for embedding model")
mock_generate_text_embeddings.side_effect = auth_error
model_aliases = ["stub-text", "stub-reasoning", "stub-embedding"]
@@ -313,12 +325,12 @@ def test_run_health_check_embedding_authentication_error(
mock_generate_text_embeddings.assert_called_once()
-@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
+@patch.object(ModelFacade, "completion", autospec=True)
def test_run_health_check_skip_health_check_flag(
- mock_completion,
- stub_secrets_resolver,
- stub_model_provider_registry,
-):
+ mock_completion: object,
+ stub_secrets_resolver: object,
+ stub_model_provider_registry: object,
+) -> None:
# Create model configs: one with skip_health_check=True, others with default (False)
model_configs = [
ModelConfig(
diff --git a/packages/data-designer-engine/tests/engine/validators/test_sql.py b/packages/data-designer-engine/tests/engine/validators/test_sql.py
index 756d0856..3ab599e1 100644
--- a/packages/data-designer-engine/tests/engine/validators/test_sql.py
+++ b/packages/data-designer-engine/tests/engine/validators/test_sql.py
@@ -1,12 +1,17 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
+from unittest.mock import patch
+
+import pytest
+
from data_designer.config.utils.code_lang import CodeLang
from data_designer.config.validator_params import CodeValidatorParams
+from data_designer.engine.validators import sql as sql_validator_module
from data_designer.engine.validators.sql import SQLValidator
-def test_valid_ansi_sql_code():
+def test_valid_ansi_sql_code() -> None:
sql_validator = SQLValidator(CodeValidatorParams(code_lang=CodeLang.SQL_ANSI))
code = "SELECT category, COUNT(*) as total_incidents FROM security_incidents_2 GROUP BY category;"
result = sql_validator.run_validation([{"sql": code}])
@@ -14,9 +19,31 @@ def test_valid_ansi_sql_code():
assert result.data[0].error_messages == ""
-def test_invalid_ansi_sql_code():
+def test_invalid_ansi_sql_code() -> None:
sql_validator = SQLValidator(CodeValidatorParams(code_lang=CodeLang.SQL_ANSI))
code = "NOT SQL"
result = sql_validator.run_validation([{"sql": code}])
assert not result.data[0].is_valid
assert result.data[0].error_messages == "PRS: Line 1, Position 1: Found unparsable section: 'NOT SQL'"
+
+
+def test_sql_validator_multi_column_input_raises() -> None:
+ sql_validator = SQLValidator(CodeValidatorParams(code_lang=CodeLang.SQL_ANSI))
+ with pytest.raises(ValueError, match="single column input"):
+ sql_validator.run_validation([{"sql": "SELECT 1", "extra": "ignored"}])
+
+
+def test_sql_validator_decimal_without_scale_fails() -> None:
+ sql_validator = SQLValidator(CodeValidatorParams(code_lang=CodeLang.SQL_ANSI))
+ code = "CREATE TABLE example (amount DECIMAL(10));"
+ result = sql_validator.run_validation([{"sql": code}])
+ assert not result.data[0].is_valid
+ assert "DECIMAL definitions without a scale" in result.data[0].error_messages
+
+
+def test_sql_validator_handles_lint_exception() -> None:
+ sql_validator = SQLValidator(CodeValidatorParams(code_lang=CodeLang.SQL_ANSI))
+ with patch.object(sql_validator_module.sqlfluff, "lint", side_effect=RuntimeError("boom")):
+ result = sql_validator.run_validation([{"sql": "SELECT 1"}])
+ assert not result.data[0].is_valid
+ assert "Exception during SQL parsing" in result.data[0].error_messages
diff --git a/scripts/benchmarks/benchmark_engine_v2.py b/scripts/benchmarks/benchmark_engine_v2.py
new file mode 100644
index 00000000..f6249021
--- /dev/null
+++ b/scripts/benchmarks/benchmark_engine_v2.py
@@ -0,0 +1,822 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Benchmark DataDesigner engine performance with mock LLMs."""
+
+from __future__ import annotations
+
+import argparse
+import contextlib
+import hashlib
+import json
+import math
+import os
+import random
+import statistics
+import subprocess
+import sys
+import tempfile
+import time
+from collections.abc import Iterator
+from dataclasses import dataclass
+from pathlib import Path
+from typing import TYPE_CHECKING, Any
+
+from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig, ValidationColumnConfig
+from data_designer.config.config_builder import DataDesignerConfigBuilder
+from data_designer.config.mcp import MCPProvider, ToolConfig
+from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider
+from data_designer.config.run_config import RunConfig
+from data_designer.config.sampler_params import SamplerType, UniformSamplerParams
+from data_designer.config.validator_params import LocalCallableValidatorParams, ValidatorType
+from data_designer.engine.mcp.registry import MCPToolDefinition, MCPToolResult
+from data_designer.lazy_heavy_imports import np, pd
+
+if TYPE_CHECKING:
+ import numpy as np
+ import pandas as pd
+
+
+RESULT_PREFIX = "BENCHMARK_RESULT="
+DEFAULT_NUM_RECORDS = 1024
+DEFAULT_BUFFER_SIZE = 1024
+DEFAULT_SEED = 11
+DEFAULT_MAX_PARALLEL_REQUESTS = 16
+DEFAULT_VALIDATOR_BATCH_SIZE = 256
+DEFAULT_ITERATIONS = 5
+
+MOCK_MCP_PROVIDER_NAME = "mock-mcp"
+MOCK_TOOL_ALIAS = "mock-tools"
+MOCK_TOOL_NAME = "mock_lookup"
+MOCK_TOOL_DESCRIPTION = "Mock lookup tool for benchmark runs."
+MOCK_TOOL_SCHEMA = {
+ "type": "object",
+ "properties": {
+ "query": {"type": "string"},
+ "limit": {"type": "integer"},
+ },
+ "required": ["query"],
+}
+
+
+@dataclass(frozen=True)
+class BenchmarkSettings:
+ num_records: int
+ buffer_size: int
+ seed: int
+ max_parallel_requests: int
+ validator_batch_size: int
+
+ def to_cli_args(self) -> list[str]:
+ return [
+ "--num-records",
+ str(self.num_records),
+ "--buffer-size",
+ str(self.buffer_size),
+ "--seed",
+ str(self.seed),
+ "--max-parallel-requests",
+ str(self.max_parallel_requests),
+ "--validator-batch-size",
+ str(self.validator_batch_size),
+ ]
+
+
+@dataclass(frozen=True)
+class BenchmarkResult:
+ engine_mode: str
+ num_records: int
+ buffer_size: int
+ build_time_sec: float
+ total_time_sec: float
+ dataset_hash: str
+ row_count: int
+ column_count: int
+
+ def to_dict(self) -> dict[str, Any]:
+ return {
+ "engine_mode": self.engine_mode,
+ "num_records": self.num_records,
+ "buffer_size": self.buffer_size,
+ "build_time_sec": self.build_time_sec,
+ "total_time_sec": self.total_time_sec,
+ "dataset_hash": self.dataset_hash,
+ "row_count": self.row_count,
+ "column_count": self.column_count,
+ }
+
+ @classmethod
+ def from_dict(cls, payload: dict[str, Any]) -> BenchmarkResult:
+ return cls(
+ engine_mode=str(payload["engine_mode"]),
+ num_records=int(payload["num_records"]),
+ buffer_size=int(payload["buffer_size"]),
+ build_time_sec=float(payload["build_time_sec"]),
+ total_time_sec=float(payload["total_time_sec"]),
+ dataset_hash=str(payload["dataset_hash"]),
+ row_count=int(payload["row_count"]),
+ column_count=int(payload["column_count"]),
+ )
+
+
+@dataclass(frozen=True)
+class MetricStats:
+ mean: float
+ stdev: float
+ ci_half_width: float
+ n: int
+
+ @property
+ def ci_low(self) -> float:
+ return self.mean - self.ci_half_width
+
+ @property
+ def ci_high(self) -> float:
+ return self.mean + self.ci_half_width
+
+
+@dataclass(frozen=True)
+class ResponseProfile:
+ label: str
+ score_mu: float
+ score_sigma: float
+ latency_alpha: float
+ latency_beta: float
+ volatility_sigma: float
+ categories: tuple[str, ...]
+ category_weights: tuple[float, ...]
+
+
+MODEL_PROFILES: dict[str, ResponseProfile] = {
+ "mock-alpha": ResponseProfile(
+ label="alpha",
+ score_mu=0.1,
+ score_sigma=0.35,
+ latency_alpha=2.2,
+ latency_beta=6.0,
+ volatility_sigma=0.25,
+ categories=("low", "mid", "high"),
+ category_weights=(0.25, 0.55, 0.2),
+ ),
+ "mock-beta": ResponseProfile(
+ label="beta",
+ score_mu=0.3,
+ score_sigma=0.45,
+ latency_alpha=2.6,
+ latency_beta=4.8,
+ volatility_sigma=0.3,
+ categories=("low", "mid", "high"),
+ category_weights=(0.2, 0.5, 0.3),
+ ),
+ "mock-gamma": ResponseProfile(
+ label="gamma",
+ score_mu=0.5,
+ score_sigma=0.5,
+ latency_alpha=3.0,
+ latency_beta=3.6,
+ volatility_sigma=0.35,
+ categories=("low", "mid", "high"),
+ category_weights=(0.15, 0.45, 0.4),
+ ),
+}
+
+DEFAULT_PROFILE = ResponseProfile(
+ label="default",
+ score_mu=0.2,
+ score_sigma=0.4,
+ latency_alpha=2.4,
+ latency_beta=5.0,
+ volatility_sigma=0.3,
+ categories=("low", "mid", "high"),
+ category_weights=(0.3, 0.5, 0.2),
+)
+
+
+@dataclass(frozen=True)
+class FakeMessage:
+ content: str
+ tool_calls: list[dict[str, Any]] | None = None
+ reasoning_content: str | None = None
+
+
+@dataclass(frozen=True)
+class FakeChoice:
+ message: FakeMessage
+
+
+@dataclass(frozen=True)
+class FakeResponse:
+ choices: list[FakeChoice]
+ usage: Any | None = None
+ model: str | None = None
+
+
+def _distinct_parallel_requests(base: int) -> tuple[int, int, int]:
+ if base < 3:
+ raise ValueError("max_parallel_requests must be >= 3 to create distinct per-model limits.")
+ high = base
+ mid = max(1, int(round(high / 2)))
+ low = max(1, int(round(high / 5)))
+
+ if mid >= high:
+ mid = high - 1
+ if low >= mid:
+ low = max(1, mid - 1)
+
+ return high, mid, low
+
+
+def _t_critical_95(df: int) -> float:
+ table = {
+ 1: 12.706,
+ 2: 4.303,
+ 3: 3.182,
+ 4: 2.776,
+ 5: 2.571,
+ 6: 2.447,
+ 7: 2.365,
+ 8: 2.306,
+ 9: 2.262,
+ 10: 2.228,
+ 11: 2.201,
+ 12: 2.179,
+ 13: 2.160,
+ 14: 2.145,
+ 15: 2.131,
+ 16: 2.120,
+ 17: 2.110,
+ 18: 2.101,
+ 19: 2.093,
+ 20: 2.086,
+ 21: 2.080,
+ 22: 2.074,
+ 23: 2.069,
+ 24: 2.064,
+ 25: 2.060,
+ 26: 2.056,
+ 27: 2.052,
+ 28: 2.048,
+ 29: 2.045,
+ 30: 2.042,
+ }
+ return table.get(df, 1.96)
+
+
+def _compute_stats(values: list[float]) -> MetricStats:
+ if not values:
+ return MetricStats(mean=0.0, stdev=0.0, ci_half_width=0.0, n=0)
+ if len(values) == 1:
+ return MetricStats(mean=values[0], stdev=0.0, ci_half_width=0.0, n=1)
+ stdev = statistics.stdev(values)
+ mean = statistics.mean(values)
+ t_value = _t_critical_95(len(values) - 1)
+ ci_half_width = t_value * stdev / math.sqrt(len(values))
+ return MetricStats(mean=mean, stdev=stdev, ci_half_width=ci_half_width, n=len(values))
+
+
+def _format_stats(stats: MetricStats, *, unit: str, precision: int = 3) -> str:
+ fmt = f"{{:.{precision}f}}"
+ mean = fmt.format(stats.mean)
+ ci = fmt.format(stats.ci_half_width)
+ stdev = fmt.format(stats.stdev)
+ return f"{mean}{unit} ± {ci}{unit} (stdev {stdev}{unit}, n={stats.n})"
+
+
+def _format_speed_stats(stats: MetricStats, *, precision: int = 2) -> str:
+ fmt = f"{{:.{precision}f}}"
+ mean = fmt.format(stats.mean)
+ ci = fmt.format(stats.ci_half_width)
+ stdev = fmt.format(stats.stdev)
+ return f"{mean}x ± {ci}x (stdev {stdev}x, n={stats.n})"
+
+
+def _significant_diff(stats: MetricStats) -> bool:
+ return stats.n > 1 and abs(stats.mean) > stats.ci_half_width
+
+
+def _json_default(value: Any) -> Any:
+ if isinstance(value, np.generic):
+ return value.item()
+ if isinstance(value, np.ndarray):
+ return value.tolist()
+ if isinstance(value, (pd.Timestamp, pd.Timedelta)):
+ return value.isoformat()
+ if isinstance(value, set):
+ return sorted(value)
+ if isinstance(value, bytes):
+ return value.decode("utf-8", errors="replace")
+ return str(value)
+
+
+def _stable_seed(model: str, messages: list[dict[str, Any]]) -> int:
+ payload = json.dumps(
+ {"model": model, "messages": messages},
+ sort_keys=True,
+ separators=(",", ":"),
+ ensure_ascii=True,
+ default=_json_default,
+ )
+ digest = hashlib.sha256(payload.encode("utf-8")).digest()
+ return int.from_bytes(digest[:8], "big")
+
+
+def _profile_for_model(model: str) -> ResponseProfile:
+ for key, profile in MODEL_PROFILES.items():
+ if key in model:
+ return profile
+ return DEFAULT_PROFILE
+
+
+def _mock_response_text(model: str, messages: list[dict[str, Any]]) -> str:
+ profile = _profile_for_model(model)
+ rng = random.Random(_stable_seed(model, messages))
+ category = rng.choices(profile.categories, weights=profile.category_weights, k=1)[0]
+ score = rng.lognormvariate(profile.score_mu, profile.score_sigma)
+ latency_ms = int(rng.betavariate(profile.latency_alpha, profile.latency_beta) * 900.0)
+ volatility = rng.gauss(0.0, profile.volatility_sigma)
+ return f"{profile.label}:{category}|score={score:.3f}|latency_ms={latency_ms}|vol={volatility:.3f}"
+
+
+def _tool_call_id(model: str, messages: list[dict[str, Any]]) -> str:
+ call_seed = _stable_seed(model, messages)
+ return f"tool-{call_seed:016x}"
+
+
+def _tool_call_arguments(model: str, messages: list[dict[str, Any]]) -> dict[str, Any]:
+ rng = random.Random(_stable_seed(model, messages))
+ return {
+ "query": f"{model}-lookup-{rng.randint(1000, 9999)}",
+ "limit": rng.randint(1, 3),
+ }
+
+
+def _build_tool_call(model: str, messages: list[dict[str, Any]]) -> dict[str, Any]:
+ arguments = _tool_call_arguments(model, messages)
+ return {
+ "id": _tool_call_id(model, messages),
+ "type": "function",
+ "function": {"name": MOCK_TOOL_NAME, "arguments": json.dumps(arguments)},
+ }
+
+
+def _should_request_tool(messages: list[dict[str, Any]]) -> bool:
+ return not any(message.get("role") == "tool" for message in messages)
+
+
+def _mock_tool_definition() -> MCPToolDefinition:
+ return MCPToolDefinition(
+ name=MOCK_TOOL_NAME,
+ description=MOCK_TOOL_DESCRIPTION,
+ input_schema=MOCK_TOOL_SCHEMA,
+ )
+
+
+def _mock_tool_result(tool_name: str, arguments: dict[str, Any], provider_name: str) -> MCPToolResult:
+ payload = {
+ "tool": tool_name,
+ "provider": provider_name,
+ "query": arguments.get("query", ""),
+ "limit": arguments.get("limit", 0),
+ "status": "ok",
+ }
+ return MCPToolResult(content=json.dumps(payload))
+
+
+def _fake_response(model: str, messages: list[dict[str, Any]], **kwargs: Any) -> FakeResponse:
+ if kwargs.get("tools") and _should_request_tool(messages):
+ tool_call = _build_tool_call(model, messages)
+ return FakeResponse(
+ choices=[FakeChoice(message=FakeMessage(content="Using tool.", tool_calls=[tool_call]))],
+ model=model,
+ )
+ response_text = _mock_response_text(model, messages)
+ return FakeResponse(choices=[FakeChoice(message=FakeMessage(content=response_text))], model=model)
+
+
+@contextlib.contextmanager
+def _patch_llm_responses() -> Iterator[None]:
+ # Imports are deferred so engine selection respects DATA_DESIGNER_ASYNC_ENGINE.
+ from data_designer.engine.models.litellm_overrides import CustomRouter
+
+ original_completion = CustomRouter.completion
+ original_acompletion = getattr(CustomRouter, "acompletion", None)
+
+ def fake_completion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> FakeResponse:
+ _ = self
+ return _fake_response(model, messages, **kwargs)
+
+ async def fake_acompletion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> FakeResponse:
+ _ = self
+ return _fake_response(model, messages, **kwargs)
+
+ CustomRouter.completion = fake_completion
+ CustomRouter.acompletion = fake_acompletion
+ try:
+ yield
+ finally:
+ CustomRouter.completion = original_completion
+ if original_acompletion is not None:
+ CustomRouter.acompletion = original_acompletion
+ else:
+ try:
+ delattr(CustomRouter, "acompletion")
+ except AttributeError:
+ pass
+
+
+@contextlib.contextmanager
+def _patch_mcp_io() -> Iterator[None]:
+ import data_designer.engine.mcp.io as mcp_io
+
+ original_list_tools = mcp_io.list_tools
+ original_call_tools = mcp_io.call_tools
+
+ def fake_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]:
+ if getattr(provider, "name", None) != MOCK_MCP_PROVIDER_NAME:
+ return original_list_tools(provider, timeout_sec=timeout_sec)
+ return (_mock_tool_definition(),)
+
+ def fake_call_tools(
+ calls: list[tuple[Any, str, dict[str, Any]]],
+ *,
+ timeout_sec: float | None = None,
+ ) -> list[MCPToolResult]:
+ if any(getattr(call[0], "name", None) != MOCK_MCP_PROVIDER_NAME for call in calls):
+ return original_call_tools(calls, timeout_sec=timeout_sec)
+ return [_mock_tool_result(tool_name, arguments, provider.name) for provider, tool_name, arguments in calls]
+
+ mcp_io.list_tools = fake_list_tools
+ mcp_io.call_tools = fake_call_tools
+ try:
+ yield
+ finally:
+ mcp_io.list_tools = original_list_tools
+ mcp_io.call_tools = original_call_tools
+
+
+def _extract_metric(text: str, key: str) -> float | None:
+ marker = f"{key}="
+ start = text.find(marker)
+ if start == -1:
+ return None
+ start += len(marker)
+ end = start
+ while end < len(text) and (text[end].isdigit() or text[end] in {".", "-"}):
+ end += 1
+ try:
+ return float(text[start:end])
+ except ValueError:
+ return None
+
+
+def _validate_recommendation(df: pd.DataFrame) -> pd.DataFrame:
+ series = df["llm_stage3"].astype(str)
+ scores = series.map(lambda text: _extract_metric(text, "score"))
+ latencies = series.map(lambda text: _extract_metric(text, "latency_ms"))
+ scores_numeric = pd.to_numeric(scores, errors="coerce")
+ latency_numeric = pd.to_numeric(latencies, errors="coerce")
+ is_valid = scores_numeric.between(0.0, 10.0) & latency_numeric.between(0.0, 900.0)
+ return pd.DataFrame(
+ {
+ "is_valid": is_valid.fillna(False).astype(bool),
+ "score": scores_numeric,
+ "latency_ms": latency_numeric,
+ }
+ )
+
+
+def _build_config(settings: BenchmarkSettings) -> DataDesignerConfigBuilder:
+ high_parallel, mid_parallel, low_parallel = _distinct_parallel_requests(settings.max_parallel_requests)
+ model_configs = [
+ ModelConfig(
+ alias="mock-alpha",
+ model="mock-alpha",
+ provider="mock-provider",
+ inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=high_parallel),
+ skip_health_check=True,
+ ),
+ ModelConfig(
+ alias="mock-beta",
+ model="mock-beta",
+ provider="mock-provider",
+ inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=low_parallel),
+ skip_health_check=True,
+ ),
+ ModelConfig(
+ alias="mock-gamma",
+ model="mock-gamma",
+ provider="mock-provider",
+ inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=mid_parallel),
+ skip_health_check=True,
+ ),
+ ]
+
+ builder = DataDesignerConfigBuilder(model_configs=model_configs)
+ builder.add_tool_config(
+ ToolConfig(
+ tool_alias=MOCK_TOOL_ALIAS,
+ providers=[MOCK_MCP_PROVIDER_NAME],
+ allow_tools=[MOCK_TOOL_NAME],
+ max_tool_call_turns=1,
+ timeout_sec=1.0,
+ )
+ )
+ builder.add_column(
+ SamplerColumnConfig(
+ name="seed_value",
+ sampler_type=SamplerType.UNIFORM,
+ params=UniformSamplerParams(low=0.0, high=100.0, decimal_places=3),
+ )
+ )
+ builder.add_column(
+ LLMTextColumnConfig(
+ name="llm_stage1",
+ model_alias="mock-alpha",
+ prompt="Summarize the signal for seed {{ seed_value }}.",
+ )
+ )
+ builder.add_column(
+ LLMTextColumnConfig(
+ name="llm_stage2",
+ model_alias="mock-beta",
+ tool_alias=MOCK_TOOL_ALIAS,
+ prompt="Analyze {{ llm_stage1 }} and produce a risk assessment.",
+ )
+ )
+ builder.add_column(
+ LLMTextColumnConfig(
+ name="llm_stage3",
+ model_alias="mock-gamma",
+ prompt="Generate a recommendation from {{ llm_stage2 }} with seed {{ seed_value }}.",
+ )
+ )
+ builder.add_column(
+ ValidationColumnConfig(
+ name="llm_stage3_validation",
+ target_columns=["llm_stage3"],
+ validator_type=ValidatorType.LOCAL_CALLABLE,
+ validator_params=LocalCallableValidatorParams(validation_function=_validate_recommendation),
+ batch_size=settings.validator_batch_size,
+ )
+ )
+ return builder
+
+
+def _dataset_fingerprint(df: pd.DataFrame) -> str:
+ normalized = df.reset_index(drop=True)
+ normalized = normalized.reindex(sorted(normalized.columns), axis=1)
+ records = normalized.to_dict(orient="records")
+ payload = json.dumps(
+ records,
+ sort_keys=True,
+ separators=(",", ":"),
+ ensure_ascii=True,
+ default=_json_default,
+ )
+ return hashlib.sha256(payload.encode("utf-8")).hexdigest()
+
+
+def _run_single_benchmark(settings: BenchmarkSettings, engine_mode: str) -> BenchmarkResult:
+ # Imports are deferred so engine selection respects DATA_DESIGNER_ASYNC_ENGINE.
+ from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
+ from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder
+ from data_designer.engine.model_provider import resolve_model_provider_registry
+ from data_designer.engine.resources.resource_provider import create_resource_provider
+ from data_designer.engine.resources.seed_reader import SeedReaderRegistry
+ from data_designer.engine.secret_resolver import CompositeResolver, EnvironmentResolver, PlaintextResolver
+
+ random.seed(settings.seed)
+ np.random.seed(settings.seed)
+
+ run_config = RunConfig(
+ buffer_size=settings.buffer_size,
+ disable_early_shutdown=True,
+ max_conversation_restarts=0,
+ max_conversation_correction_steps=0,
+ )
+ builder = _build_config(settings)
+
+ provider = ModelProvider(
+ name="mock-provider",
+ endpoint="https://mock.local",
+ provider_type="openai",
+ api_key="mock-key",
+ )
+ mcp_provider = MCPProvider(
+ name=MOCK_MCP_PROVIDER_NAME,
+ endpoint="https://mock.local/mcp",
+ api_key="mock-mcp-key",
+ )
+ model_provider_registry = resolve_model_provider_registry([provider], default_provider_name=provider.name)
+ secret_resolver = CompositeResolver([EnvironmentResolver(), PlaintextResolver()])
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ artifact_storage = ArtifactStorage(artifact_path=temp_dir, dataset_name=f"benchmark-{engine_mode}")
+ resource_provider = create_resource_provider(
+ artifact_storage=artifact_storage,
+ model_configs=builder.model_configs,
+ secret_resolver=secret_resolver,
+ model_provider_registry=model_provider_registry,
+ seed_reader_registry=SeedReaderRegistry(readers=[]),
+ blob_storage=None,
+ seed_dataset_source=None,
+ run_config=run_config,
+ mcp_providers=[mcp_provider],
+ tool_configs=builder.tool_configs,
+ )
+ dataset_builder = ColumnWiseDatasetBuilder(
+ data_designer_config=builder.build(),
+ resource_provider=resource_provider,
+ )
+
+ total_start = time.perf_counter()
+ with _patch_llm_responses(), _patch_mcp_io():
+ build_start = time.perf_counter()
+ dataset_builder.build(num_records=settings.num_records)
+ build_time = time.perf_counter() - build_start
+ dataset = dataset_builder.artifact_storage.load_dataset_with_dropped_columns()
+
+ dataset_hash = _dataset_fingerprint(dataset)
+ total_time = time.perf_counter() - total_start
+
+ return BenchmarkResult(
+ engine_mode=engine_mode,
+ num_records=settings.num_records,
+ buffer_size=settings.buffer_size,
+ build_time_sec=build_time,
+ total_time_sec=total_time,
+ dataset_hash=dataset_hash,
+ row_count=int(dataset.shape[0]),
+ column_count=int(dataset.shape[1]),
+ )
+
+
+def _run_subprocess(settings: BenchmarkSettings, engine_mode: str) -> BenchmarkResult:
+ env = os.environ.copy()
+ if engine_mode == "async":
+ env["DATA_DESIGNER_ASYNC_ENGINE"] = "1"
+ else:
+ env.pop("DATA_DESIGNER_ASYNC_ENGINE", None)
+
+ script_path = Path(__file__).resolve()
+ cmd = [sys.executable, str(script_path), "--mode", "run", "--engine", engine_mode, *settings.to_cli_args()]
+ completed = subprocess.run(cmd, capture_output=True, text=True, env=env, check=False)
+
+ if completed.returncode != 0:
+ raise RuntimeError(f"Benchmark subprocess failed.\nstdout:\n{completed.stdout}\nstderr:\n{completed.stderr}")
+
+ for line in reversed(completed.stdout.splitlines()):
+ if line.startswith(RESULT_PREFIX):
+ payload = json.loads(line.removeprefix(RESULT_PREFIX))
+ return BenchmarkResult.from_dict(payload)
+
+ raise RuntimeError(
+ f"Benchmark subprocess did not emit a result payload.\nstdout:\n{completed.stdout}\nstderr:\n{completed.stderr}"
+ )
+
+
+def _format_speedup(sync_time: float, async_time: float) -> str:
+ if async_time <= 0:
+ return "n/a"
+ return f"{(sync_time / async_time):.2f}x"
+
+
+def _run_with_progress(settings: BenchmarkSettings, engine_mode: str, iteration: int, total: int) -> BenchmarkResult:
+ print(f"[{iteration}/{total}] Running {engine_mode} benchmark...", end="", flush=True)
+ result = _run_subprocess(settings, engine_mode)
+ print(f" done ({result.total_time_sec:.3f}s)")
+ return result
+
+
+def _compare_runs(settings: BenchmarkSettings, iterations: int) -> int:
+ sync_results: list[BenchmarkResult] = []
+ async_results: list[BenchmarkResult] = []
+ expected_hash: str | None = None
+
+ for iteration in range(1, iterations + 1):
+ sync_result = _run_with_progress(settings, "sync", iteration, iterations)
+ async_result = _run_with_progress(settings, "async", iteration, iterations)
+
+ if sync_result.dataset_hash != async_result.dataset_hash:
+ print(
+ "Content mismatch detected: "
+ f"sync hash {sync_result.dataset_hash} vs async hash {async_result.dataset_hash}"
+ )
+ return 1
+
+ if expected_hash is None:
+ expected_hash = sync_result.dataset_hash
+ elif expected_hash != sync_result.dataset_hash or expected_hash != async_result.dataset_hash:
+ print("Content mismatch detected across iterations.")
+ return 1
+
+ sync_results.append(sync_result)
+ async_results.append(async_result)
+
+ build_sync = [result.build_time_sec for result in sync_results]
+ build_async = [result.build_time_sec for result in async_results]
+ total_sync = [result.total_time_sec for result in sync_results]
+ total_async = [result.total_time_sec for result in async_results]
+
+ build_speedups = [sync / async_ for sync, async_ in zip(build_sync, build_async)]
+ total_speedups = [sync / async_ for sync, async_ in zip(total_sync, total_async)]
+ build_diffs = [sync - async_ for sync, async_ in zip(build_sync, build_async)]
+ total_diffs = [sync - async_ for sync, async_ in zip(total_sync, total_async)]
+
+ build_sync_stats = _compute_stats(build_sync)
+ build_async_stats = _compute_stats(build_async)
+ total_sync_stats = _compute_stats(total_sync)
+ total_async_stats = _compute_stats(total_async)
+
+ build_speed_stats = _compute_stats(build_speedups)
+ total_speed_stats = _compute_stats(total_speedups)
+ build_diff_stats = _compute_stats(build_diffs)
+ total_diff_stats = _compute_stats(total_diffs)
+
+ print("\nEngine benchmark summary (95% CI)")
+ print(f"- runs: {iterations} | content match: yes | hash {expected_hash}")
+ print(f"- build time sync: {_format_stats(build_sync_stats, unit='s')}")
+ print(f"- build time async: {_format_stats(build_async_stats, unit='s')}")
+ print(
+ f"- build speedup: {_format_speed_stats(build_speed_stats)} | "
+ f"paired diff {_format_stats(build_diff_stats, unit='s')} | "
+ f"significant: {'yes' if _significant_diff(build_diff_stats) else 'no'}"
+ )
+ print(f"- total time sync: {_format_stats(total_sync_stats, unit='s')}")
+ print(f"- total time async: {_format_stats(total_async_stats, unit='s')}")
+ print(
+ f"- total speedup: {_format_speed_stats(total_speed_stats)} | "
+ f"paired diff {_format_stats(total_diff_stats, unit='s')} | "
+ f"significant: {'yes' if _significant_diff(total_diff_stats) else 'no'}"
+ )
+
+ return 0
+
+
+def _parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Benchmark DataDesigner engine with mock LLMs and compare async execution."
+ )
+ parser.add_argument(
+ "--mode",
+ type=str,
+ choices=("compare", "run"),
+ default="compare",
+ help="Run both engines in subprocesses, or run once in the current process.",
+ )
+ parser.add_argument(
+ "--engine",
+ type=str,
+ choices=("sync", "async"),
+ default="sync",
+ help="Engine mode for --mode run.",
+ )
+ parser.add_argument("--num-records", type=int, default=DEFAULT_NUM_RECORDS, help="Records to generate.")
+ parser.add_argument("--buffer-size", type=int, default=DEFAULT_BUFFER_SIZE, help="Batch buffer size.")
+ parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="Random seed for determinism.")
+ parser.add_argument(
+ "--iterations",
+ type=int,
+ default=DEFAULT_ITERATIONS,
+ help="Number of sync/async runs to include in the compare mode.",
+ )
+ parser.add_argument(
+ "--max-parallel-requests",
+ type=int,
+ default=DEFAULT_MAX_PARALLEL_REQUESTS,
+ help="Max parallel LLM requests per model.",
+ )
+ parser.add_argument(
+ "--validator-batch-size",
+ type=int,
+ default=DEFAULT_VALIDATOR_BATCH_SIZE,
+ help="Batch size for the local validator.",
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ args = _parse_args()
+ settings = BenchmarkSettings(
+ num_records=args.num_records,
+ buffer_size=args.buffer_size,
+ seed=args.seed,
+ max_parallel_requests=args.max_parallel_requests,
+ validator_batch_size=args.validator_batch_size,
+ )
+
+ if args.mode == "compare":
+ sys.exit(_compare_runs(settings, args.iterations))
+
+ if args.engine == "async":
+ os.environ["DATA_DESIGNER_ASYNC_ENGINE"] = "1"
+ else:
+ os.environ.pop("DATA_DESIGNER_ASYNC_ENGINE", None)
+
+ print(f"Running {args.engine} benchmark...")
+ result = _run_single_benchmark(settings, args.engine)
+ print(f"{RESULT_PREFIX}{json.dumps(result.to_dict(), sort_keys=True)}")
+
+
+if __name__ == "__main__":
+ main()