diff --git a/docs/concepts/custom_columns.md b/docs/concepts/custom_columns.md index d080bd70..98344485 100644 --- a/docs/concepts/custom_columns.md +++ b/docs/concepts/custom_columns.md @@ -93,6 +93,58 @@ This gives you direct access to all `ModelFacade` capabilities: custom parsers, | `generator_function` | Callable | Yes | Decorated function | | `generation_strategy` | GenerationStrategy | No | `CELL_BY_CELL` or `FULL_COLUMN` | | `generator_params` | BaseModel | No | Typed params passed to function | +| `allow_resize` | bool | No | Allow 1:N or N:1 generation | + +### Resizing (1:N and N:1) + +**FULL_COLUMN:** Set `allow_resize=True` and return a DataFrame with more or fewer rows than the input: + +```python +@dd.custom_column_generator( + required_columns=["topic"], + side_effect_columns=["variation_id"], +) +def expand_topics(df: pd.DataFrame, params: None, models: dict) -> pd.DataFrame: + rows = [] + for _, row in df.iterrows(): + for i in range(3): # Generate 3 variations per input + rows.append({ + "topic": row["topic"], + "question": f"Question {i+1} about {row['topic']}", + "variation_id": i, + }) + return pd.DataFrame(rows) + +dd.CustomColumnConfig( + name="question", + generator_function=expand_topics, + generation_strategy=dd.GenerationStrategy.FULL_COLUMN, + allow_resize=True, +) +``` + +**CELL_BY_CELL:** With `allow_resize=True`, your function may return a single row (`dict`) or multiple rows (`list[dict]`). Return `[]` to drop that input row. + +```python +@dd.custom_column_generator(required_columns=["id"]) +def expand_row(row: dict) -> list[dict]: + return [ + {**row, "variant": "a"}, + {**row, "variant": "b"}, + ] + +dd.CustomColumnConfig( + name="variant", + generator_function=expand_row, + generation_strategy=dd.GenerationStrategy.CELL_BY_CELL, + allow_resize=True, +) +``` + +Use cases: + +- **Expansion (1:N)**: Generate multiple variations per input +- **Retraction (N:1)**: Filter, aggregate, or deduplicate records (FULL_COLUMN) or return `[]` per row (CELL_BY_CELL) ## Multi-Turn Example diff --git a/docs/plugins/example.md b/docs/plugins/example.md index 273fee3c..4f9e561f 100644 --- a/docs/plugins/example.md +++ b/docs/plugins/example.md @@ -82,6 +82,17 @@ class IndexMultiplierColumnConfig(SingleColumnConfig): - `required_columns` lists any columns this generator depends on (empty if none) - `side_effect_columns` lists any additional columns this generator produces beyond the primary column (empty if none) +**If your plugin can expand or retract the number of rows (1:N or N:1):** set `allow_resize=True` in the config class so the pipeline updates batch bookkeeping correctly. For example: + +```python +class MyColumnConfig(SingleColumnConfig): + column_type: Literal["my-plugin"] = "my-plugin" + allow_resize: bool = True # required when output row count can differ from input + # ... +``` + +The default is `False`; only set it to `True` when your `generate` method can return more or fewer rows than it receives. + ### Step 3: Create the implementation class The implementation class defines the actual business logic of the plugin. For column generator plugins, inherit from `ColumnGeneratorFullColumn` or `ColumnGeneratorCellByCell` and implement the `generate` method. diff --git a/example_allow_resize.py b/example_allow_resize.py new file mode 100644 index 00000000..43366429 --- /dev/null +++ b/example_allow_resize.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Example: Chaining expand -> retract -> expand resize operations. + +Pipeline: 5 topics -> 15 questions (3 per topic) -> ~8 hard questions (filter easy) + -> ~24 answer variants (3 per question) +""" + +from __future__ import annotations + +import data_designer.config as dd +from data_designer.interface import DataDesigner +from data_designer.lazy_heavy_imports import pd + + +# Step 1: Expand — 1:N, generate 3 questions per topic +@dd.custom_column_generator(required_columns=["topic"], side_effect_columns=["question_id", "difficulty"]) +def expand_to_questions(df: pd.DataFrame) -> pd.DataFrame: + rows = [] + for _, row in df.iterrows(): + for i in range(3): + rows.append( + { + "topic": row["topic"], + "question": f"Q{i + 1} about {row['topic']}?", + "question_id": i, + "difficulty": ["easy", "medium", "hard"][i], + } + ) + return pd.DataFrame(rows) + + +# Step 2: Retract — N:1, keep only medium/hard questions +@dd.custom_column_generator(required_columns=["difficulty"]) +def filter_non_easy(df: pd.DataFrame) -> pd.DataFrame: + return df[df["difficulty"] != "easy"].copy().assign(filtered=True) + + +# Step 3: Expand again — 1:N, generate 3 answer variants per surviving question +@dd.custom_column_generator(required_columns=["question"], side_effect_columns=["variant"]) +def expand_to_answers(df: pd.DataFrame) -> pd.DataFrame: + rows = [] + for _, row in df.iterrows(): + for v in range(3): + rows.append({**row.to_dict(), "answer": f"Answer v{v} to: {row['question']}", "variant": v}) + return pd.DataFrame(rows) + + +def main() -> None: + data_designer = DataDesigner() + config_builder = dd.DataDesignerConfigBuilder() + + # Seed: 5 topics + config_builder.add_column( + dd.SamplerColumnConfig( + name="topic", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams(values=["Python", "ML", "Data", "Stats", "SQL"]), + ) + ) + + # Expand: 5 topics -> 15 questions + config_builder.add_column( + dd.CustomColumnConfig( + name="question", + generator_function=expand_to_questions, + generation_strategy=dd.GenerationStrategy.FULL_COLUMN, + allow_resize=True, + ) + ) + + # Retract: 15 -> 10 (drop "easy" questions) + config_builder.add_column( + dd.CustomColumnConfig( + name="filtered", + generator_function=filter_non_easy, + generation_strategy=dd.GenerationStrategy.FULL_COLUMN, + allow_resize=True, + ) + ) + + # Expand again: 10 -> 30 answer variants + config_builder.add_column( + dd.CustomColumnConfig( + name="answer", + generator_function=expand_to_answers, + generation_strategy=dd.GenerationStrategy.FULL_COLUMN, + allow_resize=True, + ) + ) + + # Preview (single batch) + preview = data_designer.preview(config_builder=config_builder, num_records=5) + print(f"Preview: 5 topics -> {len(preview.dataset)} answer variants") + print(preview.dataset[["topic", "difficulty", "question", "variant", "answer"]].to_string()) + print() + + # Build (multiple batches: 10 records with buffer_size=3 -> 4 batches) + data_designer.set_run_config(dd.RunConfig(buffer_size=3)) + results = data_designer.create(config_builder=config_builder, num_records=10) + df = results.load_dataset() + print(f"Build: 10 topics (4 batches of 3+3+3+1) -> {len(df)} answer variants") + print(df[["topic", "difficulty", "question", "variant"]].to_string()) + + +if __name__ == "__main__": + main() diff --git a/packages/data-designer-config/src/data_designer/config/base.py b/packages/data-designer-config/src/data_designer/config/base.py index 3bf5b6fa..bcf4ac6d 100644 --- a/packages/data-designer-config/src/data_designer/config/base.py +++ b/packages/data-designer-config/src/data_designer/config/base.py @@ -37,6 +37,7 @@ class SingleColumnConfig(ConfigBase, ABC): name: str drop: bool = False + allow_resize: bool = False column_type: str @staticmethod diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index e633518d..1fadf277 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -517,19 +517,24 @@ def test_sampler_column_config_discriminated_union_wrong_params_type(): ) -def test_default_column_emoji_for_custom_column_type() -> None: - """Ensure the base get_column_emoji implementation is used when not overridden.""" +class StubColumnConfig(SingleColumnConfig): + column_type: Literal["stub"] = "stub" + + @property + def required_columns(self) -> list[str]: + return [] - class StubColumnConfigWithoutEmoji(SingleColumnConfig): - column_type: Literal["stub-without-emoji"] = "stub-without-emoji" - value: str + @property + def side_effect_columns(self) -> list[str]: + return [] - @property - def required_columns(self) -> list[str]: - return [] - @property - def side_effect_columns(self) -> list[str]: - return [] +def test_default_column_emoji_for_custom_column_type() -> None: + """Ensure the base get_column_emoji implementation is used when not overridden.""" + assert StubColumnConfig.get_column_emoji() == "🎨" + - assert StubColumnConfigWithoutEmoji.get_column_emoji() == "🎨" +def test_allow_resize_inherited_by_subclasses() -> None: + """Subclasses inherit allow_resize from SingleColumnConfig.""" + assert StubColumnConfig(name="test").allow_resize is False + assert StubColumnConfig(name="test", allow_resize=True).allow_resize is True diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index f0a942e7..5e7a98ce 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -43,8 +43,11 @@ def get_generation_strategy(self) -> GenerationStrategy: """Return strategy based on config.""" return self.config.generation_strategy - def generate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame: - """Generate column value(s) for a row (dict) or batch (DataFrame).""" + def generate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | list[dict]: + """Generate column value(s) for a row (dict) or batch (DataFrame). + + For cell_by_cell with allow_resize=True, may return dict or list[dict] (0, 1, or N rows). + """ is_full_column = self.config.generation_strategy == GenerationStrategy.FULL_COLUMN is_dataframe = not isinstance(data, dict) @@ -62,7 +65,7 @@ def generate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame: return self._generate(data, is_dataframe) - def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.DataFrame: + def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.DataFrame | list[dict]: """Unified generation logic for both strategies.""" # Get columns/keys using unified accessor get_keys = (lambda d: set(d.columns)) if is_dataframe else (lambda d: set(d.keys())) @@ -93,7 +96,23 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd. f"Custom generator function failed for column '{self.config.name}': {e}" ) from e - # Validate return type + # Cell-by-cell with allow_resize: accept dict or list[dict] + if not is_dataframe and getattr(self.config, "allow_resize", False): + if isinstance(result, list): + if not all(isinstance(r, dict) for r in result): + raise CustomColumnGenerationError( + f"Custom generator for column '{self.config.name}' with allow_resize must return " + "dict or list[dict]; list elements must be dicts." + ) + return [self._validate_cell_output(r, keys_before) for r in result] + if isinstance(result, dict): + return self._validate_output(result, keys_before, is_dataframe) + raise CustomColumnGenerationError( + f"Custom generator for column '{self.config.name}' with allow_resize must return " + f"dict or list[dict], got {type(result).__name__}" + ) + + # Validate return type for non-resize paths if not isinstance(result, expected_type): raise CustomColumnGenerationError( f"Custom generator for column '{self.config.name}' must return a {type_name}, " @@ -102,6 +121,38 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd. return self._validate_output(result, keys_before, is_dataframe) + def _validate_cell_output(self, row: dict, keys_before: set[str]) -> dict: + """Validate a single row output (dict) for cell_by_cell; strip undeclared columns.""" + expected_new = {self.config.name} | set(self.config.side_effect_columns) + result_keys = set(row.keys()) + + if self.config.name not in result_keys: + raise CustomColumnGenerationError( + f"Custom generator for column '{self.config.name}' did not create the expected column. " + f"The generator_function must add a key named '{self.config.name}' to the row." + ) + missing = set(self.config.side_effect_columns) - result_keys + if missing: + raise CustomColumnGenerationError( + f"Custom generator for column '{self.config.name}' did not create declared side_effect_columns: " + f"{sorted(missing)}. Declared side_effect_columns must be added to the row." + ) + removed = keys_before - result_keys + if removed: + raise CustomColumnGenerationError( + f"Custom generator for column '{self.config.name}' removed pre-existing columns: " + f"{sorted(removed)}. The generator_function must not remove any existing columns." + ) + undeclared = (result_keys - keys_before) - expected_new + if undeclared: + logger.warning( + f"⚠️ Custom generator for column '{self.config.name}' created undeclared columns: " + f"{sorted(undeclared)}. These columns will be removed. " + f"To keep additional columns, declare them in @custom_column_generator(side_effect_columns=[...])." + ) + row = {k: v for k, v in row.items() if k not in undeclared} + return row + def _validate_output( self, result: dict | pd.DataFrame, keys_before: set[str], is_dataframe: bool ) -> dict | pd.DataFrame: @@ -147,8 +198,7 @@ def _validate_output( if is_dataframe: result = result.drop(columns=list(undeclared)) else: - for key in undeclared: - del result[key] + result = self._validate_cell_output(result, keys_before) return result @@ -199,3 +249,5 @@ def log_pre_generation(self) -> None: logger.info(f"{LOG_INDENT}model_aliases: {self.config.model_aliases}") if self.config.generator_params: logger.info(f"{LOG_INDENT}generator_params: {self.config.generator_params}") + if self.config.allow_resize: + logger.info(f"{LOG_INDENT}allow_resize: {self.config.allow_resize}") diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 9b0d91b8..9c7eaa5e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -64,6 +64,8 @@ def __init__( self.batch_manager = DatasetBatchManager(resource_provider.artifact_storage) self._resource_provider = resource_provider self._records_to_drop: set[int] = set() + self._cell_resize_results: list[dict | list[dict] | None] = [] + self._cell_resize_mode = False self._registry = registry or DataDesignerRegistry() self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider) @@ -257,9 +259,24 @@ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None: max_workers = generator.inference_parameters.max_parallel_requests self._fan_out_with_threads(generator, max_workers=max_workers) + def _column_display_name(self, config: ColumnConfigT) -> str: + return f"columns {config.column_names}" if hasattr(config, "column_names") else config.name + + def _log_resize_if_changed(self, column_name: str, original_count: int, new_count: int, allow_resize: bool) -> None: + if not allow_resize or new_count == original_count: + return + if new_count == 0: + logger.warning(f"⚠️ Column '{column_name}' reduced batch to 0 records. This batch will be skipped.") + else: + emoji = "💥" if new_count > original_count else "✂️" + logger.info(f"{emoji} Column '{column_name}' resized batch: {original_count} -> {new_count} records.") + def _run_full_column_generator(self, generator: ColumnGenerator) -> None: + original_count = self.batch_manager.num_records_in_buffer df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True)) - self.batch_manager.update_records(df.to_dict(orient="records")) + allow_resize = getattr(generator.config, "allow_resize", False) + self._log_resize_if_changed(self._column_display_name(generator.config), original_count, len(df), allow_resize) + self.batch_manager.replace_buffer(df.to_dict(orient="records"), allow_resize=allow_resize) def _run_model_health_check_if_needed(self) -> None: model_aliases: set[str] = set() @@ -292,6 +309,11 @@ def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max if getattr(generator.config, "tool_alias", None): logger.info("🛠️ Tool calling enabled") + allow_resize = getattr(generator.config, "allow_resize", False) + if allow_resize: + self._cell_resize_results = [None] * self.batch_manager.num_records_batch + self._cell_resize_mode = True + progress_tracker = ProgressTracker( total_records=self.batch_manager.num_records_batch, label=f"{generator.config.column_type} column '{generator.config.name}'", @@ -313,7 +335,27 @@ def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max progress_tracker.log_final() - if len(self._records_to_drop) > 0: + if allow_resize: + # Flatten results in index order; skip indices in _records_to_drop (failed cells), + # so those rows are omitted from the new buffer. + new_records: list[dict] = [] + for i in range(len(self._cell_resize_results)): + if i in self._records_to_drop: + continue + r = self._cell_resize_results[i] + if r is not None: + new_records.extend(r if isinstance(r, list) else [r]) + self._log_resize_if_changed( + self._column_display_name(generator.config), + self.batch_manager.num_records_in_buffer, + len(new_records), + True, + ) + self.batch_manager.replace_buffer(new_records, allow_resize=True) + self._records_to_drop.clear() + self._cell_resize_mode = False + self._cell_resize_results = [] + elif len(self._records_to_drop) > 0: self._cleanup_dropped_record_images(self._records_to_drop) self.batch_manager.drop_records(self._records_to_drop) self._records_to_drop.clear() @@ -333,7 +375,7 @@ def callback(exc: Exception, *, context: dict | None = None) -> None: return callback def _write_processed_batch(self, dataframe: pd.DataFrame) -> None: - self.batch_manager.update_records(dataframe.to_dict(orient="records")) + self.batch_manager.replace_buffer(dataframe.to_dict(orient="records"), allow_resize=False) self.batch_manager.write() def _validate_column_configs(self) -> None: @@ -410,8 +452,11 @@ def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) ) self._records_to_drop.add(context["index"]) - def _worker_result_callback(self, result: dict, *, context: dict | None = None) -> None: - self.batch_manager.update_record(context["index"], result) + def _worker_result_callback(self, result: dict | list[dict], *, context: dict | None = None) -> None: + if self._cell_resize_mode: + self._cell_resize_results[context["index"]] = result + else: + self.batch_manager.update_record(context["index"], result) def _emit_batch_inference_events( self, batch_mode: str, usage_deltas: dict[str, ModelUsageStats], group_id: str diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py index efb60b8b..1a8b41c4 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py @@ -25,6 +25,7 @@ def __init__(self, artifact_storage: ArtifactStorage): self._current_batch_number = 0 self._num_records_list: list[int] | None = None self._buffer_size: int | None = None + self._actual_num_records: int = 0 self.artifact_storage = artifact_storage @property @@ -83,11 +84,13 @@ def finish_batch(self, on_complete: Callable[[Path], None] | None = None) -> Pat raise DatasetBatchManagementError("🛑 All batches have been processed.") if self.write() is not None: + self._actual_num_records += len(self._buffer) final_file_path = self.artifact_storage.move_partial_result_to_final_file_path(self._current_batch_number) self.artifact_storage.write_metadata( { "target_num_records": sum(self.num_records_list), + "actual_num_records": self._actual_num_records, "total_num_batches": self.num_batches, "buffer_size": self._buffer_size, "schema": {field.name: str(field.type) for field in pq.read_schema(final_file_path)}, @@ -141,6 +144,7 @@ def iter_current_batch(self) -> Iterator[tuple[int, dict]]: def reset(self, delete_files: bool = False) -> None: self._current_batch_number = 0 self._buffer: list[dict] = [] + self._actual_num_records = 0 if delete_files: for dir_path in [ self.artifact_storage.final_dataset_path, @@ -191,16 +195,18 @@ def update_record(self, index: int, record: dict) -> None: raise IndexError(f"🛑 Index {index} is out of bounds for buffer of size {len(self._buffer)}.") self._buffer[index] = record - def update_records(self, records: list[dict]) -> None: - if len(records) != len(self._buffer): + def replace_buffer(self, records: list[dict], *, allow_resize: bool = False) -> None: + """Replace the buffer contents. + + Args: + records: New records to replace the buffer. + allow_resize: If True, allows the number of records to differ from the current + buffer size (1:N or N:1 patterns). Defaults to False for strict 1:1 mapping. + """ + if not allow_resize and len(records) != len(self._buffer): raise DatasetBatchManagementError( - f"🛑 Number of records to update ({len(records)}) must match " - f"the number of records in the buffer ({len(self._buffer)})." + f"🛑 Number of records ({len(records)}) must match the current buffer size ({len(self._buffer)})." ) self._buffer = records - - def replace_buffer(self, records: list[dict]) -> None: - """Replace the buffer contents, updating the current batch size.""" - self._buffer = records - if self._num_records_list is not None: + if allow_resize and self._num_records_list is not None: self._num_records_list[self._current_batch_number] = len(records) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py index c78ee2b3..8163c807 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py @@ -70,7 +70,7 @@ def run_pre_batch(self, batch_manager: DatasetBatchManager) -> None: df = batch_manager.get_current_batch(as_dataframe=True) df = self._run_stage(df, ProcessorStage.PRE_BATCH) - batch_manager.replace_buffer(df.to_dict(orient="records")) + batch_manager.replace_buffer(df.to_dict(orient="records"), allow_resize=True) def run_post_batch(self, df: pd.DataFrame, current_batch_number: int | None) -> pd.DataFrame: """Run process_after_batch() on processors that implement it.""" diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py index 383636ba..f3d36219 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py @@ -113,6 +113,23 @@ def test_config_validation_non_callable() -> None: CustomColumnConfig(name="test", generator_function="not_a_function") +def test_config_validation_allow_resize_allows_full_column_and_cell_by_cell() -> None: + """allow_resize=True is valid with full_column or cell_by_cell.""" + + @custom_column_generator() + def dummy_fn(row: dict) -> dict: + return row + + for strategy in (GenerationStrategy.FULL_COLUMN, GenerationStrategy.CELL_BY_CELL): + config = CustomColumnConfig( + name="test", + generator_function=dummy_fn, + allow_resize=True, + generation_strategy=strategy, + ) + assert config.allow_resize is True + + # Cell-by-cell generation tests @@ -160,6 +177,101 @@ def test_side_effect_columns() -> None: assert result["secondary"] == 15 +# cell_by_cell allow_resize: dict | list[dict] + + +def test_cell_by_cell_allow_resize_return_dict() -> None: + """With allow_resize, returning a single dict (1:1) works like normal cell-by-cell.""" + config = CustomColumnConfig( + name="result", + generator_function=generator_with_required_columns, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + allow_resize=True, + ) + generator = CustomColumnGenerator(config=config, resource_provider=Mock(spec=ResourceProvider)) + result = generator.generate({"input": "hi"}) + assert isinstance(result, dict) + assert result["result"] == "HI" + + +def test_cell_by_cell_allow_resize_return_list_expand() -> None: + """With allow_resize, returning list[dict] expands one row into multiple.""" + + @custom_column_generator(required_columns=["x"]) + def expand(row: dict) -> list[dict]: + return [ + {**row, "out": row["x"] * 1}, + {**row, "out": row["x"] * 2}, + ] + + config = CustomColumnConfig( + name="out", + generator_function=expand, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + allow_resize=True, + ) + generator = CustomColumnGenerator(config=config, resource_provider=Mock(spec=ResourceProvider)) + result = generator.generate({"x": 10}) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0] == {"x": 10, "out": 10} + assert result[1] == {"x": 10, "out": 20} + + +def test_cell_by_cell_allow_resize_return_list_single() -> None: + """With allow_resize, returning [dict] (1:1 via list) is valid.""" + + @custom_column_generator(required_columns=["x"]) + def one_row(row: dict) -> list[dict]: + return [{**row, "out": row["x"]}] + + config = CustomColumnConfig( + name="out", + generator_function=one_row, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + allow_resize=True, + ) + generator = CustomColumnGenerator(config=config, resource_provider=Mock(spec=ResourceProvider)) + result = generator.generate({"x": 42}) + assert result == [{"x": 42, "out": 42}] + + +def test_cell_by_cell_allow_resize_return_empty_list() -> None: + """With allow_resize, returning [] drops that row (0 rows).""" + + @custom_column_generator(required_columns=["x"]) + def drop(row: dict) -> list[dict]: + return [] + + config = CustomColumnConfig( + name="out", + generator_function=drop, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + allow_resize=True, + ) + generator = CustomColumnGenerator(config=config, resource_provider=Mock(spec=ResourceProvider)) + result = generator.generate({"x": 1}) + assert result == [] + + +def test_cell_by_cell_allow_resize_invalid_return_type() -> None: + """With allow_resize, return must be dict or list[dict].""" + + @custom_column_generator(required_columns=["x"]) + def bad_return(row: dict): + return [1, 2] + + config = CustomColumnConfig( + name="out", + generator_function=bad_return, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + allow_resize=True, + ) + generator = CustomColumnGenerator(config=config, resource_provider=Mock(spec=ResourceProvider)) + with pytest.raises(CustomColumnGenerationError, match="list elements must be dicts"): + generator.generate({"x": 1}) + + # Error handling tests diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py index c512871a..7005df44 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py @@ -64,7 +64,7 @@ def stub_batch_manager(): mock_batch_manager.finish = Mock() mock_batch_manager.write = Mock() mock_batch_manager.add_records = Mock() - mock_batch_manager.update_records = Mock() + mock_batch_manager.replace_buffer = Mock() mock_batch_manager.update_record = Mock() mock_batch_manager.get_current_batch = Mock() mock_batch_manager.get_current_batch.side_effect = [ @@ -562,3 +562,173 @@ def test_process_preview_with_empty_dataframe(simple_builder): assert len(result) == 0 mock_processor.process_after_batch.assert_called_once() mock_processor.process_after_generation.assert_called_once() + + +# allow_resize integration tests +# +# Factory: _make_resize_full_expand. Stubs: _resize_full_keep_first, _resize_cell_*. + + +def _make_resize_full_expand(n: int, primary_col: str, side_effect_col: str): + """CELL_BY_CELL: expand n times per seed_id.""" + + @custom_column_generator(required_columns=["seed_id"], side_effect_columns=[side_effect_col]) + def fn(df: pd.DataFrame) -> pd.DataFrame: + rows = [] + for _, row in df.iterrows(): + for i in range(n): + rows.append({**row.to_dict(), primary_col: f"{row['seed_id']}_v{i}", side_effect_col: i}) + return pd.DataFrame(rows) + + return fn + + +@custom_column_generator(required_columns=["seed_id"]) +def _resize_full_keep_first(df: pd.DataFrame) -> pd.DataFrame: + """FULL_COLUMN: keep first row per seed_id (retraction).""" + return df.drop_duplicates(subset="seed_id").assign(filtered=True) + + +@custom_column_generator(required_columns=["seed_id"]) +def _resize_cell_expand(row: dict) -> list[dict]: + """CELL_BY_CELL: one row -> two rows (doubled).""" + return [ + {**row, "doubled": f"{row['seed_id']}_a"}, + {**row, "doubled": f"{row['seed_id']}_b"}, + ] + + +@custom_column_generator(required_columns=["seed_id"]) +def _resize_cell_filter_odd(row: dict) -> dict | list[dict]: + """CELL_BY_CELL: drop even seed_id, keep odd.""" + if row["seed_id"] % 2 == 0: + return [] + return {**row, "kept": row["seed_id"]} + + +@custom_column_generator(required_columns=["seed_id"]) +def _resize_cell_drop_all(row: dict) -> list[dict]: + """CELL_BY_CELL: return [] for every row (drop all).""" + return [] + + +_RESIZE_SPECS: dict[str, list[tuple[str, object, GenerationStrategy]]] = { + "cell_filter_odd": [("kept", _resize_cell_filter_odd, GenerationStrategy.CELL_BY_CELL)], + "cell_x2": [("doubled", _resize_cell_expand, GenerationStrategy.CELL_BY_CELL)], + "cell_drop_all": [("dropped", _resize_cell_drop_all, GenerationStrategy.CELL_BY_CELL)], + "full_x3": [("expanded", _make_resize_full_expand(3, "expanded", "copy"), GenerationStrategy.FULL_COLUMN)], + "full_chain": [ + ("expanded", _make_resize_full_expand(3, "expanded", "copy"), GenerationStrategy.FULL_COLUMN), + ("filtered", _resize_full_keep_first, GenerationStrategy.FULL_COLUMN), + ("expanded_again", _make_resize_full_expand(3, "expanded_again", "copy2"), GenerationStrategy.FULL_COLUMN), + ], + "cell_plus_full_chain": [ + ("doubled", _resize_cell_expand, GenerationStrategy.CELL_BY_CELL), + ("filtered", _resize_full_keep_first, GenerationStrategy.FULL_COLUMN), + ("expanded_again", _make_resize_full_expand(3, "expanded_again", "copy2"), GenerationStrategy.FULL_COLUMN), + ], +} + + +def _resize_columns(spec: str) -> list[CustomColumnConfig]: + """Return column configs for a given allow_resize recipe.""" + return [ + CustomColumnConfig( + name=name, + generator_function=fn, + generation_strategy=strat, + allow_resize=True, + ) + for name, fn, strat in _RESIZE_SPECS[spec] + ] + + +def _build_resize_builder(stub_resource_provider, stub_model_configs, seed_data_setup, columns): + """Build a ColumnWiseDatasetBuilder with the given resize column configs.""" + config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"]))) + for col in columns: + config_builder.add_column(col) + return ColumnWiseDatasetBuilder( + data_designer_config=config_builder.build(), + resource_provider=stub_resource_provider, + ) + + +@pytest.mark.parametrize( + "spec,num_records,expected_len,check_doubled_order", + [ + ("cell_filter_odd", 5, 3, False), + ("cell_x2", 5, 10, True), + ("cell_drop_all", 5, 0, False), + ("full_x3", 5, 15, False), + ("full_chain", 5, 15, False), + ("cell_plus_full_chain", 5, 15, False), + ], + ids=[ + "cell_filter_odd_preview", + "cell_x2_preview", + "cell_drop_all_preview", + "full_x3_preview", + "full_chain_preview", + "cell_plus_full_chain_preview", + ], +) +def test_allow_resize_preview( + stub_resource_provider, + stub_model_configs, + seed_data_setup, + spec, + num_records, + expected_len, + check_doubled_order, +): + """Preview with allow_resize columns (FULL_COLUMN and/or CELL_BY_CELL) yields expected length.""" + columns = _resize_columns(spec) + builder = _build_resize_builder(stub_resource_provider, stub_model_configs, seed_data_setup, columns) + result = builder.build_preview(num_records=num_records) + assert len(result) == expected_len + if check_doubled_order: + expected = [x for i in range(1, 6) for x in (f"{i}_a", f"{i}_b")] + assert result["doubled"].tolist() == expected + + +@pytest.mark.parametrize( + "spec,num_records,buffer_size,expected_total_rows", + [ + ("cell_x2", 5, 2, 10), # batches [2,2,1] -> each x2 -> 4+4+2 + ("cell_filter_odd", 5, 2, 3), # batches [2,2,1] -> keep odd -> 1+1+1 + ("cell_drop_all", 5, 2, 0), # each batch -> 0 rows + ("full_x3", 5, 2, 15), # batches [2,2,1] -> each x3 -> 6+6+3 + ("full_x3", 4, 2, 12), # batches [2,2] -> 6+6 + ("full_chain", 5, 2, 15), # batches [2,2,1] -> x3, dedup, x3 -> 15 + ], + ids=[ + "cell_x2_multibatch", + "cell_filter_odd_multibatch", + "cell_drop_all_multibatch", + "full_x3_multibatch_5_2", + "full_x3_multibatch_4_2", + "full_chain_multibatch", + ], +) +def test_allow_resize_multiple_batches( + stub_resource_provider, + stub_model_configs, + seed_data_setup, + spec, + num_records, + buffer_size, + expected_total_rows, +): + """Resized batches are written independently and combine to expected total rows.""" + stub_resource_provider.run_config = RunConfig(buffer_size=buffer_size) + columns = _resize_columns(spec) + builder = _build_resize_builder(stub_resource_provider, stub_model_configs, seed_data_setup, columns) + builder.build(num_records=num_records) + final_path = builder.artifact_storage.final_dataset_path + if expected_total_rows == 0 and not final_path.exists(): + df = pd.DataFrame() + else: + df = pd.read_parquet(final_path) + assert len(df) == expected_total_rows diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py index 67e96a0c..ca99d716 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py @@ -153,24 +153,59 @@ def test_update_record_invalid_index(stub_batch_manager_with_data): stub_batch_manager_with_data.update_record(-1, {"id": -1, "name": "test"}) -def test_update_records(stub_batch_manager_with_data): +def test_replace_buffer(stub_batch_manager_with_data): records = [{"id": i, "name": f"test{i}"} for i in range(3)] stub_batch_manager_with_data.add_records(records) new_records = [{"id": i, "name": f"updated{i}"} for i in range(3)] - stub_batch_manager_with_data.update_records(new_records) + stub_batch_manager_with_data.replace_buffer(new_records) assert stub_batch_manager_with_data._buffer == new_records -def test_update_records_wrong_length(stub_batch_manager_with_data): +def test_replace_buffer_wrong_length(stub_batch_manager_with_data): records = [{"id": i, "name": f"test{i}"} for i in range(3)] stub_batch_manager_with_data.add_records(records) wrong_length_records = [{"id": i, "name": f"test{i}"} for i in range(2)] - with pytest.raises(DatasetBatchManagementError, match="Number of records to update.*must match"): - stub_batch_manager_with_data.update_records(wrong_length_records) + with pytest.raises(DatasetBatchManagementError, match="Number of records.*must match"): + stub_batch_manager_with_data.replace_buffer(wrong_length_records) + + +@pytest.mark.parametrize( + "new_size", + [6, 1, 0], + ids=["expansion", "retraction", "empty"], +) +def test_replace_buffer_allow_resize(stub_batch_manager_with_data, new_size): + """allow_resize=True permits any record count change and updates bookkeeping.""" + stub_batch_manager_with_data.add_records([{"id": i} for i in range(3)]) + + new_records = [{"id": i} for i in range(new_size)] + stub_batch_manager_with_data.replace_buffer(new_records, allow_resize=True) + + assert stub_batch_manager_with_data.num_records_in_buffer == new_size + assert stub_batch_manager_with_data._buffer == new_records + assert stub_batch_manager_with_data.num_records_batch == new_size + + +def test_actual_num_records_tracks_expansion(stub_batch_manager_with_data): + """Test that actual_num_records correctly tracks when buffer is resized.""" + # Add 3 records, then expand to 6 + records = [{"id": i} for i in range(3)] + stub_batch_manager_with_data.add_records(records) + expanded = [{"id": i} for i in range(6)] + stub_batch_manager_with_data.replace_buffer(expanded, allow_resize=True) + + # Finish batch and check metadata + stub_batch_manager_with_data.finish_batch() + + with open(stub_batch_manager_with_data.artifact_storage.metadata_file_path) as f: + metadata = json.load(f) + + assert metadata["target_num_records"] == 13 # [6, 3, 3, 1] after resize from [3, 3, 3, 1] + assert metadata["actual_num_records"] == 6 # actual expanded count # Test write method @@ -271,6 +306,7 @@ def test_finish_batch_metadata_content(stub_batch_manager_with_data): metadata = json.load(f) assert metadata["target_num_records"] == 10 + assert metadata["actual_num_records"] == 3 # actual records written in this batch assert metadata["total_num_batches"] == 4 assert metadata["buffer_size"] == 3 assert metadata["num_completed_batches"] == 1