From 8d2fdce75282b2fa71defdd5de93bcf86f0dad3a Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Mon, 2 Feb 2026 22:05:01 -0300 Subject: [PATCH 1/8] feat: add allow_resize for 1:N and N:1 generation patterns Adds support for generators that produce a different number of records than the input (expansion or retraction). This addresses GitHub issue #265. Changes: - Add `allow_resize` parameter to `update_records()` in DatasetBatchManager - Add `allow_resize` field to CustomColumnConfig - Add validation requiring FULL_COLUMN strategy when allow_resize=True - Track and report actual_num_records in metadata (may differ from target) - Add logging when batch size changes - Add example_allow_resize.py demonstrating the feature - Add comprehensive tests --- example_allow_resize.py | 98 +++++++++++++++++++ .../data_designer/config/column_configs.py | 17 ++++ .../dataset_builders/column_wise_builder.py | 17 +++- .../utils/dataset_batch_manager.py | 16 ++- .../generators/test_custom.py | 26 +++++ .../utils/test_dataset_batch_manager.py | 56 +++++++++++ 6 files changed, 227 insertions(+), 3 deletions(-) create mode 100644 example_allow_resize.py diff --git a/example_allow_resize.py b/example_allow_resize.py new file mode 100644 index 00000000..3a668311 --- /dev/null +++ b/example_allow_resize.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Example: Using allow_resize for 1:N expansion and N:1 retraction.""" + +from __future__ import annotations + +import data_designer.config as dd +from data_designer.interface import DataDesigner +from data_designer.lazy_heavy_imports import pd + + +@dd.custom_column_generator(required_columns=["topic"], side_effect_columns=["variation_id"]) +def expand_to_questions(df: pd.DataFrame, params: None, ctx: dd.CustomColumnContext) -> pd.DataFrame: + """Generate 3 questions per topic (1:N expansion).""" + rows = [] + for _, row in df.iterrows(): + for i in range(3): + rows.append( + { + "topic": row["topic"], + ctx.column_name: f"Question {i + 1} about {row['topic']}?", + "variation_id": i, + } + ) + return pd.DataFrame(rows) + + +@dd.custom_column_generator(required_columns=["topic", "score"]) +def filter_high_scores(df: pd.DataFrame, params: None, ctx: dd.CustomColumnContext) -> pd.DataFrame: + """Keep only records with score > 0.5 (N:1 retraction).""" + filtered = df[df["score"] > 0.5].copy() + filtered[ctx.column_name] = "passed" + return filtered + + +def run_expansion_example() -> None: + """3 topics -> 9 questions.""" + data_designer = DataDesigner() + config_builder = dd.DataDesignerConfigBuilder() + + config_builder.add_column( + dd.SamplerColumnConfig( + name="topic", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams(values=["Python", "ML", "Data"]), + ) + ) + config_builder.add_column( + dd.CustomColumnConfig( + name="question", + generator_function=expand_to_questions, + generation_strategy=dd.GenerationStrategy.FULL_COLUMN, + allow_resize=True, + ) + ) + + preview = data_designer.preview(config_builder=config_builder, num_records=3) + print(f"Expansion: 3 -> {len(preview.dataset)} records") + print(preview.dataset.to_string()) + + +def run_retraction_example() -> None: + """10 records -> ~5 (filtered).""" + data_designer = DataDesigner() + config_builder = dd.DataDesignerConfigBuilder() + + config_builder.add_column( + dd.SamplerColumnConfig( + name="topic", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams(values=["A", "B", "C", "D", "E"]), + ) + ) + config_builder.add_column( + dd.SamplerColumnConfig( + name="score", + sampler_type=dd.SamplerType.UNIFORM, + params=dd.UniformSamplerParams(low=0.0, high=1.0), + ) + ) + config_builder.add_column( + dd.CustomColumnConfig( + name="status", + generator_function=filter_high_scores, + generation_strategy=dd.GenerationStrategy.FULL_COLUMN, + allow_resize=True, + ) + ) + + preview = data_designer.preview(config_builder=config_builder, num_records=10) + print(f"Retraction: 10 -> {len(preview.dataset)} records") + print(preview.dataset.to_string()) + + +if __name__ == "__main__": + run_expansion_example() + # run_retraction_example() diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index b2eefd26..f564d5b0 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -509,6 +509,14 @@ class CustomColumnConfig(SingleColumnConfig): default=None, description="Optional typed configuration object passed as second argument to generator function", ) + allow_resize: bool = Field( + default=False, + description=( + "If True, allows the generator to produce a different number of records than the input. " + "Use for 1:N (expansion) or N:1 (retraction) generation patterns. " + "Only applicable when generation_strategy is 'full_column'." + ), + ) column_type: Literal["custom"] = "custom" @field_validator("generator_function") @@ -560,3 +568,12 @@ def validate_generator_function(self) -> Self: f"Expected a function decorated with @custom_column_generator." ) return self + + @model_validator(mode="after") + def validate_allow_resize_requires_full_column(self) -> Self: + if self.allow_resize and self.generation_strategy != GenerationStrategy.FULL_COLUMN: + raise InvalidConfigError( + f"🛑 `allow_resize=True` requires `generation_strategy='full_column'` for column '{self.name}'. " + f"Cell-by-cell strategy processes one row at a time and cannot change record count." + ) + return self diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index e5404d49..a8b2c9ab 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -230,8 +230,23 @@ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None: self._fan_out_with_threads(generator, max_workers=max_workers) def _run_full_column_generator(self, generator: ColumnGenerator) -> None: + original_count = self.batch_manager.num_records_in_buffer df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True)) - self.batch_manager.update_records(df.to_dict(orient="records")) + allow_resize = getattr(generator.config, "allow_resize", False) + new_count = len(df) + + if allow_resize and new_count != original_count: + if new_count == 0: + logger.warning( + f"⚠️ Column '{generator.config.name}' reduced batch to 0 records. This batch will be skipped." + ) + else: + logger.info( + f"📊 Column '{generator.config.name}' resized batch: {original_count} -> {new_count} records. " + f"Subsequent columns will operate on the new record count." + ) + + self.batch_manager.update_records(df.to_dict(orient="records"), allow_resize=allow_resize) def _run_model_health_check_if_needed(self) -> None: model_aliases: set[str] = set() diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py index a60d52b9..10b70f55 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py @@ -25,6 +25,7 @@ def __init__(self, artifact_storage: ArtifactStorage): self._current_batch_number = 0 self._num_records_list: list[int] | None = None self._buffer_size: int | None = None + self._actual_num_records: int = 0 self.artifact_storage = artifact_storage @property @@ -83,11 +84,13 @@ def finish_batch(self, on_complete: Callable[[Path], None] | None = None) -> Pat raise DatasetBatchManagementError("🛑 All batches have been processed.") if self.write() is not None: + self._actual_num_records += len(self._buffer) final_file_path = self.artifact_storage.move_partial_result_to_final_file_path(self._current_batch_number) self.artifact_storage.write_metadata( { "target_num_records": sum(self.num_records_list), + "actual_num_records": self._actual_num_records, "total_num_batches": self.num_batches, "buffer_size": self._buffer_size, "schema": {field.name: str(field.type) for field in pq.read_schema(final_file_path)}, @@ -141,6 +144,7 @@ def iter_current_batch(self) -> Iterator[tuple[int, dict]]: def reset(self, delete_files: bool = False) -> None: self._current_batch_number = 0 self._buffer: list[dict] = [] + self._actual_num_records = 0 if delete_files: for dir_path in [ self.artifact_storage.final_dataset_path, @@ -191,8 +195,16 @@ def update_record(self, index: int, record: dict) -> None: raise IndexError(f"🛑 Index {index} is out of bounds for buffer of size {len(self._buffer)}.") self._buffer[index] = record - def update_records(self, records: list[dict]) -> None: - if len(records) != len(self._buffer): + def update_records(self, records: list[dict], *, allow_resize: bool = False) -> None: + """Update all records in the buffer. + + Args: + records: New records to replace the buffer. + allow_resize: If True, allows the number of records to differ from the current + buffer size. Use for 1:N (expansion) or N:1 (retraction) generation patterns. + Defaults to False for strict 1:1 mapping. + """ + if not allow_resize and len(records) != len(self._buffer): raise DatasetBatchManagementError( f"🛑 Number of records to update ({len(records)}) must match " f"the number of records in the buffer ({len(self._buffer)})." diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py index 383636ba..1f3d1a03 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py @@ -18,6 +18,7 @@ from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy from data_designer.config.custom_column import custom_column_generator +from data_designer.config.errors import InvalidConfigError from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError from data_designer.engine.resources.resource_provider import ResourceProvider @@ -113,6 +114,31 @@ def test_config_validation_non_callable() -> None: CustomColumnConfig(name="test", generator_function="not_a_function") +def test_config_validation_allow_resize_requires_full_column() -> None: + """Test that allow_resize=True requires generation_strategy=FULL_COLUMN.""" + + @custom_column_generator() + def dummy_fn(row: dict) -> dict: + return row + + with pytest.raises(InvalidConfigError, match="allow_resize=True.*requires.*full_column"): + CustomColumnConfig( + name="test", + generator_function=dummy_fn, + allow_resize=True, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + ) + + # Should work with FULL_COLUMN + config = CustomColumnConfig( + name="test", + generator_function=dummy_fn, + allow_resize=True, + generation_strategy=GenerationStrategy.FULL_COLUMN, + ) + assert config.allow_resize is True + + # Cell-by-cell generation tests diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py index 07bd0ebb..7793c1d1 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py @@ -173,6 +173,61 @@ def test_update_records_wrong_length(stub_batch_manager_with_data): stub_batch_manager_with_data.update_records(wrong_length_records) +def test_update_records_allow_resize_expansion(stub_batch_manager_with_data): + """Test that allow_resize=True permits expanding the record count (1:N).""" + records = [{"id": i, "name": f"test{i}"} for i in range(3)] + stub_batch_manager_with_data.add_records(records) + + # Expand from 3 to 6 records + expanded_records = [{"id": i, "name": f"expanded{i}"} for i in range(6)] + stub_batch_manager_with_data.update_records(expanded_records, allow_resize=True) + + assert stub_batch_manager_with_data.num_records_in_buffer == 6 + assert stub_batch_manager_with_data._buffer == expanded_records + + +def test_update_records_allow_resize_retraction(stub_batch_manager_with_data): + """Test that allow_resize=True permits reducing the record count (N:1).""" + records = [{"id": i, "name": f"test{i}"} for i in range(3)] + stub_batch_manager_with_data.add_records(records) + + # Retract from 3 to 1 record + retracted_records = [{"id": 0, "name": "aggregated"}] + stub_batch_manager_with_data.update_records(retracted_records, allow_resize=True) + + assert stub_batch_manager_with_data.num_records_in_buffer == 1 + assert stub_batch_manager_with_data._buffer == retracted_records + + +def test_update_records_allow_resize_to_empty(stub_batch_manager_with_data): + """Test that allow_resize=True permits reducing to zero records.""" + records = [{"id": i, "name": f"test{i}"} for i in range(3)] + stub_batch_manager_with_data.add_records(records) + + stub_batch_manager_with_data.update_records([], allow_resize=True) + + assert stub_batch_manager_with_data.num_records_in_buffer == 0 + assert stub_batch_manager_with_data.buffer_is_empty + + +def test_actual_num_records_tracks_expansion(stub_batch_manager_with_data): + """Test that actual_num_records correctly tracks when buffer is resized.""" + # Add 3 records, then expand to 6 + records = [{"id": i} for i in range(3)] + stub_batch_manager_with_data.add_records(records) + expanded = [{"id": i} for i in range(6)] + stub_batch_manager_with_data.update_records(expanded, allow_resize=True) + + # Finish batch and check metadata + stub_batch_manager_with_data.finish_batch() + + with open(stub_batch_manager_with_data.artifact_storage.metadata_file_path) as f: + metadata = json.load(f) + + assert metadata["target_num_records"] == 10 # original target + assert metadata["actual_num_records"] == 6 # actual expanded count + + # Test write method def test_write_empty_buffer(stub_batch_manager_with_data): result = stub_batch_manager_with_data.write() @@ -271,6 +326,7 @@ def test_finish_batch_metadata_content(stub_batch_manager_with_data): metadata = json.load(f) assert metadata["target_num_records"] == 10 + assert metadata["actual_num_records"] == 3 # actual records written in this batch assert metadata["total_num_batches"] == 4 assert metadata["buffer_size"] == 3 assert metadata["num_completed_batches"] == 1 From 7b31cea62b1c16c61f7f4f3feb4d1c7ade546d44 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Mon, 2 Feb 2026 22:35:16 -0300 Subject: [PATCH 2/8] docs: add allow_resize to custom columns documentation --- docs/concepts/custom_columns.md | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/docs/concepts/custom_columns.md b/docs/concepts/custom_columns.md index d080bd70..f62eebf6 100644 --- a/docs/concepts/custom_columns.md +++ b/docs/concepts/custom_columns.md @@ -93,6 +93,40 @@ This gives you direct access to all `ModelFacade` capabilities: custom parsers, | `generator_function` | Callable | Yes | Decorated function | | `generation_strategy` | GenerationStrategy | No | `CELL_BY_CELL` or `FULL_COLUMN` | | `generator_params` | BaseModel | No | Typed params passed to function | +| `allow_resize` | bool | No | Allow 1:N or N:1 generation. Requires `FULL_COLUMN` strategy | + +### Resizing (1:N and N:1) + +With `full_column` strategy, you can produce more or fewer records than the input using `allow_resize=True`: + +```python +@dd.custom_column_generator( + required_columns=["topic"], + side_effect_columns=["variation_id"], +) +def expand_topics(df: pd.DataFrame, params: None, models: dict) -> pd.DataFrame: + rows = [] + for _, row in df.iterrows(): + for i in range(3): # Generate 3 variations per input + rows.append({ + "topic": row["topic"], + "question": f"Question {i+1} about {row['topic']}", + "variation_id": i, + }) + return pd.DataFrame(rows) + +dd.CustomColumnConfig( + name="question", + generator_function=expand_topics, + generation_strategy=dd.GenerationStrategy.FULL_COLUMN, + allow_resize=True, +) +``` + +Use cases: + +- **Expansion (1:N)**: Generate multiple variations per input +- **Retraction (N:1)**: Filter, aggregate, or deduplicate records ## Multi-Turn Example From 399eb4a0d73044de7c33292e9c2c0d5baa25305e Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 4 Feb 2026 17:55:27 -0300 Subject: [PATCH 3/8] refactor: consolidate buffer API and elevate allow_resize to base config - Merge update_records and replace_buffer into a single replace_buffer method with allow_resize parameter on DatasetBatchManager - Move allow_resize field from CustomColumnConfig to SingleColumnConfig so plugins inherit it without needing a mixin - Align example and logging with final CustomColumn API - Parametrize resize tests and extract shared stub in test_columns --- example_allow_resize.py | 8 +-- .../src/data_designer/config/base.py | 1 + .../data_designer/config/column_configs.py | 8 --- .../tests/config/test_columns.py | 31 ++++++---- .../column_generators/generators/custom.py | 2 + .../dataset_builders/column_wise_builder.py | 4 +- .../utils/dataset_batch_manager.py | 16 ++--- .../utils/processor_runner.py | 2 +- .../test_column_wise_builder.py | 2 +- .../utils/test_dataset_batch_manager.py | 60 +++++++------------ 10 files changed, 55 insertions(+), 79 deletions(-) diff --git a/example_allow_resize.py b/example_allow_resize.py index 3a668311..f8e40945 100644 --- a/example_allow_resize.py +++ b/example_allow_resize.py @@ -11,7 +11,7 @@ @dd.custom_column_generator(required_columns=["topic"], side_effect_columns=["variation_id"]) -def expand_to_questions(df: pd.DataFrame, params: None, ctx: dd.CustomColumnContext) -> pd.DataFrame: +def expand_to_questions(df: pd.DataFrame) -> pd.DataFrame: """Generate 3 questions per topic (1:N expansion).""" rows = [] for _, row in df.iterrows(): @@ -19,7 +19,7 @@ def expand_to_questions(df: pd.DataFrame, params: None, ctx: dd.CustomColumnCont rows.append( { "topic": row["topic"], - ctx.column_name: f"Question {i + 1} about {row['topic']}?", + "question": f"Question {i + 1} about {row['topic']}?", "variation_id": i, } ) @@ -27,10 +27,10 @@ def expand_to_questions(df: pd.DataFrame, params: None, ctx: dd.CustomColumnCont @dd.custom_column_generator(required_columns=["topic", "score"]) -def filter_high_scores(df: pd.DataFrame, params: None, ctx: dd.CustomColumnContext) -> pd.DataFrame: +def filter_high_scores(df: pd.DataFrame) -> pd.DataFrame: """Keep only records with score > 0.5 (N:1 retraction).""" filtered = df[df["score"] > 0.5].copy() - filtered[ctx.column_name] = "passed" + filtered["status"] = "passed" return filtered diff --git a/packages/data-designer-config/src/data_designer/config/base.py b/packages/data-designer-config/src/data_designer/config/base.py index 3bf5b6fa..bcf4ac6d 100644 --- a/packages/data-designer-config/src/data_designer/config/base.py +++ b/packages/data-designer-config/src/data_designer/config/base.py @@ -37,6 +37,7 @@ class SingleColumnConfig(ConfigBase, ABC): name: str drop: bool = False + allow_resize: bool = False column_type: str @staticmethod diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index f564d5b0..9095b44e 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -509,14 +509,6 @@ class CustomColumnConfig(SingleColumnConfig): default=None, description="Optional typed configuration object passed as second argument to generator function", ) - allow_resize: bool = Field( - default=False, - description=( - "If True, allows the generator to produce a different number of records than the input. " - "Use for 1:N (expansion) or N:1 (retraction) generation patterns. " - "Only applicable when generation_strategy is 'full_column'." - ), - ) column_type: Literal["custom"] = "custom" @field_validator("generator_function") diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index a069fea0..d7f8e66c 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -516,19 +516,26 @@ def test_sampler_column_config_discriminated_union_wrong_params_type(): ) -def test_default_column_emoji_for_custom_column_type() -> None: - """Ensure the base get_column_emoji implementation is used when not overridden.""" +class _StubPluginConfig(SingleColumnConfig): + """Minimal plugin config for testing base class behavior.""" + + column_type: Literal["stub-plugin"] = "stub-plugin" + + @property + def required_columns(self) -> list[str]: + return [] - class StubColumnConfigWithoutEmoji(SingleColumnConfig): - column_type: Literal["stub-without-emoji"] = "stub-without-emoji" - value: str + @property + def side_effect_columns(self) -> list[str]: + return [] - @property - def required_columns(self) -> list[str]: - return [] - @property - def side_effect_columns(self) -> list[str]: - return [] +def test_default_column_emoji_for_custom_column_type() -> None: + """Ensure the base get_column_emoji implementation is used when not overridden.""" + assert _StubPluginConfig.get_column_emoji() == "🎨" + - assert StubColumnConfigWithoutEmoji.get_column_emoji() == "🎨" +def test_allow_resize_inherited_by_plugin_configs() -> None: + """Plugin configs inherit allow_resize from SingleColumnConfig.""" + assert _StubPluginConfig(name="test").allow_resize is False + assert _StubPluginConfig(name="test", allow_resize=True).allow_resize is True diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index f0a942e7..8b9ff6f3 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -199,3 +199,5 @@ def log_pre_generation(self) -> None: logger.info(f"{LOG_INDENT}model_aliases: {self.config.model_aliases}") if self.config.generator_params: logger.info(f"{LOG_INDENT}generator_params: {self.config.generator_params}") + if self.config.allow_resize: + logger.info(f"{LOG_INDENT}allow_resize: {self.config.allow_resize}") diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index a8b2c9ab..5ca02cc9 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -246,7 +246,7 @@ def _run_full_column_generator(self, generator: ColumnGenerator) -> None: f"Subsequent columns will operate on the new record count." ) - self.batch_manager.update_records(df.to_dict(orient="records"), allow_resize=allow_resize) + self.batch_manager.replace_buffer(df.to_dict(orient="records"), allow_resize=allow_resize) def _run_model_health_check_if_needed(self) -> None: model_aliases: set[str] = set() @@ -319,7 +319,7 @@ def callback(exc: Exception, *, context: dict | None = None) -> None: return callback def _write_processed_batch(self, dataframe: pd.DataFrame) -> None: - self.batch_manager.update_records(dataframe.to_dict(orient="records")) + self.batch_manager.replace_buffer(dataframe.to_dict(orient="records")) self.batch_manager.write() def _validate_column_configs(self) -> None: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py index 10b70f55..9da0e17a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py @@ -195,24 +195,18 @@ def update_record(self, index: int, record: dict) -> None: raise IndexError(f"🛑 Index {index} is out of bounds for buffer of size {len(self._buffer)}.") self._buffer[index] = record - def update_records(self, records: list[dict], *, allow_resize: bool = False) -> None: - """Update all records in the buffer. + def replace_buffer(self, records: list[dict], *, allow_resize: bool = False) -> None: + """Replace the buffer contents. Args: records: New records to replace the buffer. allow_resize: If True, allows the number of records to differ from the current - buffer size. Use for 1:N (expansion) or N:1 (retraction) generation patterns. - Defaults to False for strict 1:1 mapping. + buffer size (1:N or N:1 patterns). Defaults to False for strict 1:1 mapping. """ if not allow_resize and len(records) != len(self._buffer): raise DatasetBatchManagementError( - f"🛑 Number of records to update ({len(records)}) must match " - f"the number of records in the buffer ({len(self._buffer)})." + f"🛑 Number of records ({len(records)}) must match the current buffer size ({len(self._buffer)})." ) self._buffer = records - - def replace_buffer(self, records: list[dict]) -> None: - """Replace the buffer contents, updating the current batch size.""" - self._buffer = records - if self._num_records_list is not None: + if allow_resize and self._num_records_list is not None: self._num_records_list[self._current_batch_number] = len(records) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py index 284d45f5..26143e6d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py @@ -70,7 +70,7 @@ def run_pre_batch(self, batch_manager: DatasetBatchManager) -> None: df = batch_manager.get_current_batch(as_dataframe=True) df = self._run_stage(df, ProcessorStage.PRE_BATCH) - batch_manager.replace_buffer(df.to_dict(orient="records")) + batch_manager.replace_buffer(df.to_dict(orient="records"), allow_resize=True) def run_post_batch(self, df: pd.DataFrame, current_batch_number: int | None) -> pd.DataFrame: """Run process_after_batch() on processors that implement it.""" diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py index c512871a..c741d586 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py @@ -64,7 +64,7 @@ def stub_batch_manager(): mock_batch_manager.finish = Mock() mock_batch_manager.write = Mock() mock_batch_manager.add_records = Mock() - mock_batch_manager.update_records = Mock() + mock_batch_manager.replace_buffer = Mock() mock_batch_manager.update_record = Mock() mock_batch_manager.get_current_batch = Mock() mock_batch_manager.get_current_batch.side_effect = [ diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py index 7793c1d1..987da48e 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py @@ -153,61 +153,41 @@ def test_update_record_invalid_index(stub_batch_manager_with_data): stub_batch_manager_with_data.update_record(-1, {"id": -1, "name": "test"}) -def test_update_records(stub_batch_manager_with_data): +def test_replace_buffer(stub_batch_manager_with_data): records = [{"id": i, "name": f"test{i}"} for i in range(3)] stub_batch_manager_with_data.add_records(records) new_records = [{"id": i, "name": f"updated{i}"} for i in range(3)] - stub_batch_manager_with_data.update_records(new_records) + stub_batch_manager_with_data.replace_buffer(new_records) assert stub_batch_manager_with_data._buffer == new_records -def test_update_records_wrong_length(stub_batch_manager_with_data): +def test_replace_buffer_wrong_length(stub_batch_manager_with_data): records = [{"id": i, "name": f"test{i}"} for i in range(3)] stub_batch_manager_with_data.add_records(records) wrong_length_records = [{"id": i, "name": f"test{i}"} for i in range(2)] - with pytest.raises(DatasetBatchManagementError, match="Number of records to update.*must match"): - stub_batch_manager_with_data.update_records(wrong_length_records) + with pytest.raises(DatasetBatchManagementError, match="Number of records.*must match"): + stub_batch_manager_with_data.replace_buffer(wrong_length_records) -def test_update_records_allow_resize_expansion(stub_batch_manager_with_data): - """Test that allow_resize=True permits expanding the record count (1:N).""" - records = [{"id": i, "name": f"test{i}"} for i in range(3)] - stub_batch_manager_with_data.add_records(records) - - # Expand from 3 to 6 records - expanded_records = [{"id": i, "name": f"expanded{i}"} for i in range(6)] - stub_batch_manager_with_data.update_records(expanded_records, allow_resize=True) +@pytest.mark.parametrize( + "new_size", + [6, 1, 0], + ids=["expansion", "retraction", "empty"], +) +def test_replace_buffer_allow_resize(stub_batch_manager_with_data, new_size): + """allow_resize=True permits any record count change and updates bookkeeping.""" + stub_batch_manager_with_data.add_records([{"id": i} for i in range(3)]) - assert stub_batch_manager_with_data.num_records_in_buffer == 6 - assert stub_batch_manager_with_data._buffer == expanded_records + new_records = [{"id": i} for i in range(new_size)] + stub_batch_manager_with_data.replace_buffer(new_records, allow_resize=True) - -def test_update_records_allow_resize_retraction(stub_batch_manager_with_data): - """Test that allow_resize=True permits reducing the record count (N:1).""" - records = [{"id": i, "name": f"test{i}"} for i in range(3)] - stub_batch_manager_with_data.add_records(records) - - # Retract from 3 to 1 record - retracted_records = [{"id": 0, "name": "aggregated"}] - stub_batch_manager_with_data.update_records(retracted_records, allow_resize=True) - - assert stub_batch_manager_with_data.num_records_in_buffer == 1 - assert stub_batch_manager_with_data._buffer == retracted_records - - -def test_update_records_allow_resize_to_empty(stub_batch_manager_with_data): - """Test that allow_resize=True permits reducing to zero records.""" - records = [{"id": i, "name": f"test{i}"} for i in range(3)] - stub_batch_manager_with_data.add_records(records) - - stub_batch_manager_with_data.update_records([], allow_resize=True) - - assert stub_batch_manager_with_data.num_records_in_buffer == 0 - assert stub_batch_manager_with_data.buffer_is_empty + assert stub_batch_manager_with_data.num_records_in_buffer == new_size + assert stub_batch_manager_with_data._buffer == new_records + assert stub_batch_manager_with_data.num_records_batch == new_size def test_actual_num_records_tracks_expansion(stub_batch_manager_with_data): @@ -216,7 +196,7 @@ def test_actual_num_records_tracks_expansion(stub_batch_manager_with_data): records = [{"id": i} for i in range(3)] stub_batch_manager_with_data.add_records(records) expanded = [{"id": i} for i in range(6)] - stub_batch_manager_with_data.update_records(expanded, allow_resize=True) + stub_batch_manager_with_data.replace_buffer(expanded, allow_resize=True) # Finish batch and check metadata stub_batch_manager_with_data.finish_batch() @@ -224,7 +204,7 @@ def test_actual_num_records_tracks_expansion(stub_batch_manager_with_data): with open(stub_batch_manager_with_data.artifact_storage.metadata_file_path) as f: metadata = json.load(f) - assert metadata["target_num_records"] == 10 # original target + assert metadata["target_num_records"] == 13 # [6, 3, 3, 1] after resize from [3, 3, 3, 1] assert metadata["actual_num_records"] == 6 # actual expanded count From 1a1690618ea7e558e82b6ba4f39f5deeea16cbd6 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 12 Feb 2026 12:25:58 -0300 Subject: [PATCH 4/8] test: add chained resize and multi-batch integration tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add expand->retract->expand chaining test (single batch) - Add multi-batch resize test verifying combined parquet output - Update example to chain expand/retract/expand with preview+build - Use 💥/✂️ emojis for resize logging (expand/retract) --- example_allow_resize.py | 94 ++++++++++-------- .../tests/config/test_columns.py | 16 ++-- .../dataset_builders/column_wise_builder.py | 4 +- .../test_column_wise_builder.py | 95 +++++++++++++++++++ 4 files changed, 156 insertions(+), 53 deletions(-) diff --git a/example_allow_resize.py b/example_allow_resize.py index f8e40945..43366429 100644 --- a/example_allow_resize.py +++ b/example_allow_resize.py @@ -1,7 +1,11 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Example: Using allow_resize for 1:N expansion and N:1 retraction.""" +"""Example: Chaining expand -> retract -> expand resize operations. + +Pipeline: 5 topics -> 15 questions (3 per topic) -> ~8 hard questions (filter easy) + -> ~24 answer variants (3 per question) +""" from __future__ import annotations @@ -10,42 +14,53 @@ from data_designer.lazy_heavy_imports import pd -@dd.custom_column_generator(required_columns=["topic"], side_effect_columns=["variation_id"]) +# Step 1: Expand — 1:N, generate 3 questions per topic +@dd.custom_column_generator(required_columns=["topic"], side_effect_columns=["question_id", "difficulty"]) def expand_to_questions(df: pd.DataFrame) -> pd.DataFrame: - """Generate 3 questions per topic (1:N expansion).""" rows = [] for _, row in df.iterrows(): for i in range(3): rows.append( { "topic": row["topic"], - "question": f"Question {i + 1} about {row['topic']}?", - "variation_id": i, + "question": f"Q{i + 1} about {row['topic']}?", + "question_id": i, + "difficulty": ["easy", "medium", "hard"][i], } ) return pd.DataFrame(rows) -@dd.custom_column_generator(required_columns=["topic", "score"]) -def filter_high_scores(df: pd.DataFrame) -> pd.DataFrame: - """Keep only records with score > 0.5 (N:1 retraction).""" - filtered = df[df["score"] > 0.5].copy() - filtered["status"] = "passed" - return filtered +# Step 2: Retract — N:1, keep only medium/hard questions +@dd.custom_column_generator(required_columns=["difficulty"]) +def filter_non_easy(df: pd.DataFrame) -> pd.DataFrame: + return df[df["difficulty"] != "easy"].copy().assign(filtered=True) + + +# Step 3: Expand again — 1:N, generate 3 answer variants per surviving question +@dd.custom_column_generator(required_columns=["question"], side_effect_columns=["variant"]) +def expand_to_answers(df: pd.DataFrame) -> pd.DataFrame: + rows = [] + for _, row in df.iterrows(): + for v in range(3): + rows.append({**row.to_dict(), "answer": f"Answer v{v} to: {row['question']}", "variant": v}) + return pd.DataFrame(rows) -def run_expansion_example() -> None: - """3 topics -> 9 questions.""" +def main() -> None: data_designer = DataDesigner() config_builder = dd.DataDesignerConfigBuilder() + # Seed: 5 topics config_builder.add_column( dd.SamplerColumnConfig( name="topic", sampler_type=dd.SamplerType.CATEGORY, - params=dd.CategorySamplerParams(values=["Python", "ML", "Data"]), + params=dd.CategorySamplerParams(values=["Python", "ML", "Data", "Stats", "SQL"]), ) ) + + # Expand: 5 topics -> 15 questions config_builder.add_column( dd.CustomColumnConfig( name="question", @@ -55,44 +70,39 @@ def run_expansion_example() -> None: ) ) - preview = data_designer.preview(config_builder=config_builder, num_records=3) - print(f"Expansion: 3 -> {len(preview.dataset)} records") - print(preview.dataset.to_string()) - - -def run_retraction_example() -> None: - """10 records -> ~5 (filtered).""" - data_designer = DataDesigner() - config_builder = dd.DataDesignerConfigBuilder() - + # Retract: 15 -> 10 (drop "easy" questions) config_builder.add_column( - dd.SamplerColumnConfig( - name="topic", - sampler_type=dd.SamplerType.CATEGORY, - params=dd.CategorySamplerParams(values=["A", "B", "C", "D", "E"]), - ) - ) - config_builder.add_column( - dd.SamplerColumnConfig( - name="score", - sampler_type=dd.SamplerType.UNIFORM, - params=dd.UniformSamplerParams(low=0.0, high=1.0), + dd.CustomColumnConfig( + name="filtered", + generator_function=filter_non_easy, + generation_strategy=dd.GenerationStrategy.FULL_COLUMN, + allow_resize=True, ) ) + + # Expand again: 10 -> 30 answer variants config_builder.add_column( dd.CustomColumnConfig( - name="status", - generator_function=filter_high_scores, + name="answer", + generator_function=expand_to_answers, generation_strategy=dd.GenerationStrategy.FULL_COLUMN, allow_resize=True, ) ) - preview = data_designer.preview(config_builder=config_builder, num_records=10) - print(f"Retraction: 10 -> {len(preview.dataset)} records") - print(preview.dataset.to_string()) + # Preview (single batch) + preview = data_designer.preview(config_builder=config_builder, num_records=5) + print(f"Preview: 5 topics -> {len(preview.dataset)} answer variants") + print(preview.dataset[["topic", "difficulty", "question", "variant", "answer"]].to_string()) + print() + + # Build (multiple batches: 10 records with buffer_size=3 -> 4 batches) + data_designer.set_run_config(dd.RunConfig(buffer_size=3)) + results = data_designer.create(config_builder=config_builder, num_records=10) + df = results.load_dataset() + print(f"Build: 10 topics (4 batches of 3+3+3+1) -> {len(df)} answer variants") + print(df[["topic", "difficulty", "question", "variant"]].to_string()) if __name__ == "__main__": - run_expansion_example() - # run_retraction_example() + main() diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index d7f8e66c..a369e9d3 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -516,10 +516,8 @@ def test_sampler_column_config_discriminated_union_wrong_params_type(): ) -class _StubPluginConfig(SingleColumnConfig): - """Minimal plugin config for testing base class behavior.""" - - column_type: Literal["stub-plugin"] = "stub-plugin" +class StubColumnConfigWithoutEmoji(SingleColumnConfig): + column_type: Literal["stub-without-emoji"] = "stub-without-emoji" @property def required_columns(self) -> list[str]: @@ -532,10 +530,10 @@ def side_effect_columns(self) -> list[str]: def test_default_column_emoji_for_custom_column_type() -> None: """Ensure the base get_column_emoji implementation is used when not overridden.""" - assert _StubPluginConfig.get_column_emoji() == "🎨" + assert StubColumnConfigWithoutEmoji.get_column_emoji() == "🎨" -def test_allow_resize_inherited_by_plugin_configs() -> None: - """Plugin configs inherit allow_resize from SingleColumnConfig.""" - assert _StubPluginConfig(name="test").allow_resize is False - assert _StubPluginConfig(name="test", allow_resize=True).allow_resize is True +def test_allow_resize_inherited_by_subclasses() -> None: + """Subclasses inherit allow_resize from SingleColumnConfig.""" + assert StubColumnConfigWithoutEmoji(name="test").allow_resize is False + assert StubColumnConfigWithoutEmoji(name="test", allow_resize=True).allow_resize is True diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 5ca02cc9..1ee1a57b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -241,9 +241,9 @@ def _run_full_column_generator(self, generator: ColumnGenerator) -> None: f"⚠️ Column '{generator.config.name}' reduced batch to 0 records. This batch will be skipped." ) else: + emoji = "💥" if new_count > original_count else "✂️" logger.info( - f"📊 Column '{generator.config.name}' resized batch: {original_count} -> {new_count} records. " - f"Subsequent columns will operate on the new record count." + f"{emoji} Column '{generator.config.name}' resized batch: {original_count} -> {new_count} records." ) self.batch_manager.replace_buffer(df.to_dict(orient="records"), allow_resize=allow_resize) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py index c741d586..4e44f130 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py @@ -562,3 +562,98 @@ def test_process_preview_with_empty_dataframe(simple_builder): assert len(result) == 0 mock_processor.process_after_batch.assert_called_once() mock_processor.process_after_generation.assert_called_once() + + +# allow_resize integration tests + + +@custom_column_generator(required_columns=["seed_id"], side_effect_columns=["copy"]) +def _expand_x3(df: pd.DataFrame) -> pd.DataFrame: + """Triple each row.""" + rows = [] + for _, row in df.iterrows(): + for i in range(3): + rows.append({**row.to_dict(), "expanded": f"{row['seed_id']}_v{i}", "copy": i}) + return pd.DataFrame(rows) + + +@custom_column_generator(required_columns=["seed_id"]) +def _keep_first(df: pd.DataFrame) -> pd.DataFrame: + """Keep only first row per seed_id (retraction).""" + return df.drop_duplicates(subset="seed_id").assign(filtered=True) + + +@custom_column_generator(required_columns=["seed_id"], side_effect_columns=["copy2"]) +def _expand_x3_again(df: pd.DataFrame) -> pd.DataFrame: + """Triple each row (second expansion).""" + rows = [] + for _, row in df.iterrows(): + for i in range(3): + rows.append({**row.to_dict(), "expanded_again": f"{row['seed_id']}_w{i}", "copy2": i}) + return pd.DataFrame(rows) + + +def _build_resize_builder(stub_resource_provider, stub_model_configs, seed_data_setup, columns): + """Helper to build a ColumnWiseDatasetBuilder with resize custom columns.""" + config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"]))) + for col in columns: + config_builder.add_column(col) + return ColumnWiseDatasetBuilder( + data_designer_config=config_builder.build(), + resource_provider=stub_resource_provider, + ) + + +def test_chained_expand_retract_expand(stub_resource_provider, stub_model_configs, seed_data_setup): + """Expand -> retract -> expand chains correctly in a single batch.""" + builder = _build_resize_builder( + stub_resource_provider, + stub_model_configs, + seed_data_setup, + [ + CustomColumnConfig( + name="expanded", + generator_function=_expand_x3, + generation_strategy=GenerationStrategy.FULL_COLUMN, + allow_resize=True, + ), + CustomColumnConfig( + name="filtered", + generator_function=_keep_first, + generation_strategy=GenerationStrategy.FULL_COLUMN, + allow_resize=True, + ), + CustomColumnConfig( + name="expanded_again", + generator_function=_expand_x3_again, + generation_strategy=GenerationStrategy.FULL_COLUMN, + allow_resize=True, + ), + ], + ) + # 5 seeds -> 15 (x3) -> 5 (dedup) -> 15 (x3) + result = builder.build_preview(num_records=5) + assert len(result) == 15 + + +def test_resize_across_multiple_batches(stub_resource_provider, stub_model_configs, seed_data_setup): + """Resized batches are written independently and combine correctly.""" + stub_resource_provider.run_config = RunConfig(buffer_size=2) + builder = _build_resize_builder( + stub_resource_provider, + stub_model_configs, + seed_data_setup, + [ + CustomColumnConfig( + name="expanded", + generator_function=_expand_x3, + generation_strategy=GenerationStrategy.FULL_COLUMN, + allow_resize=True, + ), + ], + ) + # 5 seeds, buffer_size=2 -> batches of [2, 2, 1], each x3 -> [6, 6, 3] = 15 total + builder.build(num_records=5) + df = pd.read_parquet(builder.artifact_storage.final_dataset_path) + assert len(df) == 15 From 07d14bd9ca65744c04f5d064dbbbe980ccdfee5b Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 12 Feb 2026 14:55:32 -0300 Subject: [PATCH 5/8] extend allow_resize to cell-by-cell (return dict or list[dict]) - Config: allow allow_resize with CELL_BY_CELL; relax validator - Custom generator: accept dict | list[dict] when cell_by_cell + allow_resize; validate per row via _validate_cell_output - Builder: collect results by index when cell allow_resize, flatten and replace_buffer; add _log_resize_if_changed and _column_display_name - Docs: ALL_CAPS for strategies, simplify allow_resize table text - Tests: parametrized preview and multibatch; factories with n param; _RESIZE_SPECS with inline factory calls; ids ordered like specs --- docs/concepts/custom_columns.md | 24 ++- .../data_designer/config/column_configs.py | 9 +- .../column_generators/generators/custom.py | 62 +++++- .../dataset_builders/column_wise_builder.py | 62 ++++-- .../generators/test_custom.py | 98 ++++++++-- .../test_column_wise_builder.py | 184 +++++++++++------- 6 files changed, 326 insertions(+), 113 deletions(-) diff --git a/docs/concepts/custom_columns.md b/docs/concepts/custom_columns.md index f62eebf6..98344485 100644 --- a/docs/concepts/custom_columns.md +++ b/docs/concepts/custom_columns.md @@ -93,11 +93,11 @@ This gives you direct access to all `ModelFacade` capabilities: custom parsers, | `generator_function` | Callable | Yes | Decorated function | | `generation_strategy` | GenerationStrategy | No | `CELL_BY_CELL` or `FULL_COLUMN` | | `generator_params` | BaseModel | No | Typed params passed to function | -| `allow_resize` | bool | No | Allow 1:N or N:1 generation. Requires `FULL_COLUMN` strategy | +| `allow_resize` | bool | No | Allow 1:N or N:1 generation | ### Resizing (1:N and N:1) -With `full_column` strategy, you can produce more or fewer records than the input using `allow_resize=True`: +**FULL_COLUMN:** Set `allow_resize=True` and return a DataFrame with more or fewer rows than the input: ```python @dd.custom_column_generator( @@ -123,10 +123,28 @@ dd.CustomColumnConfig( ) ``` +**CELL_BY_CELL:** With `allow_resize=True`, your function may return a single row (`dict`) or multiple rows (`list[dict]`). Return `[]` to drop that input row. + +```python +@dd.custom_column_generator(required_columns=["id"]) +def expand_row(row: dict) -> list[dict]: + return [ + {**row, "variant": "a"}, + {**row, "variant": "b"}, + ] + +dd.CustomColumnConfig( + name="variant", + generator_function=expand_row, + generation_strategy=dd.GenerationStrategy.CELL_BY_CELL, + allow_resize=True, +) +``` + Use cases: - **Expansion (1:N)**: Generate multiple variations per input -- **Retraction (N:1)**: Filter, aggregate, or deduplicate records +- **Retraction (N:1)**: Filter, aggregate, or deduplicate records (FULL_COLUMN) or return `[]` per row (CELL_BY_CELL) ## Multi-Turn Example diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index 9095b44e..e352ac4c 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -563,9 +563,12 @@ def validate_generator_function(self) -> Self: @model_validator(mode="after") def validate_allow_resize_requires_full_column(self) -> Self: - if self.allow_resize and self.generation_strategy != GenerationStrategy.FULL_COLUMN: + if self.allow_resize and self.generation_strategy not in ( + GenerationStrategy.FULL_COLUMN, + GenerationStrategy.CELL_BY_CELL, + ): raise InvalidConfigError( - f"🛑 `allow_resize=True` requires `generation_strategy='full_column'` for column '{self.name}'. " - f"Cell-by-cell strategy processes one row at a time and cannot change record count." + f"🛑 `allow_resize=True` requires `generation_strategy` 'full_column' or 'cell_by_cell' " + f"for column '{self.name}'." ) return self diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index 8b9ff6f3..ba5d4bd8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -43,8 +43,11 @@ def get_generation_strategy(self) -> GenerationStrategy: """Return strategy based on config.""" return self.config.generation_strategy - def generate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame: - """Generate column value(s) for a row (dict) or batch (DataFrame).""" + def generate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | list[dict]: + """Generate column value(s) for a row (dict) or batch (DataFrame). + + For cell_by_cell with allow_resize=True, may return dict or list[dict] (0, 1, or N rows). + """ is_full_column = self.config.generation_strategy == GenerationStrategy.FULL_COLUMN is_dataframe = not isinstance(data, dict) @@ -62,7 +65,7 @@ def generate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame: return self._generate(data, is_dataframe) - def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.DataFrame: + def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.DataFrame | list[dict]: """Unified generation logic for both strategies.""" # Get columns/keys using unified accessor get_keys = (lambda d: set(d.columns)) if is_dataframe else (lambda d: set(d.keys())) @@ -93,7 +96,23 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd. f"Custom generator function failed for column '{self.config.name}': {e}" ) from e - # Validate return type + # Cell-by-cell with allow_resize: accept dict or list[dict] + if not is_dataframe and getattr(self.config, "allow_resize", False): + if isinstance(result, list): + if not all(isinstance(r, dict) for r in result): + raise CustomColumnGenerationError( + f"Custom generator for column '{self.config.name}' with allow_resize must return " + "dict or list[dict]; list elements must be dicts." + ) + return [self._validate_cell_output(r, keys_before) for r in result] + if isinstance(result, dict): + return self._validate_output(result, keys_before, is_dataframe) + raise CustomColumnGenerationError( + f"Custom generator for column '{self.config.name}' with allow_resize must return " + f"dict or list[dict], got {type(result).__name__}" + ) + + # Validate return type for non-resize paths if not isinstance(result, expected_type): raise CustomColumnGenerationError( f"Custom generator for column '{self.config.name}' must return a {type_name}, " @@ -102,6 +121,38 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd. return self._validate_output(result, keys_before, is_dataframe) + def _validate_cell_output(self, row: dict, keys_before: set[str]) -> dict: + """Validate a single row output (dict) for cell_by_cell; strip undeclared columns.""" + expected_new = {self.config.name} | set(self.config.side_effect_columns) + result_keys = set(row.keys()) + + if self.config.name not in result_keys: + raise CustomColumnGenerationError( + f"Custom generator for column '{self.config.name}' did not create the expected column. " + "The generator_function must add a key named '{self.config.name}' to the row." + ) + missing = set(self.config.side_effect_columns) - result_keys + if missing: + raise CustomColumnGenerationError( + f"Custom generator for column '{self.config.name}' did not create declared side_effect_columns: " + f"{sorted(missing)}. Declared side_effect_columns must be added to the row." + ) + removed = keys_before - result_keys + if removed: + raise CustomColumnGenerationError( + f"Custom generator for column '{self.config.name}' removed pre-existing columns: " + f"{sorted(removed)}. The generator_function must not remove any existing columns." + ) + undeclared = (result_keys - keys_before) - expected_new + if undeclared: + logger.warning( + f"⚠️ Custom generator for column '{self.config.name}' created undeclared columns: " + f"{sorted(undeclared)}. These columns will be removed. " + f"To keep additional columns, declare them in @custom_column_generator(side_effect_columns=[...])." + ) + row = {k: v for k, v in row.items() if k not in undeclared} + return row + def _validate_output( self, result: dict | pd.DataFrame, keys_before: set[str], is_dataframe: bool ) -> dict | pd.DataFrame: @@ -147,8 +198,7 @@ def _validate_output( if is_dataframe: result = result.drop(columns=list(undeclared)) else: - for key in undeclared: - del result[key] + result = self._validate_cell_output(result, keys_before) return result diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 1ee1a57b..705778f4 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -63,6 +63,8 @@ def __init__( self.batch_manager = DatasetBatchManager(resource_provider.artifact_storage) self._resource_provider = resource_provider self._records_to_drop: set[int] = set() + self._cell_resize_results: list[dict | list[dict] | None] = [] + self._cell_resize_mode = False self._registry = registry or DataDesignerRegistry() self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider) @@ -229,23 +231,23 @@ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None: max_workers = generator.inference_parameters.max_parallel_requests self._fan_out_with_threads(generator, max_workers=max_workers) + def _column_display_name(self, config: ColumnConfigT) -> str: + return f"columns {config.column_names}" if hasattr(config, "column_names") else config.name + + def _log_resize_if_changed(self, column_name: str, original_count: int, new_count: int, allow_resize: bool) -> None: + if not allow_resize or new_count == original_count: + return + if new_count == 0: + logger.warning(f"⚠️ Column '{column_name}' reduced batch to 0 records. This batch will be skipped.") + else: + emoji = "💥" if new_count > original_count else "✂️" + logger.info(f"{emoji} Column '{column_name}' resized batch: {original_count} -> {new_count} records.") + def _run_full_column_generator(self, generator: ColumnGenerator) -> None: original_count = self.batch_manager.num_records_in_buffer df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True)) allow_resize = getattr(generator.config, "allow_resize", False) - new_count = len(df) - - if allow_resize and new_count != original_count: - if new_count == 0: - logger.warning( - f"⚠️ Column '{generator.config.name}' reduced batch to 0 records. This batch will be skipped." - ) - else: - emoji = "💥" if new_count > original_count else "✂️" - logger.info( - f"{emoji} Column '{generator.config.name}' resized batch: {original_count} -> {new_count} records." - ) - + self._log_resize_if_changed(self._column_display_name(generator.config), original_count, len(df), allow_resize) self.batch_manager.replace_buffer(df.to_dict(orient="records"), allow_resize=allow_resize) def _run_model_health_check_if_needed(self) -> None: @@ -279,6 +281,11 @@ def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max if getattr(generator.config, "tool_alias", None): logger.info("🛠️ Tool calling enabled") + allow_resize = getattr(generator.config, "allow_resize", False) + if allow_resize: + self._cell_resize_results = [None] * self.batch_manager.num_records_batch + self._cell_resize_mode = True + progress_tracker = ProgressTracker( total_records=self.batch_manager.num_records_batch, label=f"{generator.config.column_type} column '{generator.config.name}'", @@ -300,7 +307,27 @@ def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max progress_tracker.log_final() - if len(self._records_to_drop) > 0: + if allow_resize: + # Flatten results in index order; skip indices in _records_to_drop (failed cells), + # so those rows are omitted from the new buffer. + new_records: list[dict] = [] + for i in range(len(self._cell_resize_results)): + if i in self._records_to_drop: + continue + r = self._cell_resize_results[i] + if r is not None: + new_records.extend(r if isinstance(r, list) else [r]) + self._log_resize_if_changed( + self._column_display_name(generator.config), + self.batch_manager.num_records_in_buffer, + len(new_records), + True, + ) + self.batch_manager.replace_buffer(new_records, allow_resize=True) + self._records_to_drop.clear() + self._cell_resize_mode = False + self._cell_resize_results = [] + elif len(self._records_to_drop) > 0: self.batch_manager.drop_records(self._records_to_drop) self._records_to_drop.clear() @@ -372,8 +399,11 @@ def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) ) self._records_to_drop.add(context["index"]) - def _worker_result_callback(self, result: dict, *, context: dict | None = None) -> None: - self.batch_manager.update_record(context["index"], result) + def _worker_result_callback(self, result: dict | list[dict], *, context: dict | None = None) -> None: + if self._cell_resize_mode: + self._cell_resize_results[context["index"]] = result + else: + self.batch_manager.update_record(context["index"], result) def _emit_batch_inference_events( self, batch_mode: str, usage_deltas: dict[str, ModelUsageStats], group_id: str diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py index 1f3d1a03..c768a720 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py @@ -18,7 +18,6 @@ from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy from data_designer.config.custom_column import custom_column_generator -from data_designer.config.errors import InvalidConfigError from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError from data_designer.engine.resources.resource_provider import ResourceProvider @@ -114,29 +113,21 @@ def test_config_validation_non_callable() -> None: CustomColumnConfig(name="test", generator_function="not_a_function") -def test_config_validation_allow_resize_requires_full_column() -> None: - """Test that allow_resize=True requires generation_strategy=FULL_COLUMN.""" +def test_config_validation_allow_resize_allows_full_column_and_cell_by_cell() -> None: + """allow_resize=True is valid with full_column or cell_by_cell.""" @custom_column_generator() def dummy_fn(row: dict) -> dict: return row - with pytest.raises(InvalidConfigError, match="allow_resize=True.*requires.*full_column"): - CustomColumnConfig( + for strategy in (GenerationStrategy.FULL_COLUMN, GenerationStrategy.CELL_BY_CELL): + config = CustomColumnConfig( name="test", generator_function=dummy_fn, allow_resize=True, - generation_strategy=GenerationStrategy.CELL_BY_CELL, + generation_strategy=strategy, ) - - # Should work with FULL_COLUMN - config = CustomColumnConfig( - name="test", - generator_function=dummy_fn, - allow_resize=True, - generation_strategy=GenerationStrategy.FULL_COLUMN, - ) - assert config.allow_resize is True + assert config.allow_resize is True # Cell-by-cell generation tests @@ -186,6 +177,83 @@ def test_side_effect_columns() -> None: assert result["secondary"] == 15 +# cell_by_cell allow_resize: dict | list[dict] + + +def test_cell_by_cell_allow_resize_return_dict() -> None: + """With allow_resize, returning a single dict (1:1) works like normal cell-by-cell.""" + config = CustomColumnConfig( + name="result", + generator_function=generator_with_required_columns, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + allow_resize=True, + ) + generator = CustomColumnGenerator(config=config, resource_provider=Mock(spec=ResourceProvider)) + result = generator.generate({"input": "hi"}) + assert isinstance(result, dict) + assert result["result"] == "HI" + + +def test_cell_by_cell_allow_resize_return_list_expand() -> None: + """With allow_resize, returning list[dict] expands one row into multiple.""" + + @custom_column_generator(required_columns=["x"]) + def expand(row: dict) -> list[dict]: + return [ + {**row, "out": row["x"] * 1}, + {**row, "out": row["x"] * 2}, + ] + + config = CustomColumnConfig( + name="out", + generator_function=expand, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + allow_resize=True, + ) + generator = CustomColumnGenerator(config=config, resource_provider=Mock(spec=ResourceProvider)) + result = generator.generate({"x": 10}) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0] == {"x": 10, "out": 10} + assert result[1] == {"x": 10, "out": 20} + + +def test_cell_by_cell_allow_resize_return_empty_list() -> None: + """With allow_resize, returning [] drops that row (0 rows).""" + + @custom_column_generator(required_columns=["x"]) + def drop(row: dict) -> list[dict]: + return [] + + config = CustomColumnConfig( + name="out", + generator_function=drop, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + allow_resize=True, + ) + generator = CustomColumnGenerator(config=config, resource_provider=Mock(spec=ResourceProvider)) + result = generator.generate({"x": 1}) + assert result == [] + + +def test_cell_by_cell_allow_resize_invalid_return_type() -> None: + """With allow_resize, return must be dict or list[dict].""" + + @custom_column_generator(required_columns=["x"]) + def bad_return(row: dict): + return [1, 2] + + config = CustomColumnConfig( + name="out", + generator_function=bad_return, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + allow_resize=True, + ) + generator = CustomColumnGenerator(config=config, resource_provider=Mock(spec=ResourceProvider)) + with pytest.raises(CustomColumnGenerationError, match="list elements must be dicts"): + generator.generate({"x": 1}) + + # Error handling tests diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py index 4e44f130..78ea2855 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py @@ -565,36 +565,72 @@ def test_process_preview_with_empty_dataframe(simple_builder): # allow_resize integration tests +# +# Factories and keep_first stub; _RESIZE_SPECS calls factories inline (n, col names). -@custom_column_generator(required_columns=["seed_id"], side_effect_columns=["copy"]) -def _expand_x3(df: pd.DataFrame) -> pd.DataFrame: - """Triple each row.""" - rows = [] - for _, row in df.iterrows(): - for i in range(3): - rows.append({**row.to_dict(), "expanded": f"{row['seed_id']}_v{i}", "copy": i}) - return pd.DataFrame(rows) +def _make_resize_full_expand(n: int, primary_col: str, side_effect_col: str): + @custom_column_generator(required_columns=["seed_id"], side_effect_columns=[side_effect_col]) + def fn(df: pd.DataFrame) -> pd.DataFrame: + rows = [] + for _, row in df.iterrows(): + for i in range(n): + rows.append({**row.to_dict(), primary_col: f"{row['seed_id']}_v{i}", side_effect_col: i}) + return pd.DataFrame(rows) + + return fn + + +def _make_resize_cell_expand(n: int, col_name: str): + suffixes = tuple("abcdefgh"[:n]) + + @custom_column_generator(required_columns=["seed_id"]) + def fn(row: dict) -> list[dict]: + return [{**row, col_name: f"{row['seed_id']}_{s}"} for s in suffixes] + + return fn @custom_column_generator(required_columns=["seed_id"]) -def _keep_first(df: pd.DataFrame) -> pd.DataFrame: - """Keep only first row per seed_id (retraction).""" +def _resize_full_keep_first(df: pd.DataFrame) -> pd.DataFrame: + """FULL_COLUMN: keep first row per seed_id (retraction).""" return df.drop_duplicates(subset="seed_id").assign(filtered=True) -@custom_column_generator(required_columns=["seed_id"], side_effect_columns=["copy2"]) -def _expand_x3_again(df: pd.DataFrame) -> pd.DataFrame: - """Triple each row (second expansion).""" - rows = [] - for _, row in df.iterrows(): - for i in range(3): - rows.append({**row.to_dict(), "expanded_again": f"{row['seed_id']}_w{i}", "copy2": i}) - return pd.DataFrame(rows) +FULL = GenerationStrategy.FULL_COLUMN +CELL = GenerationStrategy.CELL_BY_CELL + +_RESIZE_SPECS: dict[str, list[tuple[str, object, GenerationStrategy]]] = { + "single_full_x3": [("expanded", _make_resize_full_expand(3, "expanded", "copy"), FULL)], + "single_cell_x2": [("doubled", _make_resize_cell_expand(2, "doubled"), CELL)], + "full_chain": [ + ("expanded", _make_resize_full_expand(3, "expanded", "copy"), FULL), + ("filtered", _resize_full_keep_first, FULL), + ("expanded_again", _make_resize_full_expand(3, "expanded_again", "copy2"), FULL), + ], + "cell_then_full_chain": [ + ("doubled", _make_resize_cell_expand(2, "doubled"), CELL), + ("filtered", _resize_full_keep_first, FULL), + ("expanded_again", _make_resize_full_expand(3, "expanded_again", "copy2"), FULL), + ], +} + + +def _resize_columns(spec: str) -> list[CustomColumnConfig]: + """Return column configs for a given allow_resize recipe.""" + return [ + CustomColumnConfig( + name=name, + generator_function=fn, + generation_strategy=strat, + allow_resize=True, + ) + for name, fn, strat in _RESIZE_SPECS[spec] + ] def _build_resize_builder(stub_resource_provider, stub_model_configs, seed_data_setup, columns): - """Helper to build a ColumnWiseDatasetBuilder with resize custom columns.""" + """Build a ColumnWiseDatasetBuilder with the given resize column configs.""" config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"]))) for col in columns: @@ -605,55 +641,63 @@ def _build_resize_builder(stub_resource_provider, stub_model_configs, seed_data_ ) -def test_chained_expand_retract_expand(stub_resource_provider, stub_model_configs, seed_data_setup): - """Expand -> retract -> expand chains correctly in a single batch.""" - builder = _build_resize_builder( - stub_resource_provider, - stub_model_configs, - seed_data_setup, - [ - CustomColumnConfig( - name="expanded", - generator_function=_expand_x3, - generation_strategy=GenerationStrategy.FULL_COLUMN, - allow_resize=True, - ), - CustomColumnConfig( - name="filtered", - generator_function=_keep_first, - generation_strategy=GenerationStrategy.FULL_COLUMN, - allow_resize=True, - ), - CustomColumnConfig( - name="expanded_again", - generator_function=_expand_x3_again, - generation_strategy=GenerationStrategy.FULL_COLUMN, - allow_resize=True, - ), - ], - ) - # 5 seeds -> 15 (x3) -> 5 (dedup) -> 15 (x3) - result = builder.build_preview(num_records=5) - assert len(result) == 15 - - -def test_resize_across_multiple_batches(stub_resource_provider, stub_model_configs, seed_data_setup): - """Resized batches are written independently and combine correctly.""" - stub_resource_provider.run_config = RunConfig(buffer_size=2) - builder = _build_resize_builder( - stub_resource_provider, - stub_model_configs, - seed_data_setup, - [ - CustomColumnConfig( - name="expanded", - generator_function=_expand_x3, - generation_strategy=GenerationStrategy.FULL_COLUMN, - allow_resize=True, - ), - ], - ) - # 5 seeds, buffer_size=2 -> batches of [2, 2, 1], each x3 -> [6, 6, 3] = 15 total - builder.build(num_records=5) +@pytest.mark.parametrize( + "spec,num_records,expected_len,check_doubled_order", + [ + ("single_full_x3", 5, 15, False), + ("single_cell_x2", 5, 10, True), + ("full_chain", 5, 15, False), + ("cell_then_full_chain", 5, 15, False), + ], + ids=["single_full_x3", "single_cell_x2", "full_chain", "cell_then_full_chain"], +) +def test_allow_resize_preview( + stub_resource_provider, + stub_model_configs, + seed_data_setup, + spec, + num_records, + expected_len, + check_doubled_order, +): + """Preview with allow_resize columns (FULL_COLUMN and/or CELL_BY_CELL) yields expected length.""" + columns = _resize_columns(spec) + builder = _build_resize_builder(stub_resource_provider, stub_model_configs, seed_data_setup, columns) + result = builder.build_preview(num_records=num_records) + assert len(result) == expected_len + if check_doubled_order: + expected = [x for i in range(1, 6) for x in (f"{i}_a", f"{i}_b")] + assert result["doubled"].tolist() == expected + + +@pytest.mark.parametrize( + "spec,num_records,buffer_size,expected_total_rows", + [ + ("single_full_x3", 5, 2, 15), # batches [2,2,1] -> each x3 -> 6+6+3 + ("single_cell_x2", 5, 2, 10), # batches [2,2,1] -> each x2 -> 4+4+2 + ("full_chain", 5, 2, 15), # batches [2,2,1] -> x3, dedup, x3 -> 15 + ("single_full_x3", 4, 2, 12), # batches [2,2] -> 6+6 + ], + ids=[ + "single_full_x3_multibatch", + "single_cell_x2_multibatch", + "full_chain_multibatch", + "single_full_x3_4_2_multibatch", + ], +) +def test_allow_resize_multiple_batches( + stub_resource_provider, + stub_model_configs, + seed_data_setup, + spec, + num_records, + buffer_size, + expected_total_rows, +): + """Resized batches are written independently and combine to expected total rows.""" + stub_resource_provider.run_config = RunConfig(buffer_size=buffer_size) + columns = _resize_columns(spec) + builder = _build_resize_builder(stub_resource_provider, stub_model_configs, seed_data_setup, columns) + builder.build(num_records=num_records) df = pd.read_parquet(builder.artifact_storage.final_dataset_path) - assert len(df) == 15 + assert len(df) == expected_total_rows From 0ff511d8d2c19bbe29f77379786caa7d70eb98f9 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 12 Feb 2026 15:16:38 -0300 Subject: [PATCH 6/8] reorder allow_resize specs and add edge-case tests - Rename specs: full_x3, cell_x2, cell_plus_full_chain; add cell_filter_odd, cell_drop_all to _RESIZE_SPECS - Stubs before specs: _resize_full_keep_first, _resize_cell_expand, _resize_cell_filter_odd, _resize_cell_drop_all; drop cell factories - Remove FULL/CELL constants; use GenerationStrategy.* in _RESIZE_SPECS - Preview/multibatch parametrize: _preview and _multibatch ids; two full_x3 multibatch cases (5_2, 4_2) first - Handle all-batches-skipped in multibatch test (empty df when path missing) - test_custom: add test_cell_by_cell_allow_resize_return_list_single (1:1 via list) --- .../generators/test_custom.py | 18 ++++ .../test_column_wise_builder.py | 95 ++++++++++++------- 2 files changed, 81 insertions(+), 32 deletions(-) diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py index c768a720..f3d36219 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py @@ -218,6 +218,24 @@ def expand(row: dict) -> list[dict]: assert result[1] == {"x": 10, "out": 20} +def test_cell_by_cell_allow_resize_return_list_single() -> None: + """With allow_resize, returning [dict] (1:1 via list) is valid.""" + + @custom_column_generator(required_columns=["x"]) + def one_row(row: dict) -> list[dict]: + return [{**row, "out": row["x"]}] + + config = CustomColumnConfig( + name="out", + generator_function=one_row, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + allow_resize=True, + ) + generator = CustomColumnGenerator(config=config, resource_provider=Mock(spec=ResourceProvider)) + result = generator.generate({"x": 42}) + assert result == [{"x": 42, "out": 42}] + + def test_cell_by_cell_allow_resize_return_empty_list() -> None: """With allow_resize, returning [] drops that row (0 rows).""" diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py index 78ea2855..7005df44 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py @@ -566,10 +566,12 @@ def test_process_preview_with_empty_dataframe(simple_builder): # allow_resize integration tests # -# Factories and keep_first stub; _RESIZE_SPECS calls factories inline (n, col names). +# Factory: _make_resize_full_expand. Stubs: _resize_full_keep_first, _resize_cell_*. def _make_resize_full_expand(n: int, primary_col: str, side_effect_col: str): + """CELL_BY_CELL: expand n times per seed_id.""" + @custom_column_generator(required_columns=["seed_id"], side_effect_columns=[side_effect_col]) def fn(df: pd.DataFrame) -> pd.DataFrame: rows = [] @@ -581,37 +583,49 @@ def fn(df: pd.DataFrame) -> pd.DataFrame: return fn -def _make_resize_cell_expand(n: int, col_name: str): - suffixes = tuple("abcdefgh"[:n]) +@custom_column_generator(required_columns=["seed_id"]) +def _resize_full_keep_first(df: pd.DataFrame) -> pd.DataFrame: + """FULL_COLUMN: keep first row per seed_id (retraction).""" + return df.drop_duplicates(subset="seed_id").assign(filtered=True) - @custom_column_generator(required_columns=["seed_id"]) - def fn(row: dict) -> list[dict]: - return [{**row, col_name: f"{row['seed_id']}_{s}"} for s in suffixes] - return fn +@custom_column_generator(required_columns=["seed_id"]) +def _resize_cell_expand(row: dict) -> list[dict]: + """CELL_BY_CELL: one row -> two rows (doubled).""" + return [ + {**row, "doubled": f"{row['seed_id']}_a"}, + {**row, "doubled": f"{row['seed_id']}_b"}, + ] @custom_column_generator(required_columns=["seed_id"]) -def _resize_full_keep_first(df: pd.DataFrame) -> pd.DataFrame: - """FULL_COLUMN: keep first row per seed_id (retraction).""" - return df.drop_duplicates(subset="seed_id").assign(filtered=True) +def _resize_cell_filter_odd(row: dict) -> dict | list[dict]: + """CELL_BY_CELL: drop even seed_id, keep odd.""" + if row["seed_id"] % 2 == 0: + return [] + return {**row, "kept": row["seed_id"]} + +@custom_column_generator(required_columns=["seed_id"]) +def _resize_cell_drop_all(row: dict) -> list[dict]: + """CELL_BY_CELL: return [] for every row (drop all).""" + return [] -FULL = GenerationStrategy.FULL_COLUMN -CELL = GenerationStrategy.CELL_BY_CELL _RESIZE_SPECS: dict[str, list[tuple[str, object, GenerationStrategy]]] = { - "single_full_x3": [("expanded", _make_resize_full_expand(3, "expanded", "copy"), FULL)], - "single_cell_x2": [("doubled", _make_resize_cell_expand(2, "doubled"), CELL)], + "cell_filter_odd": [("kept", _resize_cell_filter_odd, GenerationStrategy.CELL_BY_CELL)], + "cell_x2": [("doubled", _resize_cell_expand, GenerationStrategy.CELL_BY_CELL)], + "cell_drop_all": [("dropped", _resize_cell_drop_all, GenerationStrategy.CELL_BY_CELL)], + "full_x3": [("expanded", _make_resize_full_expand(3, "expanded", "copy"), GenerationStrategy.FULL_COLUMN)], "full_chain": [ - ("expanded", _make_resize_full_expand(3, "expanded", "copy"), FULL), - ("filtered", _resize_full_keep_first, FULL), - ("expanded_again", _make_resize_full_expand(3, "expanded_again", "copy2"), FULL), + ("expanded", _make_resize_full_expand(3, "expanded", "copy"), GenerationStrategy.FULL_COLUMN), + ("filtered", _resize_full_keep_first, GenerationStrategy.FULL_COLUMN), + ("expanded_again", _make_resize_full_expand(3, "expanded_again", "copy2"), GenerationStrategy.FULL_COLUMN), ], - "cell_then_full_chain": [ - ("doubled", _make_resize_cell_expand(2, "doubled"), CELL), - ("filtered", _resize_full_keep_first, FULL), - ("expanded_again", _make_resize_full_expand(3, "expanded_again", "copy2"), FULL), + "cell_plus_full_chain": [ + ("doubled", _resize_cell_expand, GenerationStrategy.CELL_BY_CELL), + ("filtered", _resize_full_keep_first, GenerationStrategy.FULL_COLUMN), + ("expanded_again", _make_resize_full_expand(3, "expanded_again", "copy2"), GenerationStrategy.FULL_COLUMN), ], } @@ -644,12 +658,21 @@ def _build_resize_builder(stub_resource_provider, stub_model_configs, seed_data_ @pytest.mark.parametrize( "spec,num_records,expected_len,check_doubled_order", [ - ("single_full_x3", 5, 15, False), - ("single_cell_x2", 5, 10, True), + ("cell_filter_odd", 5, 3, False), + ("cell_x2", 5, 10, True), + ("cell_drop_all", 5, 0, False), + ("full_x3", 5, 15, False), ("full_chain", 5, 15, False), - ("cell_then_full_chain", 5, 15, False), + ("cell_plus_full_chain", 5, 15, False), + ], + ids=[ + "cell_filter_odd_preview", + "cell_x2_preview", + "cell_drop_all_preview", + "full_x3_preview", + "full_chain_preview", + "cell_plus_full_chain_preview", ], - ids=["single_full_x3", "single_cell_x2", "full_chain", "cell_then_full_chain"], ) def test_allow_resize_preview( stub_resource_provider, @@ -673,16 +696,20 @@ def test_allow_resize_preview( @pytest.mark.parametrize( "spec,num_records,buffer_size,expected_total_rows", [ - ("single_full_x3", 5, 2, 15), # batches [2,2,1] -> each x3 -> 6+6+3 - ("single_cell_x2", 5, 2, 10), # batches [2,2,1] -> each x2 -> 4+4+2 + ("cell_x2", 5, 2, 10), # batches [2,2,1] -> each x2 -> 4+4+2 + ("cell_filter_odd", 5, 2, 3), # batches [2,2,1] -> keep odd -> 1+1+1 + ("cell_drop_all", 5, 2, 0), # each batch -> 0 rows + ("full_x3", 5, 2, 15), # batches [2,2,1] -> each x3 -> 6+6+3 + ("full_x3", 4, 2, 12), # batches [2,2] -> 6+6 ("full_chain", 5, 2, 15), # batches [2,2,1] -> x3, dedup, x3 -> 15 - ("single_full_x3", 4, 2, 12), # batches [2,2] -> 6+6 ], ids=[ - "single_full_x3_multibatch", - "single_cell_x2_multibatch", + "cell_x2_multibatch", + "cell_filter_odd_multibatch", + "cell_drop_all_multibatch", + "full_x3_multibatch_5_2", + "full_x3_multibatch_4_2", "full_chain_multibatch", - "single_full_x3_4_2_multibatch", ], ) def test_allow_resize_multiple_batches( @@ -699,5 +726,9 @@ def test_allow_resize_multiple_batches( columns = _resize_columns(spec) builder = _build_resize_builder(stub_resource_provider, stub_model_configs, seed_data_setup, columns) builder.build(num_records=num_records) - df = pd.read_parquet(builder.artifact_storage.final_dataset_path) + final_path = builder.artifact_storage.final_dataset_path + if expected_total_rows == 0 and not final_path.exists(): + df = pd.DataFrame() + else: + df = pd.read_parquet(final_path) assert len(df) == expected_total_rows From ab8ce119d48cb9f4903c070f5d75baef464b2b8c Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 12 Feb 2026 15:35:28 -0300 Subject: [PATCH 7/8] tidy allow_resize: drop validator, shared stub, explicit flag - Remove validate_allow_resize_requires_full_column from CustomColumnConfig - Rename StubColumnConfigWithoutEmoji to StubColumnConfig in test_columns - Pass allow_resize=False in _write_processed_batch replace_buffer call --- .../src/data_designer/config/column_configs.py | 12 ------------ .../tests/config/test_columns.py | 10 +++++----- .../engine/dataset_builders/column_wise_builder.py | 2 +- 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index e352ac4c..b2eefd26 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -560,15 +560,3 @@ def validate_generator_function(self) -> Self: f"Expected a function decorated with @custom_column_generator." ) return self - - @model_validator(mode="after") - def validate_allow_resize_requires_full_column(self) -> Self: - if self.allow_resize and self.generation_strategy not in ( - GenerationStrategy.FULL_COLUMN, - GenerationStrategy.CELL_BY_CELL, - ): - raise InvalidConfigError( - f"🛑 `allow_resize=True` requires `generation_strategy` 'full_column' or 'cell_by_cell' " - f"for column '{self.name}'." - ) - return self diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index a369e9d3..20025094 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -516,8 +516,8 @@ def test_sampler_column_config_discriminated_union_wrong_params_type(): ) -class StubColumnConfigWithoutEmoji(SingleColumnConfig): - column_type: Literal["stub-without-emoji"] = "stub-without-emoji" +class StubColumnConfig(SingleColumnConfig): + column_type: Literal["stub"] = "stub" @property def required_columns(self) -> list[str]: @@ -530,10 +530,10 @@ def side_effect_columns(self) -> list[str]: def test_default_column_emoji_for_custom_column_type() -> None: """Ensure the base get_column_emoji implementation is used when not overridden.""" - assert StubColumnConfigWithoutEmoji.get_column_emoji() == "🎨" + assert StubColumnConfig.get_column_emoji() == "🎨" def test_allow_resize_inherited_by_subclasses() -> None: """Subclasses inherit allow_resize from SingleColumnConfig.""" - assert StubColumnConfigWithoutEmoji(name="test").allow_resize is False - assert StubColumnConfigWithoutEmoji(name="test", allow_resize=True).allow_resize is True + assert StubColumnConfig(name="test").allow_resize is False + assert StubColumnConfig(name="test", allow_resize=True).allow_resize is True diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 705778f4..7e8fc0fa 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -346,7 +346,7 @@ def callback(exc: Exception, *, context: dict | None = None) -> None: return callback def _write_processed_batch(self, dataframe: pd.DataFrame) -> None: - self.batch_manager.replace_buffer(dataframe.to_dict(orient="records")) + self.batch_manager.replace_buffer(dataframe.to_dict(orient="records"), allow_resize=False) self.batch_manager.write() def _validate_column_configs(self) -> None: From 8285b48a07cbfd2e0ca173b13718c859d476b0e0 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 12 Feb 2026 15:38:06 -0300 Subject: [PATCH 8/8] fix: add missing f prefix to error message in custom.py --- .../data_designer/engine/column_generators/generators/custom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index ba5d4bd8..5e7a98ce 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -129,7 +129,7 @@ def _validate_cell_output(self, row: dict, keys_before: set[str]) -> dict: if self.config.name not in result_keys: raise CustomColumnGenerationError( f"Custom generator for column '{self.config.name}' did not create the expected column. " - "The generator_function must add a key named '{self.config.name}' to the row." + f"The generator_function must add a key named '{self.config.name}' to the row." ) missing = set(self.config.side_effect_columns) - result_keys if missing: