diff --git a/LITELLM_REMOVAL_ANALYSIS.md b/LITELLM_REMOVAL_ANALYSIS.md new file mode 100644 index 00000000..2b2b705f --- /dev/null +++ b/LITELLM_REMOVAL_ANALYSIS.md @@ -0,0 +1,441 @@ +# LiteLLM Removal: Impact Analysis & Implementation Plan + +## Executive Summary + +LiteLLM serves as Data Designer's abstraction layer between model configuration and HTTP API calls. It provides multi-provider routing, retry/backoff logic, error normalization, and response type unification. The dependency is well-contained — 12 files touch it directly, all within the `engine/models/` and `engine/models_v2/` layers. Removing it is feasible but non-trivial. The work breaks down into five areas: HTTP client, error handling, retry/resilience, response normalization, and provider abstraction. + +**Dependency**: `litellm>=1.73.6,<1.80.12` (pinned in `packages/data-designer-engine/pyproject.toml`) + +--- + +## What LiteLLM Provides Today + +### 1. Multi-Provider Routing + +LiteLLM's `Router` class accepts a deployment list and routes `completion()`/`embedding()` calls to the correct provider endpoint using the `{provider_type}/{model_name}` naming convention. + +**DD's usage**: Each `ModelFacade` gets its own `CustomRouter` with a single-element deployment list — so DD does not actually use multi-deployment load balancing or failover. The Router is effectively a single-provider HTTP client with retry logic. + +**Key takeaway**: DD uses the Router as a resilient HTTP client, not as a load balancer. This simplifies replacement significantly. + +### 2. Provider Abstraction (`provider_type`) + +`ModelProvider.provider_type` (default: `"openai"`) tells LiteLLM which API format to use. LiteLLM translates this into the correct HTTP request format, auth headers, and response parsing for each provider (OpenAI, Anthropic, Cohere, etc.). + +**DD's usage**: The `provider_type` is combined with the model name (`f"{provider_type}/{model_name}"`) and passed to the Router. DD also passes `api_base` (endpoint URL) and `api_key` per deployment. + +**Key question**: How many distinct `provider_type` values does DD actually use in production? If it's primarily `"openai"` (OpenAI-compatible APIs including NVIDIA NIM), the provider abstraction is low-value since most inference providers expose OpenAI-compatible endpoints. If it includes `"anthropic"`, `"cohere"`, etc., the translation layer is more valuable. + +### 3. Retry & Exponential Backoff + +DD's `CustomRouter` extends LiteLLM's `Router` with configurable exponential backoff: + +- **Retry policy**: 3 retries for `RateLimitError`, 3 retries for `Timeout` +- **Backoff formula**: `initial_retry_after_s * 2^retry_count * (1 +/- jitter_pct)` +- **Defaults**: initial=2s, jitter=20%, timeout=60s +- **Server Retry-After**: Extracted from response headers, capped at 60s + +This is the most custom piece of DD's LiteLLM integration. The `_time_to_sleep_before_retry()` and `_extract_retry_delay_from_headers()` overrides are well-tested. + +### 4. Error Normalization (12 Exception Types) + +`handle_llm_exceptions()` in `models/errors.py` pattern-matches 12 LiteLLM exception types and maps them to DD-specific error classes: + +| LiteLLM Exception | DD Error | Notes | +|---|---|---| +| `APIError` | `ModelAPIError` | Special 403 detection | +| `APIConnectionError` | `ModelAPIConnectionError` | Network issues | +| `AuthenticationError` | `ModelAuthenticationError` | Invalid API key | +| `ContextWindowExceededError` | `ModelContextWindowExceededError` | Parses OpenAI token details | +| `UnsupportedParamsError` | `ModelUnsupportedParamsError` | | +| `BadRequestError` | `ModelBadRequestError` | Detects multimodal rejection | +| `InternalServerError` | `ModelInternalServerError` | | +| `NotFoundError` | `ModelNotFoundError` | | +| `PermissionDeniedError` | `ModelPermissionDeniedError` | | +| `RateLimitError` | `ModelRateLimitError` | | +| `Timeout` | `ModelTimeoutError` | | +| `UnprocessableEntityError` | `ModelUnprocessableEntityError` | | + +The DD error types already exist and carry user-friendly `FormattedLLMErrorMessage(cause, solution)` payloads. The only LiteLLM-specific part is the `match` statement that catches LiteLLM's exception classes. + +### 5. Response Type Normalization + +DD accesses these fields from LiteLLM responses: + +**Chat completions** (`ModelResponse`): +- `response.choices[0].message.content` — generated text +- `response.choices[0].message.reasoning_content` — extended thinking (via `getattr`) +- `response.choices[0].message.tool_calls` — MCP tool calls +- `response.usage.prompt_tokens` / `response.usage.completion_tokens` — token counts + +**Embeddings** (`EmbeddingResponse`): +- `response.data[i]["embedding"]` — float arrays +- `response.usage.prompt_tokens` — input token count + +These match the OpenAI API response format exactly. LiteLLM normalizes responses from non-OpenAI providers into this shape. + +### 6. Global Patches + +`apply_litellm_patches()` modifies LiteLLM's global state at startup: +- Replaces the in-memory client cache with a thread-safe version (`ThreadSafeCache`) +- Increases `LoggingCallbackManager.MAX_CALLBACKS` to 1000 (workaround for litellm#9792) +- Suppresses verbose logging from httpx, LiteLLM, and LiteLLM Router + +These are workarounds for LiteLLM's rough edges — they disappear entirely if LiteLLM is removed. + +--- + +## Arguments For Removal + +### Reduced dependency weight +LiteLLM is a large dependency (~50+ transitive packages) with frequent releases and breaking changes. The version pin (`>=1.73.6,<1.80.12`) is already narrow, indicating past compatibility issues. Every LiteLLM upgrade is a risk surface. + +### Simpler async story +The event loop issue we just fixed (LiteLLM's `LoggingWorker` queue binding to a stale loop) is symptomatic of LiteLLM's internal async state being opaque and hard to reason about. With a direct HTTP client, DD controls all async state. + +### Thread-safety workarounds become unnecessary +`ThreadSafeCache`, the `LoggingCallbackManager.MAX_CALLBACKS` patch, and verbose logger suppression are all workarounds for LiteLLM behavior. They represent ongoing maintenance burden. + +### Overfit abstraction +DD uses the Router as a single-deployment client with retry logic. The multi-model routing, caching, and callback infrastructure of LiteLLM's Router class is unused overhead. + +### Import performance +LiteLLM's import time is significant (already lazy-loaded via `lazy_heavy_imports.py`). Removing it improves cold start time. + +### Better error messages +DD already defines its own error types. Currently, LiteLLM's exceptions are caught and re-raised as DD errors. With a direct client, DD can produce better error messages without an intermediate translation layer. + +--- + +## Arguments Against Removal + +### Provider format translation +LiteLLM handles API format differences between providers. If DD only targets OpenAI-compatible endpoints, this is irrelevant. If it supports Anthropic, Cohere, etc., this is significant work to reimplement. + +### Battle-tested retry logic +LiteLLM's Router has been hardened over many releases for rate limiting, retry-after headers, connection pooling, and edge cases. Reimplementing this from scratch risks regressions. + +### Maintained by others +LiteLLM receives frequent updates for new models, API changes, and provider additions. DD's replacement would need to be maintained by the DD team. + +### Feature velocity risk +If DD later needs streaming, function calling improvements, vision model support, or new provider integrations, LiteLLM provides these incrementally. A custom client requires explicit implementation for each. + +--- + +## Blast Radius (by Phase) + +All paths relative to `packages/data-designer-engine/src/data_designer/engine/` unless noted. + +### Phase 1: Replace Router with ModelClient in `models_v2/` + +**Response type strategy:** Keep the OpenAI response format (`choices[0].message.content`, `choices[0].message.tool_calls`, `usage.prompt_tokens`, etc.) as DD's canonical model response type. The `openai` SDK already returns typed objects in this shape — use them directly or define DD-owned dataclasses with the same structure. This means **zero changes** to response access sites in the facade, MCP facade, or anywhere else that reads model responses. Non-OpenAI providers (Phase 3) will be responsible for translating their native responses into this format within their adapter. + +**New files:** +- `models_v2/client.py` — `ModelClient` protocol, `OpenAIModelClient` adapter wrapping the `openai` SDK + +**Delete:** +- `models_v2/litellm_overrides.py` — `CustomRouter`, `ThreadSafeCache`, `apply_litellm_patches()` no longer needed in `models_v2/` + +**Heavy modification:** +- `models_v2/facade.py` — Replace `self._router.completion()` / `self._router.acompletion()` with `ModelClient` calls. Response access patterns (`response.choices[0].message.content`, etc.) stay the same since we keep the OpenAI format. Replace `_get_litellm_deployment()` with adapter construction. +- `models_v2/errors.py` — Replace `litellm.exceptions.*` matching with `openai` SDK exception matching in `handle_llm_exceptions()` + +**Light modification:** +- `models_v2/factory.py` — Remove `apply_litellm_patches()` call, remove litellm imports, construct `ModelClient` adapter instead of `CustomRouter` + +**Tests (for `models_v2/` path):** +- `tests/engine/models/test_facade.py` — Medium rewrite: 26 tests, replace `CustomRouter` patches with `ModelClient` mocks. Response object construction in mocks can use `openai` SDK types directly (same shape as today's litellm types). +- `tests/engine/models/test_model_errors.py` — Medium rewrite: 7 tests, replace `litellm.exceptions.*` with `openai` SDK exceptions +- `tests/engine/models/test_model_registry.py` — Light: remove `apply_litellm_patches` mock +- `tests/engine/models/conftest.py` — Replace 2 fixtures that construct LiteLLM response objects (use `openai` SDK types instead — same shape) +- `scripts/benchmarks/benchmark_engine_v2.py` — Replace `CustomRouter` import/patches with `ModelClient` mocks + +**Dependency:** +- `pyproject.toml` — Add `openai` as direct dependency (already transitive via litellm; no new weight) + +**NOT touched:** `models/` directory is entirely unchanged. `engine/mcp/`, column generators, dataset builders, config layer, validators — all unchanged. No response format changes means no cross-layer ripple. + +--- + +### Phase 2: Validate + +**No code changes.** Validation only: +- Benchmark: `--mode compare` between `models/` (litellm, env var off) and `models_v2/` (direct SDK, env var on) +- Full test suite: `uv run pytest packages/data-designer-engine/tests -x -q` +- Real inference: pdf_qa recipe or equivalent with `DATA_DESIGNER_ASYNC_ENGINE=1` + +--- + +### Phase 3: Additional provider adapters + +**New files:** +- `models_v2/adapters/anthropic.py` — `AnthropicModelClient` +- `models_v2/adapters/bedrock.py` — `BedrockModelClient` + +**Modification:** +- `config/models.py` — `ModelProvider.provider_type: str` → `ProviderType` enum with Pydantic string coercion +- `models_v2/factory.py` — Adapter selection: `match provider_type` → construct appropriate `ModelClient` +- `engine/mcp/facade.py` — If Anthropic's flat tool_use blocks need different extraction than OpenAI's nested format, the tool call normalization logic needs updating. **This is the highest-risk cross-layer change.** +- `models/utils.py` — `ChatMessage.to_dict()` may need to support Anthropic's message format for multi-turn conversations with tool calls + +**Tests:** +- New test files for Anthropic and Bedrock adapters +- Update tool call extraction tests if MCP facade changes + +**Dependency:** +- `pyproject.toml` — Add `anthropic`. Bedrock: add `boto3`/`aiobotocore` or use `asyncio.to_thread()` with sync boto3. + +**NOT touched:** `models/` still unchanged — litellm fallback remains available. + +--- + +### Phase 4: Consolidate and drop dependency + +**Delete entirely:** +- `models/` directory (all files: `facade.py`, `errors.py`, `factory.py`, `litellm_overrides.py`, `registry.py`, `usage.py`, `telemetry.py`, `utils.py`, `parsers/`, `recipes/`) +- `tests/engine/models/test_litellm_overrides.py` + +**Modification:** +- `models/__init__.py` — Remove the path redirect hack; `models_v2/` becomes the sole implementation (or rename to `models/`) +- `lazy_heavy_imports.py` — Remove `litellm` from lazy import registry +- `pyproject.toml` — Remove `litellm>=1.73.6,<1.80.12` +- `uv.lock` — Regenerate + +**Cleanup (non-functional):** +- `config/column_configs.py` — Docstring mentions "via LiteLLM" +- `engine/resources/resource_provider.py` — Comment mentions "heavy dependencies like litellm" +- `engine/mcp/facade.py` — Type hint comment references `litellm.ModelResponse` +- `README.md`, `AGENTS.md` — Documentation references to LiteLLM +- `async_concurrency.py` — Comment mentions "libraries (like LiteLLM)" + +--- + +## What Needs to Be Built + +### 1. `ModelClient` Interface + Provider Adapters + +Replace LiteLLM Router with a `ModelClient` protocol and thin adapter classes that wrap the official provider SDKs. **No raw HTTP** — the SDKs handle networking, connection pooling, retry/backoff, and rate limiting internally. + +```python +class ModelClient(Protocol): + def completion(self, messages: list[dict], **kwargs) -> CompletionResponse: ... + async def acompletion(self, messages: list[dict], **kwargs) -> CompletionResponse: ... + def embedding(self, input_texts: list[str], **kwargs) -> EmbeddingResponse: ... + async def aembedding(self, input_texts: list[str], **kwargs) -> EmbeddingResponse: ... +``` + +Each adapter's job is purely **translation**: +- Translate DD's `ModelConfig` + `InferenceParams` → SDK-specific call parameters +- Call the SDK method (e.g., `await self._client.chat.completions.create(...)`) +- Translate the SDK response → DD's `CompletionResponse` / `EmbeddingResponse` +- Catch SDK-specific exceptions → DD error types + +The SDK client is created once in the factory (same lifecycle as `CustomRouter` today) and reused for all calls. No need for a dedicated I/O service like `mcp/io.py` — the official SDKs already manage connection pools, event loops, and request lifecycle internally. + +### 2. Response Types + +Define lightweight response dataclasses replacing LiteLLM's `ModelResponse` and `EmbeddingResponse`: + +```python +@dataclass +class CompletionResponse: + content: str + reasoning_content: str | None + tool_calls: list[ToolCall] | None + usage: UsageInfo | None + +@dataclass +class EmbeddingResponse: + embeddings: list[list[float]] + usage: UsageInfo | None + +@dataclass +class UsageInfo: + prompt_tokens: int + completion_tokens: int +``` + +**Scope**: ~50 lines. The existing code already accesses these fields — the dataclass just formalizes the contract. + +### 3. Error Handling + +Each SDK has its own exception hierarchy. The adapter for each provider catches SDK-specific exceptions and maps them to DD's existing error types. + +**OpenAI SDK** — exception types map almost 1:1 to DD errors (LiteLLM's exception hierarchy was modeled on OpenAI's): +`BadRequestError(400)`, `AuthenticationError(401)`, `PermissionDeniedError(403)`, `NotFoundError(404)`, `RateLimitError(429)`, `InternalServerError(5xx)`, `APIConnectionError`, `APITimeoutError` + +**Anthropic SDK** — simpler hierarchy: +`APIStatusError` (with `.status_code`), `RateLimitError`, `APIConnectionError`, `APITimeoutError`. Adapter checks status code for finer-grained mapping. + +**Bedrock** — all errors via `botocore.exceptions.ClientError`: +Check `response['Error']['Code']` for `ValidationException(400)`, `AccessDeniedException(403)`, `ThrottlingException(429)`, `InternalServerException(500)`, etc. + +**Nuance**: Some providers encode context window errors as 400 with specific error messages. The existing `parse_context_window_exceeded_error()` and `parse_bad_request_error()` logic handles this — it would need to match on response body strings, same as today. + +The DD error types and `FormattedLLMErrorMessage` formatting already exist. Only the matching logic changes per adapter. + +### 4. Retry & Backoff + +**This varies significantly by provider — the claim that "each SDK handles its own retries" is only partially true.** + +**OpenAI SDK**: Built-in. Default 2 retries, exponential backoff + jitter (0.5s initial, 8s cap). Auto-retries 408, 409, 429, 5xx. Respects `Retry-After` headers (capped at 60s). Configurable via `max_retries` on client or per-request. DD's `CustomRouter` defaults (2s initial, 20% jitter, 3 retries) become client configuration. + +**Anthropic SDK**: Built-in. Same defaults and behavior as OpenAI SDK (0.5s initial, 8s cap, 2 retries). Auto-retries connection errors, 408, 409, 429, 5xx. Same `max_retries` configuration. + +**Bedrock**: **Retry is NOT fully handled.** boto3 has three retry modes (legacy, standard, adaptive), but `ThrottlingException` (429 rate limiting) is **not auto-retried in any mode**. Only `ModelNotReadyException` (also 429, for cold-start) is auto-retried. DD must implement its own retry logic for Bedrock throttling — exponential backoff with jitter, same as `CustomRouter` does today. + +**Bottom line**: For OpenAI and Anthropic, DD can rely on the SDK's built-in retry. For Bedrock, DD needs a standalone retry utility (port of `CustomRouter`'s backoff logic) or a wrapper around the Bedrock adapter. + +### 5. Provider Adapters (use official SDKs directly) + +**Decision**: DD needs multi-provider support, but the set is bounded: OpenAI, Anthropic, and Bedrock. Use each provider's official SDK directly: + +- **OpenAI-compatible** (`openai`): Covers OpenAI, NVIDIA NIM, Azure OpenAI, and any OpenAI-compatible endpoint. Native async via `AsyncOpenAI`. Built-in retry + rate limit handling. Response format is what DD already expects (`choices[0].message.content`). **Thinnest adapter — mostly passthrough.** +- **Anthropic** (`anthropic`): Native async via `AsyncAnthropic`. Built-in retry. But response format differs: content is an array of typed blocks (`text`, `tool_use`, `thinking`), not a single string. Extended thinking is native but structurally different. **Adapter must translate content blocks → DD's expected format.** +- **Bedrock** (`boto3` / `aiobotocore`): Sync-only in boto3; async requires `aiobotocore` as an additional dependency (or `asyncio.to_thread()` as a simpler fallback). No auto-retry for throttling. Response format is Bedrock-native (`output.message.content[].text`), not OpenAI-compatible. **Most adapter work: retry, async wrapper, and response translation.** AWS does offer an `/openai/v1` endpoint that returns OpenAI-compatible responses, which could reduce translation work. + +Each adapter implements the `ModelClient` interface, translating DD's `ModelConfig` into the SDK-specific call and normalizing the response back to DD's `CompletionResponse` / `EmbeddingResponse` types. + +**Key insight**: This approach also simplifies retry/backoff — each SDK handles its own retries natively. DD's `CustomRouter` backoff logic may reduce to just configuration on the underlying client, rather than a reimplementation. + +--- + +## Recommended Approach + +### Architecture: Parallel stack in `models_v2/` + +The `engine/models/` and `engine/models_v2/` directories are near-complete copies, switchable at runtime via a `__init__.py` path redirect on `DATA_DESIGNER_ASYNC_ENGINE`. Only 2 of 14 files actually differ (`facade.py` adds async methods, `errors.py` adds async decorator). The other 12 are pure copy-paste duplicates. + +**Strategy**: Build the new non-litellm implementation entirely within `models_v2/`. Leave `models/` untouched as the stable litellm-backed fallback. The existing env var switch (`DATA_DESIGNER_ASYNC_ENGINE=1`) already gates which module path is used. Once `models_v2/` is validated in production, consolidate by deleting `models/` and dropping the litellm dependency. + +This approach: +- Avoids a risky big-bang swap — litellm remains available as fallback +- Contains all new work to `models_v2/` (6 files to modify, not 12) +- Reuses the existing runtime switching mechanism +- Defers consolidation and dep removal to a clean follow-up + +### Phase 1: Replace Router with ModelClient in `models_v2/` +- Define `ModelClient` protocol and DD-owned response types in `models_v2/` +- Implement `OpenAIModelClient` using the `openai` SDK (already a transitive dep) +- Rewrite `models_v2/facade.py` to use `ModelClient` instead of `CustomRouter` +- Rewrite `models_v2/errors.py` to match on OpenAI SDK exceptions instead of litellm exceptions +- Remove `models_v2/litellm_overrides.py` and litellm imports from `models_v2/factory.py` +- Update response access sites within `models_v2/` (and any shared code that receives responses) +- **Result**: `models_v2/` is litellm-free, `models/` is unchanged + +### Phase 2: Validate +- Run benchmark: `--mode compare` to verify identical output between `models/` (litellm) and `models_v2/` (direct SDK) +- Run full test suite +- Run real inference (pdf_qa recipe or equivalent) with `DATA_DESIGNER_ASYNC_ENGINE=1` +- **Result**: Confidence that the new stack is correct + +### Phase 3: Additional provider adapters +- `AnthropicModelClient`: HIGH risk — content block → string translation, tool_use block → OpenAI tool_calls format, thinking block → reasoning_content. Requires changes to MCP facade tool extraction and ChatMessage serialization. +- `BedrockModelClient`: HIGH risk — manual throttle retry, async via `to_thread` or `aiobotocore`, response format translation from Converse API shape. +- `ProviderType` enum in config with Pydantic string coercion for backwards compatibility +- Each adapter raises explicit `UnsupportedFeatureError` for capabilities the provider doesn't support +- **Result**: Full provider coverage; `models/` (litellm) still available as fallback until all adapters are proven + +### Phase 4: Consolidate and drop dependency +- Delete `models/` directory +- Remove `__init__.py` path redirect hack +- Remove `litellm` from `pyproject.toml` +- Remove `litellm` from `lazy_heavy_imports.py` +- Clean up `uv.lock` +- Update documentation references +- **Result**: Single implementation, litellm fully removed. Only after all provider adapters are validated in production. + +--- + +## Design Decisions (Answered) + +1. ~~**Which `provider_type` values are used in production?**~~ **Answered**: OpenAI (including NIM), Anthropic, and Bedrock. Bounded set — use official SDKs directly behind a `ModelClient` interface. + +2. ~~**Is streaming on the roadmap?**~~ **Answered**: No. Streaming will not be supported. This simplifies the `ModelClient` interface — no need for `stream_completion()` or async iterator return types. Each method is a simple request/response call. + +3. ~~**How does `ModelConfig.provider_type` map to adapter selection?**~~ **Answered**: `provider_type` should become an **enum**, not remain a free-form string. The enum values determine which `ModelClient` adapter is instantiated. This makes the mapping explicit and catches misconfiguration at validation time rather than at runtime. + +4. ~~**What about provider-specific features?**~~ **Answered**: The `ModelClient` interface targets the **OpenAI feature superset** — completions, embeddings, tool calling, extended thinking, etc. OpenAI's adapter is the full implementation. Other providers implement what they can and **raise explicit incompatibility errors** for unsupported features (e.g., if a provider doesn't support tool calling, the adapter raises a clear error rather than silently degrading). This means: + - `ModelClient` defines the full interface anchored to OpenAI's capabilities + - `OpenAIModelClient` implements everything + - `AnthropicModelClient` implements everything but translates response formats (content blocks → string, thinking blocks, tool_use blocks) + - `BedrockModelClient` implements core completions/embeddings; raises `UnsupportedFeatureError` for anything Bedrock's Converse API doesn't support + +--- + +## Risk Assessment + +| Component | Risk | Why | +|---|---|---| +| `ModelClient` interface | Low | Well-understood contract; OpenAI response shape is the reference | +| OpenAI adapter | Low | Thinnest adapter — response format matches, retry built into SDK, `openai` already a transitive dep | +| Anthropic adapter | **High** | Content blocks vs strings, flat tool_use vs nested function format, thinking blocks vs reasoning_content field. Leaks into MCP facade and trace storage. | +| Bedrock adapter | **High** | Manual throttle retry, async requires extra dep, non-OpenAI response format | +| Error mapping (OpenAI) | Low | SDK exceptions map 1:1 to DD errors (LiteLLM modeled its hierarchy on OpenAI's) | +| Response type migration | Medium | Existing code accesses `response.choices[0].message.content` everywhere — either keep structural parity or coordinate refactor of all access sites | +| Test migration | Medium | ~56 test functions need modification; 26 require significant rework (all in `tests/engine/models/`) | +| Parallel stack validation | Low | Same env-var gating pattern as the async engine; benchmark already validates output correctness | + +OpenAI adapter + validation (Phases 1-2) is low risk. The high-risk work is Anthropic and Bedrock adapters (Phase 4), which can be deferred and tackled independently. + +--- + +## Review Findings (Moth Swarm) + +10 independent reviewers examined this report against the actual codebase. Key corrections and additions: + +### Blast radius is larger than stated + +The original count of 2 test files is incomplete. The full test impact: + +| File | Impact | Details | +|---|---|---| +| `test_facade.py` | Heavy rewrite | 26 test functions, imports `ModelResponse`/`EmbeddingResponse`/`Choices`/`Message` from litellm, 6 `CustomRouter` patches | +| `test_litellm_overrides.py` | Delete | 11 tests for `CustomRouter`, `ThreadSafeCache`, `apply_litellm_patches()` | +| `test_model_errors.py` | Medium rewrite | 7 tests, imports all 12 `litellm.exceptions.*` types for parametrized error mapping tests | +| `test_model_registry.py` | Light touch | 1 test patches `apply_litellm_patches()` | +| `conftest.py` (models) | Light touch | 2 fixtures construct `ModelResponse` and `EmbeddingResponse` objects | +| `benchmark_engine_v2.py` | Medium rewrite | Imports `CustomRouter`, patches `completion` and `acompletion` | + +Total: ~56 test functions need modification, with 26 requiring significant rework. All contained within `tests/engine/models/`. + +### Anthropic adapter is HIGH risk, not medium + +The Anthropic SDK's response format is structurally incompatible with DD's assumptions: + +- **Content is an array of typed blocks** (`text`, `thinking`, `tool_use`), not `choices[0].message.content` (string). DD accesses `response.choices[0].message.content` in facade.py, models_v2/facade.py, and mcp/facade.py. +- **Tool calls are flat content blocks** (`name`, `input`), not nested OpenAI format (`function.name`, `function.arguments`). The MCP facade's tool extraction logic (`mcp/facade.py:340-353`) assumes the nested structure. +- **Reasoning is a content block**, not a field. DD uses `getattr(message, "reasoning_content", None)` — Anthropic returns `ThinkingBlock(type="thinking", thinking="...")` in the content array. +- **This leaks beyond the adapter** into the MCP facade (tool extraction), the generation loop (content access), ChatMessage serialization (trace storage), and multimodal content formatting. + +The Anthropic adapter requires refactoring core response handling, not just wrapping the SDK. + +### ModelClient interface needs more specificity + +The proposed 4-method interface is too thin: +- Missing explicit `model` parameter (currently passed per-request to Router) +- `ToolCall` type is undefined — needs `id`, `type`, `function.name`, `function.arguments` +- Response type structure decision: the proposed flat `CompletionResponse` breaks all existing `response.choices[0].message.content` access sites. Either keep structural parity with OpenAI's response shape (nested `choices[0].message`) or coordinate a refactor of every access site. +- `consolidate_kwargs()` merges `inference_parameters.generate_kwargs`, `extra_body`, and `extra_headers` before calling the Router — the adapter contract should document what's in `**kwargs` when it arrives. + +### Retry-after header extraction is LiteLLM-specific + +`CustomRouter._extract_retry_delay_from_headers()` uses: +- `exception.litellm_response_headers` (LiteLLM-specific attribute) +- `exception.response.headers` (httpx attribute, works with OpenAI/Anthropic SDKs) +- `litellm.utils._get_retry_after_from_exception_header()` (LiteLLM utility) + +For OpenAI/Anthropic, the SDKs handle retry internally — this logic is only needed for Bedrock. But the Retry-After header parsing utility (`_get_retry_after_from_exception_header`) needs reimplementation as a standalone function. + +### Dependency impact is lighter than expected + +- `openai` is already a transitive dependency via litellm — promoting to direct dep adds zero new weight +- `httpx` stays regardless — DD depends on it directly for `engine/models/telemetry.py` and `engine/validators/remote.py` +- Net dependency change for OpenAI-only: a reduction (~10-15 packages removed with litellm) +- Adding Anthropic is lightweight (most deps already present via openai) +- Bedrock (boto3/botocore) is the heaviest new addition + +### Config migration is clean + +`provider_type` is only consumed in one place (`_get_litellm_deployment()` → `f"{provider_type}/{model_name}"`). Pydantic handles string → enum coercion automatically, so existing YAML/JSON configs and programmatic construction continue to work. The CLI text field for provider_type would become a select/choice field. All predefined providers and examples use `"openai"` — no existing users need migration. diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py index 42e551be..54b8fd58 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py @@ -96,6 +96,45 @@ def generate(self, data: dict) -> dict: return data + async def agenerate(self, data: dict) -> dict: + deserialized_record = deserialize_json_values(data) + + multi_modal_context = None + if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0: + multi_modal_context = [] + for context in self.config.multi_modal_context: + multi_modal_context.extend(context.get_contexts(deserialized_record)) + + response, trace = await self.model.agenerate( + prompt=self.prompt_renderer.render( + record=deserialized_record, + prompt_template=self.config.prompt, + prompt_type=PromptType.USER_PROMPT, + ), + system_prompt=self.prompt_renderer.render( + record=deserialized_record, + prompt_template=self.config.system_prompt, + prompt_type=PromptType.SYSTEM_PROMPT, + ), + parser=self.response_recipe.parse, + multi_modal_context=multi_modal_context, + tool_alias=self.config.tool_alias, + max_correction_steps=self.max_conversation_correction_steps, + max_conversation_restarts=self.max_conversation_restarts, + purpose=f"running generation for column '{self.config.name}'", + ) + + serialized_output = self.response_recipe.serialize_output(response) + data[self.config.name] = self._process_serialized_output(serialized_output) + + should_save_trace = ( + self.config.with_trace or self.resource_provider.run_config.debug_override_save_all_column_traces + ) + if should_save_trace: + data[self.config.name + TRACE_COLUMN_POSTFIX] = [message.to_dict() for message in trace] + + return data + def _process_serialized_output(self, serialized_output: str) -> str | dict | list: """Process the serialized output from the model. Subclasses can override to customize deserialization.""" return serialized_output diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 577c2976..9de339f5 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -7,6 +7,7 @@ import importlib.metadata import json import logging +import os import time import uuid from pathlib import Path @@ -31,6 +32,7 @@ from data_designer.engine.dataset_builders.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig +from data_designer.engine.dataset_builders.utils.async_concurrency import AsyncConcurrentExecutor from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager @@ -50,6 +52,11 @@ logger = logging.getLogger(__name__) +DATA_DESIGNER_ASYNC_ENGINE = os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "0") == "1" + +if DATA_DESIGNER_ASYNC_ENGINE: + logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled — using async concurrency") + _CLIENT_VERSION: str = importlib.metadata.version("data-designer-engine") @@ -205,7 +212,11 @@ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None: max_workers = self._resource_provider.run_config.non_inference_max_parallel_workers if isinstance(generator, ColumnGeneratorWithModel): max_workers = generator.inference_parameters.max_parallel_requests - self._fan_out_with_threads(generator, max_workers=max_workers) + if DATA_DESIGNER_ASYNC_ENGINE: + logger.info("⚡ Using async engine for concurrent execution") + self._fan_out_with_async(generator, max_workers=max_workers) + else: + self._fan_out_with_threads(generator, max_workers=max_workers) def _run_full_column_generator(self, generator: ColumnGenerator) -> None: df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True)) @@ -227,6 +238,41 @@ def _run_mcp_tool_check_if_needed(self) -> None: raise DatasetGenerationError(f"Tool alias(es) {tool_aliases!r} specified but no MCPRegistry configured.") self._resource_provider.mcp_registry.run_health_check(tool_aliases) + def _fan_out_with_async(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None: + if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL: + raise DatasetGenerationError( + f"Generator {generator.name} is not a {GenerationStrategy.CELL_BY_CELL} " + "generator so concurrency through async is not supported." + ) + + progress_tracker = ProgressTracker( + total_records=self.batch_manager.num_records_batch, + label=f"{generator.config.column_type} column '{generator.config.name}'", + ) + progress_tracker.log_start(max_workers) + + settings = self._resource_provider.run_config + executor = AsyncConcurrentExecutor( + max_workers=max_workers, + column_name=generator.config.name, + result_callback=self._make_result_callback(progress_tracker), + error_callback=self._make_error_callback(progress_tracker), + shutdown_error_rate=settings.shutdown_error_rate, + shutdown_error_window=settings.shutdown_error_window, + disable_early_shutdown=settings.disable_early_shutdown, + ) + + work_items = [ + (generator.agenerate(record), {"index": i}) for i, record in self.batch_manager.iter_current_batch() + ] + executor.run(work_items) + + progress_tracker.log_final() + + if len(self._records_to_drop) > 0: + self.batch_manager.drop_records(self._records_to_drop) + self._records_to_drop.clear() + def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None: if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL: raise DatasetGenerationError( diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py new file mode 100644 index 00000000..896a99cb --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import json +import logging +import threading +from collections.abc import Coroutine +from dataclasses import dataclass +from typing import Any, Generic, TypeVar + +from data_designer.engine.dataset_builders.utils.concurrency import ( + CallbackWithContext, + ErrorCallbackWithContext, + ExecutorResults, +) +from data_designer.engine.errors import DataDesignerRuntimeError + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +@dataclass(frozen=True, slots=True) +class Success(Generic[T]): + index: int + value: T + + +@dataclass(frozen=True, slots=True) +class Failure: + index: int + error: Exception + + +TaskResult = Success[T] | Failure + +_loop: asyncio.AbstractEventLoop | None = None +_thread: threading.Thread | None = None +_lock = threading.Lock() + + +def _run_loop(loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + +def _ensure_async_engine_loop() -> asyncio.AbstractEventLoop: + """Get or create a persistent event loop for async engine work. + + A single event loop is shared across all AsyncConcurrentExecutor instances + to avoid breaking libraries (like LiteLLM) that bind internal async state + to a specific event loop. + """ + global _loop, _thread + with _lock: + if _loop is None or not _loop.is_running(): + _loop = asyncio.new_event_loop() + _thread = threading.Thread(target=_run_loop, args=(_loop,), daemon=True, name="AsyncEngine-EventLoop") + _thread.start() + return _loop + + +class AsyncConcurrentExecutor: + """Async equivalent of ConcurrentThreadExecutor. + + Executes a batch of coroutines with bounded concurrency, error rate + monitoring, and early shutdown semantics. Callers remain synchronous — + the ``run()`` method submits work to a persistent background event loop. + + No locks are needed because asyncio tasks run cooperatively on a + single thread — mutations to ``_results`` are always sequential. + """ + + def __init__( + self, + *, + max_workers: int, + column_name: str, + result_callback: CallbackWithContext | None = None, + error_callback: ErrorCallbackWithContext | None = None, + shutdown_error_rate: float = 0.50, + shutdown_error_window: int = 10, + disable_early_shutdown: bool = False, + ) -> None: + self._column_name = column_name + self._max_workers = max_workers + self._result_callback = result_callback + self._error_callback = error_callback + self._shutdown_error_rate = shutdown_error_rate + self._shutdown_window_size = shutdown_error_window + self._disable_early_shutdown = disable_early_shutdown + self._results = ExecutorResults(failure_threshold=shutdown_error_rate) + + @property + def results(self) -> ExecutorResults: + return self._results + + @property + def max_workers(self) -> int: + return self._max_workers + + @property + def shutdown_error_rate(self) -> float: + return self._shutdown_error_rate + + @property + def shutdown_window_size(self) -> int: + return self._shutdown_window_size + + def run(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> None: + """Execute all work items concurrently. Callers remain synchronous.""" + logger.debug( + f"AsyncConcurrentExecutor: launching {len(work_items)} tasks " + f"with max_workers={self._max_workers} for column '{self._column_name}'" + ) + loop = _ensure_async_engine_loop() + future = asyncio.run_coroutine_threadsafe(self._run_all(work_items), loop) + future.result() + + async def _run_all(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> None: + self._semaphore = asyncio.Semaphore(self._max_workers) + self._shutdown_event = asyncio.Event() + + async with asyncio.TaskGroup() as tg: + for i, (coro, context) in enumerate(work_items): + tg.create_task(self._run_task(i, coro, context)) + + if not self._disable_early_shutdown and self._results.early_shutdown: + self._raise_task_error() + + async def _run_task(self, index: int, coro: Coroutine[Any, Any, Any], context: dict | None) -> None: + if self._shutdown_event.is_set(): + return + + async with self._semaphore: + if self._shutdown_event.is_set(): + return + + try: + result = await coro + self._results.completed_count += 1 + self._results.success_count += 1 + if self._result_callback is not None: + self._result_callback(result, context=context) + except Exception as err: + self._results.completed_count += 1 + self._results.error_trap.handle_error(err) + if not self._disable_early_shutdown and self._results.is_error_rate_exceeded( + self._shutdown_window_size + ): + if not self._results.early_shutdown: + self._results.early_shutdown = True + self._shutdown_event.set() + if self._error_callback is not None: + self._error_callback(err, context=context) + + def _raise_task_error(self) -> None: + raise DataDesignerRuntimeError( + "\n".join( + [ + " |-- Data generation was terminated early due to error rate exceeding threshold.", + f" |-- The summary of encountered errors is: \n{json.dumps(self._results.summary, indent=4)}", + ] + ) + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/__init__.py index 52a7a9da..ba053a37 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/__init__.py @@ -1,2 +1,27 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +from pathlib import Path + +_ASYNC_ENGINE_ENV_VAR = "DATA_DESIGNER_ASYNC_ENGINE" +_TRUTHY_ENV_VALUES = {"1", "true", "yes"} + + +def _is_async_engine_enabled() -> bool: + return os.getenv(_ASYNC_ENGINE_ENV_VAR, "").lower() in _TRUTHY_ENV_VALUES + + +def _redirect_to_models_v2() -> None: + models_v2_path = Path(__file__).resolve().parent.parent / "models_v2" + # Set DATA_DESIGNER_ASYNC_ENGINE before importing this package for it to take effect. + global __path__ + __path__ = [str(models_v2_path)] + if __spec__ is not None: + __spec__.submodule_search_locations = [str(models_v2_path)] + + +if __name__ == "data_designer.engine.models" and _is_async_engine_enabled(): + _redirect_to_models_v2() diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/__init__.py new file mode 100644 index 00000000..52a7a9da --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/errors.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/errors.py new file mode 100644 index 00000000..0bbb26b4 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/errors.py @@ -0,0 +1,325 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +from collections.abc import Callable +from functools import wraps +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel + +from data_designer.engine.errors import DataDesignerError +from data_designer.lazy_heavy_imports import litellm + +if TYPE_CHECKING: + import litellm + +logger = logging.getLogger(__name__) + + +def get_exception_primary_cause(exception: BaseException) -> BaseException: + """Returns the primary cause of an exception by walking backwards. + + This recursive walkback halts when it arrives at an exception which + has no provided __cause__ (e.g. __cause__ is None). + + Args: + exception (Exception): An exception to start from. + + Raises: + RecursionError: if for some reason exceptions have circular + dependencies (seems impossible in practice). + """ + if exception.__cause__ is None: + return exception + else: + return get_exception_primary_cause(exception.__cause__) + + +class GenerationValidationFailureError(Exception): ... + + +class ModelRateLimitError(DataDesignerError): ... + + +class ModelTimeoutError(DataDesignerError): ... + + +class ModelContextWindowExceededError(DataDesignerError): ... + + +class ModelAuthenticationError(DataDesignerError): ... + + +class ModelPermissionDeniedError(DataDesignerError): ... + + +class ModelNotFoundError(DataDesignerError): ... + + +class ModelUnsupportedParamsError(DataDesignerError): ... + + +class ModelBadRequestError(DataDesignerError): ... + + +class ModelInternalServerError(DataDesignerError): ... + + +class ModelAPIError(DataDesignerError): ... + + +class ModelUnprocessableEntityError(DataDesignerError): ... + + +class ModelAPIConnectionError(DataDesignerError): ... + + +class ModelStructuredOutputError(DataDesignerError): ... + + +class ModelGenerationValidationFailureError(DataDesignerError): ... + + +class FormattedLLMErrorMessage(BaseModel): + cause: str + solution: str + + def __str__(self) -> str: + return "\n".join( + [ + " |----------", + f" | Cause: {self.cause}", + f" | Solution: {self.solution}", + " |----------", + ] + ) + + +def handle_llm_exceptions( + exception: Exception, model_name: str, model_provider_name: str, purpose: str | None = None +) -> None: + """Handle LLM-related exceptions and convert them to appropriate DataDesignerError errors. + + This method centralizes the exception handling logic for LLM operations, + making it reusable across different contexts. + + Args: + exception: The exception that was raised + model_name: Name of the model that was being used + model_provider_name: Name of the model provider that was being used + purpose: The purpose of the model usage to show as context in the error message + Raises: + DataDesignerError: A more user-friendly error with appropriate error type and message + """ + purpose = purpose or "running generation" + authentication_error = FormattedLLMErrorMessage( + cause=f"The API key provided for model {model_name!r} was found to be invalid or expired while {purpose}.", + solution=f"Verify your API key for model provider and update it in your settings for model provider {model_provider_name!r}.", + ) + err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose) + match exception: + # Common errors that can come from LiteLLM + case litellm.exceptions.APIError(): + raise err_msg_parser.parse_api_error(exception, authentication_error) from None + + case litellm.exceptions.APIConnectionError(): + raise ModelAPIConnectionError( + FormattedLLMErrorMessage( + cause=f"Connection to model {model_name!r} hosted on model provider {model_provider_name!r} failed while {purpose}.", + solution="Check your network/proxy/firewall settings.", + ) + ) from None + + case litellm.exceptions.AuthenticationError(): + raise ModelAuthenticationError(authentication_error) from None + + case litellm.exceptions.ContextWindowExceededError(): + raise err_msg_parser.parse_context_window_exceeded_error(exception) from None + + case litellm.exceptions.UnsupportedParamsError(): + raise ModelUnsupportedParamsError( + FormattedLLMErrorMessage( + cause=f"One or more of the parameters you provided were found to be unsupported by model {model_name!r} while {purpose}.", + solution=f"Review the documentation for model provider {model_provider_name!r} and adjust your request.", + ) + ) from None + + case litellm.exceptions.BadRequestError(): + raise err_msg_parser.parse_bad_request_error(exception) from None + + case litellm.exceptions.InternalServerError(): + raise ModelInternalServerError( + FormattedLLMErrorMessage( + cause=f"Model {model_name!r} is currently experiencing internal server issues while {purpose}.", + solution=f"Try again in a few moments. Check with your model provider {model_provider_name!r} if the issue persists.", + ) + ) from None + + case litellm.exceptions.NotFoundError(): + raise ModelNotFoundError( + FormattedLLMErrorMessage( + cause=f"The specified model {model_name!r} could not be found while {purpose}.", + solution=f"Check that the model name is correct and supported by your model provider {model_provider_name!r} and try again.", + ) + ) from None + + case litellm.exceptions.PermissionDeniedError(): + raise ModelPermissionDeniedError( + FormattedLLMErrorMessage( + cause=f"Your API key was found to lack the necessary permissions to use model {model_name!r} while {purpose}.", + solution=f"Use an API key that has the right permissions for the model or use a model the API key in use has access to in model provider {model_provider_name!r}.", + ) + ) from None + + case litellm.exceptions.RateLimitError(): + raise ModelRateLimitError( + FormattedLLMErrorMessage( + cause=f"You have exceeded the rate limit for model {model_name!r} while {purpose}.", + solution="Wait and try again in a few moments.", + ) + ) from None + + case litellm.exceptions.Timeout(): + raise ModelTimeoutError( + FormattedLLMErrorMessage( + cause=f"The request to model {model_name!r} timed out while {purpose}.", + solution="Check your connection and try again. You may need to increase the timeout setting for the model.", + ) + ) from None + + case litellm.exceptions.UnprocessableEntityError(): + raise ModelUnprocessableEntityError( + FormattedLLMErrorMessage( + cause=f"The request to model {model_name!r} failed despite correct request format while {purpose}.", + solution="This is most likely temporary. Try again in a few moments.", + ) + ) from None + + # Parsing and validation errors + case GenerationValidationFailureError(): + raise ModelGenerationValidationFailureError( + FormattedLLMErrorMessage( + cause=f"The provided output schema was unable to be parsed from model {model_name!r} responses while {purpose}.", + solution="This is most likely temporary as we make additional attempts. If you continue to see more of this, simplify or modify the output schema for structured output and try again. If you are attempting token-intensive tasks like generations with high-reasoning effort, ensure that max_tokens in the model config is high enough to reach completion.", + ) + ) from None + + case DataDesignerError(): + raise exception from None + + case _: + raise DataDesignerError( + FormattedLLMErrorMessage( + cause=f"An unexpected error occurred while {purpose}.", + solution=f"Review the stack trace for more details: {exception}", + ) + ) from exception + + +def catch_llm_exceptions(func: Callable) -> Callable: + """This decorator should be used on any `ModelFacade` method that could potentially raise + exceptions that should turn into upstream user-facing errors. + """ + + @wraps(func) + def wrapper(model_facade: Any, *args, **kwargs): + try: + return func(model_facade, *args, **kwargs) + except Exception as e: + logger.debug( + "\n".join( + [ + "", + "|----------", + f"| Caught an exception downstream of type {type(e)!r}. Re-raising it below as a custom error with more context.", + "|----------", + ] + ), + exc_info=True, + stack_info=True, + ) + handle_llm_exceptions( + e, model_facade.model_name, model_facade.model_provider_name, purpose=kwargs.get("purpose") + ) + + return wrapper + + +def acatch_llm_exceptions(func: Callable) -> Callable: + @wraps(func) + async def wrapper(model_facade: Any, *args: Any, **kwargs: Any) -> Any: + try: + return await func(model_facade, *args, **kwargs) + except Exception as e: + logger.debug( + "\n".join( + [ + "", + "|----------", + f"| Caught an exception downstream of type {type(e)!r}. Re-raising it below as a custom error with more context.", + "|----------", + ] + ), + exc_info=True, + stack_info=True, + ) + handle_llm_exceptions( + e, model_facade.model_name, model_facade.model_provider_name, purpose=kwargs.get("purpose") + ) + + return wrapper + + +class DownstreamLLMExceptionMessageParser: + def __init__(self, model_name: str, model_provider_name: str, purpose: str): + self.model_name = model_name + self.model_provider_name = model_provider_name + self.purpose = purpose + + def parse_bad_request_error(self, exception: litellm.exceptions.BadRequestError) -> DataDesignerError: + err_msg = FormattedLLMErrorMessage( + cause=f"The request for model {self.model_name!r} was found to be malformed or missing required parameters while {self.purpose}.", + solution="Check your request parameters and try again.", + ) + if "is not a multimodal model" in str(exception): + err_msg = FormattedLLMErrorMessage( + cause=f"Model {self.model_name!r} is not a multimodal model, but it looks like you are trying to provide multimodal context while {self.purpose}.", + solution="Check your request parameters and try again.", + ) + return ModelBadRequestError(err_msg) + + def parse_context_window_exceeded_error( + self, exception: litellm.exceptions.ContextWindowExceededError + ) -> DataDesignerError: + cause = f"The input data for model '{self.model_name}' was found to exceed its supported context width while {self.purpose}." + try: + if "OpenAIException - This model's maximum context length is " in str(exception): + openai_exception_cause = ( + str(exception).split("OpenAIException - ")[1].split("\n")[0].split(" Please reduce ")[0] + ) + cause = f"{cause} {openai_exception_cause}" + except Exception: + pass + finally: + return ModelContextWindowExceededError( + FormattedLLMErrorMessage( + cause=cause, + solution="Check the model's supported max context width. Adjust the length of your input along with completions and try again.", + ) + ) + + def parse_api_error( + self, exception: litellm.exceptions.InternalServerError, auth_error_msg: FormattedLLMErrorMessage + ) -> DataDesignerError: + if "Error code: 403" in str(exception): + return ModelAuthenticationError(auth_error_msg) + + return ModelAPIError( + FormattedLLMErrorMessage( + cause=f"An unexpected API error occurred with model {self.model_name!r} while {self.purpose}.", + solution=f"Try again in a few moments. Check with your model provider {self.model_provider_name!r} if the issue persists.", + ) + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/facade.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/facade.py new file mode 100644 index 00000000..031ab919 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/facade.py @@ -0,0 +1,495 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Callable +from copy import deepcopy +from typing import TYPE_CHECKING, Any + +from data_designer.config.models import GenerationType, ModelConfig, ModelProvider +from data_designer.engine.mcp.errors import MCPConfigurationError +from data_designer.engine.model_provider import ModelProviderRegistry +from data_designer.engine.models.errors import ( + GenerationValidationFailureError, + catch_llm_exceptions, + get_exception_primary_cause, +) +from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs +from data_designer.engine.models.parsers.errors import ParserException +from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats +from data_designer.engine.models.utils import ChatMessage, prompt_to_messages +from data_designer.engine.models_v2.errors import acatch_llm_exceptions +from data_designer.engine.secret_resolver import SecretResolver +from data_designer.lazy_heavy_imports import litellm + +if TYPE_CHECKING: + import litellm + + from data_designer.engine.mcp.facade import MCPFacade + from data_designer.engine.mcp.registry import MCPRegistry + +logger = logging.getLogger(__name__) + + +class ModelFacade: + def __init__( + self, + model_config: ModelConfig, + secret_resolver: SecretResolver, + model_provider_registry: ModelProviderRegistry, + *, + mcp_registry: MCPRegistry | None = None, + ) -> None: + self._model_config = model_config + self._secret_resolver = secret_resolver + self._model_provider_registry = model_provider_registry + self._mcp_registry = mcp_registry + self._litellm_deployment = self._get_litellm_deployment(model_config) + self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump()) + self._usage_stats = ModelUsageStats() + + @property + def model_name(self) -> str: + return self._model_config.model + + @property + def model_provider(self) -> ModelProvider: + return self._model_provider_registry.get_provider(self._model_config.provider) + + @property + def model_generation_type(self) -> GenerationType: + return self._model_config.generation_type + + @property + def model_provider_name(self) -> str: + return self.model_provider.name + + @property + def model_alias(self) -> str: + return self._model_config.alias + + @property + def usage_stats(self) -> ModelUsageStats: + return self._usage_stats + + def completion( + self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs + ) -> litellm.ModelResponse: + message_payloads = [message.to_dict() for message in messages] + logger.debug( + f"Prompting model {self.model_name!r}...", + extra={"model": self.model_name, "messages": message_payloads}, + ) + response = None + kwargs = self.consolidate_kwargs(**kwargs) + try: + response = self._router.completion(model=self.model_name, messages=message_payloads, **kwargs) + logger.debug( + f"Received completion from model {self.model_name!r}", + extra={ + "model": self.model_name, + "response": response, + "text": response.choices[0].message.content, + "usage": self._usage_stats.model_dump(), + }, + ) + return response + except Exception as e: + raise e + finally: + if not skip_usage_tracking and response is not None: + self._track_usage(response) + + def consolidate_kwargs(self, **kwargs) -> dict[str, Any]: + # Remove purpose from kwargs to avoid passing it to the model + kwargs.pop("purpose", None) + kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs} + if self.model_provider.extra_body: + kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} + if self.model_provider.extra_headers: + kwargs["extra_headers"] = self.model_provider.extra_headers + return kwargs + + def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: + if tool_alias is None: + return None + if self._mcp_registry is None: + raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.") + + try: + return self._mcp_registry.get_mcp(tool_alias=tool_alias) + except ValueError as exc: + raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc + + @catch_llm_exceptions + def generate_text_embeddings( + self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs + ) -> list[list[float]]: + logger.debug( + f"Generating embeddings with model {self.model_name!r}...", + extra={ + "model": self.model_name, + "input_count": len(input_texts), + }, + ) + kwargs = self.consolidate_kwargs(**kwargs) + response = None + try: + response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs) + logger.debug( + f"Received embeddings from model {self.model_name!r}", + extra={ + "model": self.model_name, + "embedding_count": len(response.data) if response.data else 0, + "usage": self._usage_stats.model_dump(), + }, + ) + if response.data and len(response.data) == len(input_texts): + return [data["embedding"] for data in response.data] + else: + raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}") + except Exception as e: + raise e + finally: + if not skip_usage_tracking and response is not None: + self._track_usage_from_embedding(response) + + @catch_llm_exceptions + def generate( + self, + prompt: str, + *, + parser: Callable[[str], Any], + system_prompt: str | None = None, + multi_modal_context: list[dict[str, Any]] | None = None, + tool_alias: str | None = None, + max_correction_steps: int = 0, + max_conversation_restarts: int = 0, + skip_usage_tracking: bool = False, + purpose: str | None = None, + **kwargs, + ) -> tuple[Any, list[ChatMessage]]: + """Generate a parsed output with correction steps. + + This generation call will attempt to generate an output which is + valid according to the specified parser, where "valid" implies + that the parser can process the LLM response without raising + an exception. + + `ParserExceptions` are routed back + to the LLM as new rounds in the conversation, where the LLM is provided its + earlier response along with the "user" role responding with the exception string + (not traceback). This will continue for the number of rounds specified by + `max_correction_steps`. + + Args: + prompt (str): Task prompt. + system_prompt (str, optional): Optional system instructions. If not specified, + no system message is provided and the model should use its default system + prompt. + parser (func(str) -> Any): A function applied to the LLM response which processes + an LLM response into some output object. + tool_alias (str | None): Optional tool configuration alias. When provided, + the model may call permitted tools from the configured MCP providers. + The alias must reference a ToolConfig registered in the MCPRegistry. + max_correction_steps (int): Maximum number of correction rounds permitted + within a single conversation. Note, many rounds can lead to increasing + context size without necessarily improving performance -- small language + models can enter repeated cycles which will not be solved with more steps. + Default: `0` (no correction). + max_conversation_restarts (int): Maximum number of full conversation restarts permitted + if generation fails. Default: `0` (no restarts). + skip_usage_tracking (bool): Whether to skip usage tracking. Default: `False`. + purpose (str): The purpose of the model usage to show as context in the error message. + It is expected to be used by the @catch_llm_exceptions decorator. + **kwargs: Additional arguments to pass to the model. + + Returns: + A tuple containing: + - The parsed output object from the parser. + - The full trace of ChatMessage entries in the conversation, including any tool calls, + corrections, and reasoning traces. Callers can decide whether to store this. + + Raises: + GenerationValidationFailureError: If the maximum number of retries or + correction steps are met and the last response failures on + generation validation. + MCPConfigurationError: If tool_alias is specified but no MCPRegistry is configured. + """ + output_obj = None + tool_schemas = None + tool_call_turns = 0 + curr_num_correction_steps = 0 + curr_num_restarts = 0 + + mcp_facade = self._get_mcp_facade(tool_alias) + + # Checkpoint for restarts - updated after tool calls so we don't repeat them + restart_checkpoint = prompt_to_messages( + user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context + ) + checkpoint_tool_call_turns = 0 + messages: list[ChatMessage] = deepcopy(restart_checkpoint) + + if mcp_facade is not None: + tool_schemas = mcp_facade.get_tool_schemas() + + while True: + completion_kwargs = dict(kwargs) + if tool_schemas is not None: + completion_kwargs["tools"] = tool_schemas + + completion_response = self.completion( + messages, + skip_usage_tracking=skip_usage_tracking, + **completion_kwargs, + ) + + # Process any tool calls in the response (handles parallel tool calling) + if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response): + tool_call_turns += 1 + + if tool_call_turns > mcp_facade.max_tool_call_turns: + # Gracefully refuse tool calls when budget is exhausted + messages.extend(mcp_facade.refuse_completion_response(completion_response)) + else: + messages.extend(mcp_facade.process_completion_response(completion_response)) + + # Update checkpoint so restarts don't repeat tool calls + restart_checkpoint = deepcopy(messages) + checkpoint_tool_call_turns = tool_call_turns + + continue # Back to top + + # No tool calls remaining to process + response = completion_response.choices[0].message.content or "" + reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None) + messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) + curr_num_correction_steps += 1 + + try: + output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below + break + except ParserException as exc: + if max_correction_steps == 0 and max_conversation_restarts == 0: + raise GenerationValidationFailureError( + "Unsuccessful generation attempt. No retries were attempted." + ) from exc + + if curr_num_correction_steps <= max_correction_steps: + # Add user message with error for correction + messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc)))) + + elif curr_num_restarts < max_conversation_restarts: + curr_num_correction_steps = 0 + curr_num_restarts += 1 + messages = deepcopy(restart_checkpoint) + tool_call_turns = checkpoint_tool_call_turns + + else: + raise GenerationValidationFailureError( + f"Unsuccessful generation despite {max_correction_steps} correction steps " + f"and {max_conversation_restarts} conversation restarts." + ) from exc + + return output_obj, messages + + def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict: + provider = self._model_provider_registry.get_provider(model_config.provider) + api_key = None + if provider.api_key: + api_key = self._secret_resolver.resolve(provider.api_key) + api_key = api_key or "not-used-but-required" + + litellm_params = litellm.LiteLLM_Params( + model=f"{provider.provider_type}/{model_config.model}", + api_base=provider.endpoint, + api_key=api_key, + ) + return { + "model_name": model_config.model, + "litellm_params": litellm_params.model_dump(), + } + + def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> None: + if response is None: + self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) + return + if ( + response.usage is not None + and response.usage.prompt_tokens is not None + and response.usage.completion_tokens is not None + ): + self._usage_stats.extend( + token_usage=TokenUsageStats( + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + ), + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + ) + + def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None: + if response is None: + self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) + return + if response.usage is not None and response.usage.prompt_tokens is not None: + self._usage_stats.extend( + token_usage=TokenUsageStats( + input_tokens=response.usage.prompt_tokens, + output_tokens=0, + ), + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + ) + + async def acompletion( + self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any + ) -> litellm.ModelResponse: + message_payloads = [message.to_dict() for message in messages] + logger.debug( + f"Prompting model {self.model_name!r}...", + extra={"model": self.model_name, "messages": message_payloads}, + ) + response = None + kwargs = self.consolidate_kwargs(**kwargs) + try: + response = await self._router.acompletion(model=self.model_name, messages=message_payloads, **kwargs) + logger.debug( + f"Received completion from model {self.model_name!r}", + extra={ + "model": self.model_name, + "response": response, + "text": response.choices[0].message.content, + "usage": self._usage_stats.model_dump(), + }, + ) + return response + except Exception as e: + raise e + finally: + if not skip_usage_tracking and response is not None: + self._track_usage(response) + + @acatch_llm_exceptions + async def agenerate_text_embeddings( + self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs: Any + ) -> list[list[float]]: + logger.debug( + f"Generating embeddings with model {self.model_name!r}...", + extra={ + "model": self.model_name, + "input_count": len(input_texts), + }, + ) + kwargs = self.consolidate_kwargs(**kwargs) + response = None + try: + response = await self._router.aembedding(model=self.model_name, input=input_texts, **kwargs) + logger.debug( + f"Received embeddings from model {self.model_name!r}", + extra={ + "model": self.model_name, + "embedding_count": len(response.data) if response.data else 0, + "usage": self._usage_stats.model_dump(), + }, + ) + if response.data and len(response.data) == len(input_texts): + return [data["embedding"] for data in response.data] + else: + raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}") + except Exception as e: + raise e + finally: + if not skip_usage_tracking and response is not None: + self._track_usage_from_embedding(response) + + @acatch_llm_exceptions + async def agenerate( + self, + prompt: str, + *, + parser: Callable[[str], Any], + system_prompt: str | None = None, + multi_modal_context: list[dict[str, Any]] | None = None, + tool_alias: str | None = None, + max_correction_steps: int = 0, + max_conversation_restarts: int = 0, + skip_usage_tracking: bool = False, + purpose: str | None = None, + **kwargs: Any, + ) -> tuple[Any, list[ChatMessage]]: + output_obj = None + tool_schemas = None + tool_call_turns = 0 + curr_num_correction_steps = 0 + curr_num_restarts = 0 + + mcp_facade = self._get_mcp_facade(tool_alias) + + restart_checkpoint = prompt_to_messages( + user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context + ) + checkpoint_tool_call_turns = 0 + messages: list[ChatMessage] = deepcopy(restart_checkpoint) + + if mcp_facade is not None: + tool_schemas = await asyncio.to_thread(mcp_facade.get_tool_schemas) + + while True: + completion_kwargs = dict(kwargs) + if tool_schemas is not None: + completion_kwargs["tools"] = tool_schemas + + completion_response = await self.acompletion( + messages, + skip_usage_tracking=skip_usage_tracking, + **completion_kwargs, + ) + + if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response): + tool_call_turns += 1 + + if tool_call_turns > mcp_facade.max_tool_call_turns: + messages.extend(mcp_facade.refuse_completion_response(completion_response)) + else: + messages.extend( + await asyncio.to_thread(mcp_facade.process_completion_response, completion_response) + ) + + restart_checkpoint = deepcopy(messages) + checkpoint_tool_call_turns = tool_call_turns + + continue + + response = completion_response.choices[0].message.content or "" + reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None) + messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) + curr_num_correction_steps += 1 + + try: + output_obj = parser(response) + break + except ParserException as exc: + if max_correction_steps == 0 and max_conversation_restarts == 0: + raise GenerationValidationFailureError( + "Unsuccessful generation attempt. No retries were attempted." + ) from exc + + if curr_num_correction_steps <= max_correction_steps: + messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc)))) + + elif curr_num_restarts < max_conversation_restarts: + curr_num_correction_steps = 0 + curr_num_restarts += 1 + messages = deepcopy(restart_checkpoint) + tool_call_turns = checkpoint_tool_call_turns + + else: + raise GenerationValidationFailureError( + f"Unsuccessful generation despite {max_correction_steps} correction steps " + f"and {max_conversation_restarts} conversation restarts." + ) from exc + + return output_obj, messages diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/factory.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/factory.py new file mode 100644 index 00000000..fb3b2e1d --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/factory.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from data_designer.config.models import ModelConfig +from data_designer.engine.model_provider import ModelProviderRegistry +from data_designer.engine.secret_resolver import SecretResolver + +if TYPE_CHECKING: + from data_designer.engine.mcp.registry import MCPRegistry + from data_designer.engine.models.registry import ModelRegistry + + +def create_model_registry( + *, + model_configs: list[ModelConfig] | None = None, + secret_resolver: SecretResolver, + model_provider_registry: ModelProviderRegistry, + mcp_registry: MCPRegistry | None = None, +) -> ModelRegistry: + """Factory function for creating a ModelRegistry instance. + + Heavy dependencies (litellm, httpx) are deferred until this function is called. + This is a factory function pattern - imports inside factories are idiomatic Python + for lazy initialization. + + Args: + model_configs: Optional list of model configurations to register. + secret_resolver: Resolver for secrets referenced in provider configs. + model_provider_registry: Registry of model provider configurations. + mcp_registry: Optional MCP registry for tool operations. When provided, + ModelFacades can look up MCPFacades by tool_alias for tool-enabled generation. + + Returns: + A configured ModelRegistry instance. + """ + from data_designer.engine.models.facade import ModelFacade + from data_designer.engine.models.litellm_overrides import apply_litellm_patches + from data_designer.engine.models.registry import ModelRegistry + + apply_litellm_patches() + + def model_facade_factory(model_config, secret_resolver, model_provider_registry): + return ModelFacade( + model_config, + secret_resolver, + model_provider_registry, + mcp_registry=mcp_registry, + ) + + return ModelRegistry( + model_configs=model_configs, + secret_resolver=secret_resolver, + model_provider_registry=model_provider_registry, + model_facade_factory=model_facade_factory, + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/litellm_overrides.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/litellm_overrides.py new file mode 100644 index 00000000..92070def --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/litellm_overrides.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +""" +LiteLLM overrides and customizations. + +Note on imports: This module uses direct (eager) imports for litellm rather than lazy loading. +This is intentional because: + +1. Class inheritance requires base classes to be resolved at class definition time, + making lazy imports incompatible with our ThreadSafeCache and CustomRouter classes. + +2. This module is already lazily loaded at the application level - it's only imported + by facade.py, which itself is imported inside the create_model_registry() factory + function. So litellm is only loaded when models are actually needed. + +3. Attempting to use lazy imports here causes intermittent ImportErrors. +""" + +from __future__ import annotations + +import random +import threading + +import httpx +import litellm +from litellm import RetryPolicy +from litellm.caching.in_memory_cache import InMemoryCache +from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager +from litellm.router import Router +from pydantic import BaseModel, Field +from typing_extensions import override + +from data_designer.logging import quiet_noisy_logger + +DEFAULT_MAX_CALLBACKS = 1000 + + +class LiteLLMRouterDefaultKwargs(BaseModel): + ## Number of seconds to wait initially after a connection + ## failure. + initial_retry_after_s: float = 2.0 + + ## Jitter percentage added during exponential backoff to + ## smooth repeated retries over time. + jitter_pct: float = 0.2 + + ## Maximum number of seconds to wait for an API request + ## before letting it die. Will trigger a retry. + timeout: float = 60.0 + + ## Sets the default retry policy, including the number + ## of retries to use in particular scenarios. + retry_policy: RetryPolicy = Field( + default_factory=lambda: RetryPolicy( + RateLimitErrorRetries=3, + TimeoutErrorRetries=3, + ) + ) + + +class ThreadSafeCache(InMemoryCache): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._lock = threading.RLock() + + def get_cache(self, key, **kwargs): + with self._lock: + return super().get_cache(key, **kwargs) + + def set_cache(self, key, value, **kwargs): + with self._lock: + super().set_cache(key, value, **kwargs) + + def batch_get_cache(self, keys: list, **kwargs): + with self._lock: + return super().batch_get_cache(keys, **kwargs) + + def delete_cache(self, key): + with self._lock: + super().delete_cache(key) + + def evict_cache(self): + with self._lock: + super().evict_cache() + + def increment_cache(self, key, value: int, **kwargs) -> int: + with self._lock: + return super().increment_cache(key, value, **kwargs) + + def flush_cache(self): + with self._lock: + super().flush_cache() + + +class CustomRouter(Router): + def __init__( + self, + *args, + initial_retry_after_s: float, + jitter_pct: float, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._initial_retry_after_s = initial_retry_after_s + self._jitter_pct = jitter_pct + + def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None: + """ + Most of this code logic was extracted directly from the parent + `Router`'s `_time_to_sleep_before_retry` function. Our override + of that method below should only affect requests where the server + didn't explicitly return a desired retry-delay. If the server did + return this info, we'll simply use that retry value returned here. + """ + + response_headers: httpx.Headers | None = None + if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore + response_headers = e.response.headers # type: ignore + if hasattr(e, "litellm_response_headers"): + response_headers = e.litellm_response_headers # type: ignore + + retry_after = litellm.utils._get_retry_after_from_exception_header(response_headers) + + # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. + if retry_after is not None and 0 < retry_after <= 60: + return retry_after + else: + return None + + @override + def _time_to_sleep_before_retry( + self, + e: Exception, + remaining_retries: int, + num_retries: int, + healthy_deployments: list | None = None, + all_deployments: list | None = None, + ) -> int | float: + """ + Implements exponential backoff for retries. + + Technically, litellm's `Router` already implements some + form of exponential backoff. However, that backoff + is not customizable w.r.t jitter and initial delay + timing. For that reason, we override this method to + utilize our own custom instance variables, deferring + to the existing implementation wherever we can. + """ + + # If the response headers indicated how long we should wait, + # use that information. + if retry_after := self._extract_retry_delay_from_headers(e): + return retry_after + + return self.calculate_exponential_backoff( + initial_retry_after_s=self._initial_retry_after_s, + current_retry=num_retries - remaining_retries, + jitter_pct=self._jitter_pct, + ) + + @staticmethod + def calculate_exponential_backoff(initial_retry_after_s: float, current_retry: int, jitter_pct: float) -> float: + sleep_s = initial_retry_after_s * (pow(2.0, current_retry)) + jitter = 1.0 + random.uniform(-jitter_pct, jitter_pct) + return sleep_s * jitter + + +def apply_litellm_patches(): + litellm.in_memory_llm_clients_cache = ThreadSafeCache() + + # Workaround for the litellm issue described in https://github.com/BerriAI/litellm/issues/9792 + LoggingCallbackManager.MAX_CALLBACKS = DEFAULT_MAX_CALLBACKS + + quiet_noisy_logger("httpx") + quiet_noisy_logger("LiteLLM") + quiet_noisy_logger("LiteLLM Router") diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/__init__.py new file mode 100644 index 00000000..52a7a9da --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/errors.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/errors.py new file mode 100644 index 00000000..7d1db351 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/errors.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + + +class ParserException(Exception): + """Identifies errors resulting from generic parser errors. + + Attributes: + source (str | None): The source string that the parser + attempted to parse. + """ + + source: str | None + + @staticmethod + def _log_format(source: str) -> str: + ## NOTE: The point of this was to be able to report offending + ## failure cases to the logs. This might not be what we want + ## to do in all cases. In the meantime, this note is left + ## for later review. + # + # return f"{source}" + return "" + + def __init__(self, msg: str | None = None, source: str | None = None): + msg = "" if msg is None else msg.strip() + + if source is not None: + msg += self._log_format(source) + + super().__init__(msg) + self.source = source diff --git a/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/parser.py b/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/parser.py new file mode 100644 index 00000000..18e95bac --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models_v2/parsers/parser.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from functools import reduce + +import marko +from lxml import etree +from lxml.etree import _Element + +import data_designer.engine.models.parsers.tag_parsers as tp +from data_designer.engine.models.parsers.postprocessors import merge_text_blocks +from data_designer.engine.models.parsers.types import ( + LLMStructuredResponse, + PostProcessor, + TagParser, +) + +DEFAULT_TAG_PARSERS = { + "pre.code": tp.code_block_parser, + "p.code": tp.inline_code_parser, + "p": tp.text_parser, + "pre": tp.text_parser, + "": tp.text_parser_keep_markup, +} + +DEFAULT_POST_PROCESSORS = [merge_text_blocks] + + +def _patch_tags_before_code_fences(response: str) -> str: + """Patch to add a linebreak between a tag prior to a code block. + + Marko conversion of MD->HTML has a quirk. If there is a case like + the following, it will not convert the code block at all: + + ... + + ```syntax + ... + + We want to find these cases and simply introduce an additional + line break. + """ + + return response.replace(">\n```", ">\n\n```") + + +class LLMResponseParser: + """ + Parses Language Model (LLM) responses containing a mixture of Markdown and custom markup into structured data. + + The `LLMResponseParser` class facilitates the translation of LLM-generated responses, which may include + Markdown and custom markup tags, into a structured format using ElementTree. It allows for customizable + parsing behavior through the registration of tag-specific parsers and post-processors. + + ## Description + + The core functionality of this class enables LLMs to respond using Markdown along with any custom + prompted markup specified by the system or task. The parsing process involves converting the Markdown + and markup into an ElementTree, then processing each element using registered tag parsers to produce + a list of structured `BaseModel` instances. Post-processors can further refine the structured response. + + ### Tag Parsers + + Tag parsers are responsible for handling specific markup tags within the LLM response. They can be + registered with the parser using dot-path notation to manage hierarchical tag structures. This allows + downstream tasks to customize how specific elements are processed into `BaseModel` instances. + + ### Post-Processors + + Post-processors are functions that operate on the list of parsed blocks to perform additional + transformations or aggregations. They are applied after the initial parsing of the response. + + Attributes: + tag_parsers (dict[str, TagParser]): A dictionary mapping tag paths to their corresponding `TagParser` instances. + postprocessors (list[PostProcessor]): A list of post-processing functions to apply to the structured response. + + Example: + ```python + class CodeBlock(BaseModel): + code: str + syntax: Optional[str] = None + + class CodeBlockParser: + def __call__(self, element: _Element) -> CodeBlock: + # Implementation details... + return CodeBlock(code=element.text, syntax=element.get("class")) + + parser = LLMResponseParser( + tag_parsers={ + "pre.code": CodeBlockParser(), + } + ) + + out = parser.parse('```json\n{"answer": 42}\n```') + print(out.parsed) + # Output: [CodeBlock(code='{"answer": 42}\n', syntax='json')] + ``` + """ + + tag_parsers: dict[str, TagParser] + postprocessors: list[PostProcessor] + + def __init__( + self, + tag_parsers: dict[str, TagParser] | None = None, + postprocessors: list[PostProcessor] | None = None, + ): + """ + Initializes the LLMResponseParser with optional tag parsers and post-processors. + + Args: + tag_parsers (Optional[dict[str, TagParser]]): A dictionary mapping tag paths to `TagParser` instances. + If provided, these parsers will be merged with the default tag parsers. + postprocessors (Optional[list[PostProcessor]]): A list of post-processing functions to apply + to the structured response. If not provided, a default post-processor `merge_text_blocks` + is used. + + Attributes: + tag_parsers (dict[str, TagParser]): Initialized with default tag parsers, updated with any provided. + postprocessors (list[PostProcessor]): Initialized with default post-processors or the provided list. + """ + self.tag_parsers = {**DEFAULT_TAG_PARSERS} + if tag_parsers: + self.tag_parsers.update(tag_parsers) + + self.postprocessors = [ + merge_text_blocks, + ] + if postprocessors is not None: + self.postprocessors = postprocessors + + def lookup_parser(self, element: _Element) -> TagParser: + """ + Resolves and retrieves the appropriate `TagParser` for a given XML element based on its tag hierarchy. + + The method constructs the dot-path lineage of the element's tags, starting from the root and moving + towards the specific element. It then attempts to find the most specific matching `TagParser` by + progressively reducing the specificity of the tag path until a matching parser is found. + + Args: + element (_Element): The XML element for which to find the corresponding `TagParser`. + + Returns: + TagParser: The `TagParser` instance that matches the element's tag path. + + Raises: + KeyError: If no matching `TagParser` is found for the element's tag path. + """ + # Get the dot path lineage of this tag, sans root. + # Note that the lineage comes back in reverse order. + parents = [e.tag for e in element.iterancestors()][::-1] + lineage = [*parents, element.tag] + + # Now attempt to matchup with the tag parsers name. + # Starts from the full linear (most specific), and + # breaks on the first hit. So this should properly + # prioritize specific parsers over general ones. + while lineage: + tag_path = ".".join(lineage) + if tag_path not in self.tag_parsers: + lineage.pop(0) + else: + break + + # Tag path can be an empty string, which hits the + # default parsing option specified by the "" entry + # of the tag parsers dict. + tag_path = ".".join(lineage) + return self.tag_parsers[tag_path] + + def postprocess(self, structured_response: LLMStructuredResponse) -> LLMStructuredResponse: + """ + Applies post-processing functions to the structured response. + + If no post-processors are registered, the original structured response is returned. + Otherwise, each post-processor is applied in sequence to transform the response. + + Args: + structured_response (LLMStructuredResponse): The initial structured response to be post-processed. + + Returns: + LLMStructuredResponse: The post-processed structured response. + """ + if not self.postprocessors: + return structured_response + + return reduce(lambda acc, func: func(acc), self.postprocessors, structured_response) + + def parse(self, md_response: str) -> LLMStructuredResponse: + """ + Parses a Markdown-formatted LLM response into a structured `LLMStructuredResponse`. + + The parsing process involves converting the Markdown and custom markup into an XML tree, + iterating over each element in a depth-first traversal to apply the appropriate + `TagParser`, and then applying any registered post-processors to the resulting structured data. + + Args: + md_response (str): The Markdown-formatted response from the LLM, potentially containing custom markup. + + Returns: + LLMStructuredResponse: The structured representation of the parsed response, containing parsed blocks. + + Raises: + etree.XMLSyntaxError: If the provided Markdown cannot be converted into a valid XML structure. + """ + response = marko.convert(_patch_tags_before_code_fences(md_response)) + output = LLMStructuredResponse(response=md_response, markup=response) + + # Generate document tree + parser = etree.HTMLParser(recover=True, remove_blank_text=True) + root = etree.fromstring(response, parser=parser) + tags = root.iter() if root is not None else [] + + # Iterate over tags, depth first + for element in tags: + if element == root or element.tag == "body": + continue + + parsed_block = self.lookup_parser(element)(element) + + # Make a quick check for dead text blocks, which + # can happen with container tags like
,