Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import asyncio
import functools
import logging
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -43,6 +44,14 @@ def generate(self, data: pd.DataFrame) -> pd.DataFrame: ...
@abstractmethod
def generate(self, data: DataT) -> DataT: ...

async def agenerate(self, data: dict) -> dict:
"""Async fallback — delegates to sync generate via thread pool.

Subclasses with native async support (e.g. ColumnGeneratorWithModelChatCompletion)
should override this with a direct async implementation.
"""
return await asyncio.to_thread(self.generate, data)

def log_pre_generation(self) -> None:
"""A shared method to log info before the generator's `generate` method is called.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import functools
import logging
from typing import TYPE_CHECKING, Any

from data_designer.config.column_configs import (
LLMCodeColumnConfig,
Expand All @@ -24,6 +25,9 @@
from data_designer.engine.models.recipes.base import ResponseRecipe
from data_designer.engine.processing.utils import deserialize_json_values

if TYPE_CHECKING:
from data_designer.engine.models.utils import ChatMessage

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -56,36 +60,55 @@ def prompt_renderer(self) -> RecordBasedPromptRenderer:
)

def generate(self, data: dict) -> dict:
kwargs = self._prepare_generation_kwargs(data)
response, trace = self.model.generate(**kwargs)
return self._process_generation_result(data, response, trace)

async def agenerate(self, data: dict) -> dict:
kwargs = self._prepare_generation_kwargs(data)
response, trace = await self.model.agenerate(**kwargs)
return self._process_generation_result(data, response, trace)

def _prepare_generation_kwargs(self, data: dict) -> dict[str, Any]:
"""Prepare keyword arguments for model.generate() / model.agenerate().

Deserializes input data, builds multi-modal context, and renders prompts.
"""
# Deserialize input data from previous columns so Jinja2 templates can access nested fields
# Example: If prev column stored '{"key": "value"}', templates can use {{ prev_column.key }}
# Note: This creates a new dict and doesn't mutate the original `data` argument
deserialized_record = deserialize_json_values(data)

multi_modal_context = None
multi_modal_context: list[dict[str, Any]] | None = None
if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0:
multi_modal_context = []
for context in self.config.multi_modal_context:
multi_modal_context.extend(context.get_contexts(deserialized_record))

response, trace = self.model.generate(
prompt=self.prompt_renderer.render(
return {
"prompt": self.prompt_renderer.render(
record=deserialized_record,
prompt_template=self.config.prompt,
prompt_type=PromptType.USER_PROMPT,
),
system_prompt=self.prompt_renderer.render(
"system_prompt": self.prompt_renderer.render(
record=deserialized_record,
prompt_template=self.config.system_prompt,
prompt_type=PromptType.SYSTEM_PROMPT,
),
parser=self.response_recipe.parse,
multi_modal_context=multi_modal_context,
tool_alias=self.config.tool_alias,
max_correction_steps=self.max_conversation_correction_steps,
max_conversation_restarts=self.max_conversation_restarts,
purpose=f"running generation for column '{self.config.name}'",
)

"parser": self.response_recipe.parse,
"multi_modal_context": multi_modal_context,
"tool_alias": self.config.tool_alias,
"max_correction_steps": self.max_conversation_correction_steps,
"max_conversation_restarts": self.max_conversation_restarts,
"purpose": f"running generation for column '{self.config.name}'",
}

def _process_generation_result(self, data: dict, response: Any, trace: list[ChatMessage]) -> dict:
"""Process model response and trace into the output data dict.

Serializes the response, applies trace column logic, and extracts reasoning content.
"""
serialized_output = self.response_recipe.serialize_output(response)
data[self.config.name] = self._process_serialized_output(serialized_output)

Expand All @@ -102,7 +125,7 @@ def generate(self, data: dict) -> dict:

return data

def _extract_reasoning_content(self, trace: list) -> str | None:
def _extract_reasoning_content(self, trace: list[ChatMessage]) -> str | None:
"""Extract reasoning_content from the final assistant message in the trace.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import functools
import logging
import os
import time
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable

from data_designer.config.column_configs import CustomColumnConfig
from data_designer.config.column_types import ColumnConfigT
Expand All @@ -31,6 +32,7 @@
from data_designer.engine.dataset_builders.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage
from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
from data_designer.engine.dataset_builders.utils.async_concurrency import AsyncConcurrentExecutor
from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor
from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager
Expand All @@ -50,6 +52,11 @@

logger = logging.getLogger(__name__)

DATA_DESIGNER_ASYNC_ENGINE = os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "0") == "1"

if DATA_DESIGNER_ASYNC_ENGINE:
logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled — using async concurrency")

_CLIENT_VERSION: str = get_library_version()


Expand Down Expand Up @@ -199,7 +206,11 @@ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
max_workers = self._resource_provider.run_config.non_inference_max_parallel_workers
if isinstance(generator, ColumnGeneratorWithModel):
max_workers = generator.inference_parameters.max_parallel_requests
self._fan_out_with_threads(generator, max_workers=max_workers)
if DATA_DESIGNER_ASYNC_ENGINE:
logger.info("⚡ Using async engine for concurrent execution")
self._fan_out_with_async(generator, max_workers=max_workers)
else:
self._fan_out_with_threads(generator, max_workers=max_workers)
Comment on lines +209 to +213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we leave this dataset builder untouched since perf gain with async isn't there. When we work on the new builder, we can start to use it there.

Copy link
Contributor Author

@eric-tramel eric-tramel Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This switch is critical for being able to test end-to-end correctness of the async implementations. Without this we do not have a clean way to say that the async stack is correct. However, yes, once moving to async tasks we would need to hoist this context to some higher step and it wouldn't exist within column_wise_builder.


def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True))
Expand All @@ -226,41 +237,57 @@ def _run_mcp_tool_check_if_needed(self) -> None:
raise DatasetGenerationError(f"Tool alias(es) {tool_aliases!r} specified but no MCPRegistry configured.")
self._resource_provider.mcp_registry.run_health_check(tool_aliases)

def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
def _setup_fan_out(
self, generator: ColumnGeneratorWithModelRegistry, max_workers: int
) -> tuple[ProgressTracker, dict[str, Any]]:
if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL:
raise DatasetGenerationError(
f"Generator {generator.name} is not a {GenerationStrategy.CELL_BY_CELL} "
"generator so concurrency through threads is not supported."
"generator so concurrent fan-out is not supported."
)

if getattr(generator.config, "tool_alias", None):
logger.info("🛠️ Tool calling enabled")

progress_tracker = ProgressTracker(
total_records=self.batch_manager.num_records_batch,
label=f"{generator.config.column_type} column '{generator.config.name}'",
)
progress_tracker.log_start(max_workers)

settings = self._resource_provider.run_config
with ConcurrentThreadExecutor(
max_workers=max_workers,
column_name=generator.config.name,
result_callback=self._make_result_callback(progress_tracker),
error_callback=self._make_error_callback(progress_tracker),
shutdown_error_rate=settings.shutdown_error_rate,
shutdown_error_window=settings.shutdown_error_window,
disable_early_shutdown=settings.disable_early_shutdown,
) as executor:
for i, record in self.batch_manager.iter_current_batch():
executor.submit(lambda record: generator.generate(record), record, context={"index": i})

executor_kwargs: dict = {
"column_name": generator.config.name,
"result_callback": self._make_result_callback(progress_tracker),
"error_callback": self._make_error_callback(progress_tracker),
"shutdown_error_rate": settings.shutdown_error_rate,
"shutdown_error_window": settings.shutdown_error_window,
"disable_early_shutdown": settings.disable_early_shutdown,
}

return progress_tracker, executor_kwargs

def _finalize_fan_out(self, progress_tracker: ProgressTracker) -> None:
progress_tracker.log_final()

if len(self._records_to_drop) > 0:
self.batch_manager.drop_records(self._records_to_drop)
self._records_to_drop.clear()

def _fan_out_with_async(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers)
executor = AsyncConcurrentExecutor(max_workers=max_workers, **executor_kwargs)
work_items = [
(generator.agenerate(record), {"index": i}) for i, record in self.batch_manager.iter_current_batch()
]
executor.run(work_items)
self._finalize_fan_out(progress_tracker)

def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
if getattr(generator.config, "tool_alias", None):
logger.info("🛠️ Tool calling enabled")
progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers)
with ConcurrentThreadExecutor(max_workers=max_workers, **executor_kwargs) as executor:
for i, record in self.batch_manager.iter_current_batch():
executor.submit(lambda record: generator.generate(record), record, context={"index": i})
self._finalize_fan_out(progress_tracker)

def _make_result_callback(self, progress_tracker: ProgressTracker) -> Callable[[dict], None]:
def callback(result: dict, *, context: dict | None = None) -> None:
self._worker_result_callback(result, context=context)
Expand Down
Loading