Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/concepts/custom_columns.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,40 @@ This gives you direct access to all `ModelFacade` capabilities: custom parsers,
| `generator_function` | Callable | Yes | Decorated function |
| `generation_strategy` | GenerationStrategy | No | `CELL_BY_CELL` or `FULL_COLUMN` |
| `generator_params` | BaseModel | No | Typed params passed to function |
| `allow_resize` | bool | No | Allow 1:N or N:1 generation. Requires `FULL_COLUMN` strategy |

### Resizing (1:N and N:1)

With `full_column` strategy, you can produce more or fewer records than the input using `allow_resize=True`:

```python
@dd.custom_column_generator(
required_columns=["topic"],
side_effect_columns=["variation_id"],
)
def expand_topics(df: pd.DataFrame, params: None, models: dict) -> pd.DataFrame:
rows = []
for _, row in df.iterrows():
for i in range(3): # Generate 3 variations per input
rows.append({
"topic": row["topic"],
"question": f"Question {i+1} about {row['topic']}",
"variation_id": i,
})
return pd.DataFrame(rows)

dd.CustomColumnConfig(
name="question",
generator_function=expand_topics,
generation_strategy=dd.GenerationStrategy.FULL_COLUMN,
allow_resize=True,
)
```

Use cases:

- **Expansion (1:N)**: Generate multiple variations per input
- **Retraction (N:1)**: Filter, aggregate, or deduplicate records

## Multi-Turn Example

Expand Down
108 changes: 108 additions & 0 deletions example_allow_resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Example: Chaining expand -> retract -> expand resize operations.

Pipeline: 5 topics -> 15 questions (3 per topic) -> ~8 hard questions (filter easy)
-> ~24 answer variants (3 per question)
"""

from __future__ import annotations

import data_designer.config as dd
from data_designer.interface import DataDesigner
from data_designer.lazy_heavy_imports import pd


# Step 1: Expand — 1:N, generate 3 questions per topic
@dd.custom_column_generator(required_columns=["topic"], side_effect_columns=["question_id", "difficulty"])
def expand_to_questions(df: pd.DataFrame) -> pd.DataFrame:
rows = []
for _, row in df.iterrows():
for i in range(3):
rows.append(
{
"topic": row["topic"],
"question": f"Q{i + 1} about {row['topic']}?",
"question_id": i,
"difficulty": ["easy", "medium", "hard"][i],
}
)
return pd.DataFrame(rows)


# Step 2: Retract — N:1, keep only medium/hard questions
@dd.custom_column_generator(required_columns=["difficulty"])
def filter_non_easy(df: pd.DataFrame) -> pd.DataFrame:
return df[df["difficulty"] != "easy"].copy().assign(filtered=True)


# Step 3: Expand again — 1:N, generate 3 answer variants per surviving question
@dd.custom_column_generator(required_columns=["question"], side_effect_columns=["variant"])
def expand_to_answers(df: pd.DataFrame) -> pd.DataFrame:
rows = []
for _, row in df.iterrows():
for v in range(3):
rows.append({**row.to_dict(), "answer": f"Answer v{v} to: {row['question']}", "variant": v})
return pd.DataFrame(rows)


def main() -> None:
data_designer = DataDesigner()
config_builder = dd.DataDesignerConfigBuilder()

# Seed: 5 topics
config_builder.add_column(
dd.SamplerColumnConfig(
name="topic",
sampler_type=dd.SamplerType.CATEGORY,
params=dd.CategorySamplerParams(values=["Python", "ML", "Data", "Stats", "SQL"]),
)
)

# Expand: 5 topics -> 15 questions
config_builder.add_column(
dd.CustomColumnConfig(
name="question",
generator_function=expand_to_questions,
generation_strategy=dd.GenerationStrategy.FULL_COLUMN,
allow_resize=True,
)
)

# Retract: 15 -> 10 (drop "easy" questions)
config_builder.add_column(
dd.CustomColumnConfig(
name="filtered",
generator_function=filter_non_easy,
generation_strategy=dd.GenerationStrategy.FULL_COLUMN,
allow_resize=True,
)
)

# Expand again: 10 -> 30 answer variants
config_builder.add_column(
dd.CustomColumnConfig(
name="answer",
generator_function=expand_to_answers,
generation_strategy=dd.GenerationStrategy.FULL_COLUMN,
allow_resize=True,
)
)

# Preview (single batch)
preview = data_designer.preview(config_builder=config_builder, num_records=5)
print(f"Preview: 5 topics -> {len(preview.dataset)} answer variants")
print(preview.dataset[["topic", "difficulty", "question", "variant", "answer"]].to_string())
print()

# Build (multiple batches: 10 records with buffer_size=3 -> 4 batches)
data_designer.set_run_config(dd.RunConfig(buffer_size=3))
results = data_designer.create(config_builder=config_builder, num_records=10)
df = results.load_dataset()
print(f"Build: 10 topics (4 batches of 3+3+3+1) -> {len(df)} answer variants")
print(df[["topic", "difficulty", "question", "variant"]].to_string())


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class SingleColumnConfig(ConfigBase, ABC):

name: str
drop: bool = False
allow_resize: bool = False
column_type: str

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,12 @@ def validate_generator_function(self) -> Self:
f"Expected a function decorated with @custom_column_generator."
)
return self

@model_validator(mode="after")
def validate_allow_resize_requires_full_column(self) -> Self:
if self.allow_resize and self.generation_strategy != GenerationStrategy.FULL_COLUMN:
raise InvalidConfigError(
f"🛑 `allow_resize=True` requires `generation_strategy='full_column'` for column '{self.name}'. "
f"Cell-by-cell strategy processes one row at a time and cannot change record count."
)
return self
27 changes: 16 additions & 11 deletions packages/data-designer-config/tests/config/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,19 +516,24 @@ def test_sampler_column_config_discriminated_union_wrong_params_type():
)


def test_default_column_emoji_for_custom_column_type() -> None:
"""Ensure the base get_column_emoji implementation is used when not overridden."""
class StubColumnConfigWithoutEmoji(SingleColumnConfig):
column_type: Literal["stub-without-emoji"] = "stub-without-emoji"

class StubColumnConfigWithoutEmoji(SingleColumnConfig):
column_type: Literal["stub-without-emoji"] = "stub-without-emoji"
value: str
@property
def required_columns(self) -> list[str]:
return []

@property
def required_columns(self) -> list[str]:
return []
@property
def side_effect_columns(self) -> list[str]:
return []

@property
def side_effect_columns(self) -> list[str]:
return []

def test_default_column_emoji_for_custom_column_type() -> None:
"""Ensure the base get_column_emoji implementation is used when not overridden."""
assert StubColumnConfigWithoutEmoji.get_column_emoji() == "🎨"


def test_allow_resize_inherited_by_subclasses() -> None:
"""Subclasses inherit allow_resize from SingleColumnConfig."""
assert StubColumnConfigWithoutEmoji(name="test").allow_resize is False
assert StubColumnConfigWithoutEmoji(name="test", allow_resize=True).allow_resize is True
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,5 @@ def log_pre_generation(self) -> None:
logger.info(f"{LOG_INDENT}model_aliases: {self.config.model_aliases}")
if self.config.generator_params:
logger.info(f"{LOG_INDENT}generator_params: {self.config.generator_params}")
if self.config.allow_resize:
logger.info(f"{LOG_INDENT}allow_resize: {self.config.allow_resize}")
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,23 @@ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
self._fan_out_with_threads(generator, max_workers=max_workers)

def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
original_count = self.batch_manager.num_records_in_buffer
df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True))
self.batch_manager.update_records(df.to_dict(orient="records"))
allow_resize = getattr(generator.config, "allow_resize", False)
new_count = len(df)

if allow_resize and new_count != original_count:
if new_count == 0:
logger.warning(
f"⚠️ Column '{generator.config.name}' reduced batch to 0 records. This batch will be skipped."
)
else:
emoji = "💥" if new_count > original_count else "✂️"
logger.info(
f"{emoji} Column '{generator.config.name}' resized batch: {original_count} -> {new_count} records."
)

self.batch_manager.replace_buffer(df.to_dict(orient="records"), allow_resize=allow_resize)

def _run_model_health_check_if_needed(self) -> None:
model_aliases: set[str] = set()
Expand Down Expand Up @@ -304,7 +319,7 @@ def callback(exc: Exception, *, context: dict | None = None) -> None:
return callback

def _write_processed_batch(self, dataframe: pd.DataFrame) -> None:
self.batch_manager.update_records(dataframe.to_dict(orient="records"))
self.batch_manager.replace_buffer(dataframe.to_dict(orient="records"))
self.batch_manager.write()

def _validate_column_configs(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, artifact_storage: ArtifactStorage):
self._current_batch_number = 0
self._num_records_list: list[int] | None = None
self._buffer_size: int | None = None
self._actual_num_records: int = 0
self.artifact_storage = artifact_storage

@property
Expand Down Expand Up @@ -83,11 +84,13 @@ def finish_batch(self, on_complete: Callable[[Path], None] | None = None) -> Pat
raise DatasetBatchManagementError("🛑 All batches have been processed.")

if self.write() is not None:
self._actual_num_records += len(self._buffer)
final_file_path = self.artifact_storage.move_partial_result_to_final_file_path(self._current_batch_number)

self.artifact_storage.write_metadata(
{
"target_num_records": sum(self.num_records_list),
"actual_num_records": self._actual_num_records,
"total_num_batches": self.num_batches,
"buffer_size": self._buffer_size,
"schema": {field.name: str(field.type) for field in pq.read_schema(final_file_path)},
Expand Down Expand Up @@ -141,6 +144,7 @@ def iter_current_batch(self) -> Iterator[tuple[int, dict]]:
def reset(self, delete_files: bool = False) -> None:
self._current_batch_number = 0
self._buffer: list[dict] = []
self._actual_num_records = 0
if delete_files:
for dir_path in [
self.artifact_storage.final_dataset_path,
Expand Down Expand Up @@ -191,16 +195,18 @@ def update_record(self, index: int, record: dict) -> None:
raise IndexError(f"🛑 Index {index} is out of bounds for buffer of size {len(self._buffer)}.")
self._buffer[index] = record

def update_records(self, records: list[dict]) -> None:
if len(records) != len(self._buffer):
def replace_buffer(self, records: list[dict], *, allow_resize: bool = False) -> None:
"""Replace the buffer contents.

Args:
records: New records to replace the buffer.
allow_resize: If True, allows the number of records to differ from the current
buffer size (1:N or N:1 patterns). Defaults to False for strict 1:1 mapping.
"""
if not allow_resize and len(records) != len(self._buffer):
raise DatasetBatchManagementError(
f"🛑 Number of records to update ({len(records)}) must match "
f"the number of records in the buffer ({len(self._buffer)})."
f"🛑 Number of records ({len(records)}) must match the current buffer size ({len(self._buffer)})."
)
self._buffer = records

def replace_buffer(self, records: list[dict]) -> None:
"""Replace the buffer contents, updating the current batch size."""
self._buffer = records
if self._num_records_list is not None:
if allow_resize and self._num_records_list is not None:
self._num_records_list[self._current_batch_number] = len(records)
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run_pre_batch(self, batch_manager: DatasetBatchManager) -> None:

df = batch_manager.get_current_batch(as_dataframe=True)
df = self._run_stage(df, ProcessorStage.PRE_BATCH)
batch_manager.replace_buffer(df.to_dict(orient="records"))
batch_manager.replace_buffer(df.to_dict(orient="records"), allow_resize=True)

def run_post_batch(self, df: pd.DataFrame, current_batch_number: int | None) -> pd.DataFrame:
"""Run process_after_batch() on processors that implement it."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy
from data_designer.config.custom_column import custom_column_generator
from data_designer.config.errors import InvalidConfigError
from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator
from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError
from data_designer.engine.resources.resource_provider import ResourceProvider
Expand Down Expand Up @@ -113,6 +114,31 @@ def test_config_validation_non_callable() -> None:
CustomColumnConfig(name="test", generator_function="not_a_function")


def test_config_validation_allow_resize_requires_full_column() -> None:
"""Test that allow_resize=True requires generation_strategy=FULL_COLUMN."""

@custom_column_generator()
def dummy_fn(row: dict) -> dict:
return row

with pytest.raises(InvalidConfigError, match="allow_resize=True.*requires.*full_column"):
CustomColumnConfig(
name="test",
generator_function=dummy_fn,
allow_resize=True,
generation_strategy=GenerationStrategy.CELL_BY_CELL,
)

# Should work with FULL_COLUMN
config = CustomColumnConfig(
name="test",
generator_function=dummy_fn,
allow_resize=True,
generation_strategy=GenerationStrategy.FULL_COLUMN,
)
assert config.allow_resize is True


# Cell-by-cell generation tests


Expand Down
Loading