diff --git a/docs/concepts/processors.md b/docs/concepts/processors.md index 46773ecb..26dbd2eb 100644 --- a/docs/concepts/processors.md +++ b/docs/concepts/processors.md @@ -13,7 +13,18 @@ Each processor: - Applies its transformation - Passes the result to the next processor (or to output) -Currently, processors run only at the `POST_BATCH` stage, i.e., after column generation completes for each batch. +Processors can run at three stages, determined by which callback methods they implement: + +| Stage | When it runs | Callback method | Use cases | +|-------|--------------|-----------------|-----------| +| Pre-batch | After seed columns, before dependent columns | `process_before_batch()` | Transform seed data before other columns are generated | +| Post-batch | After each batch completes | `process_after_batch()` | Drop columns, transform schema per batch | +| After generation | Once, on final dataset after all batches | `process_after_generation()` | Deduplicate, aggregate statistics, final cleanup | + +!!! info "Full Schema Available During Generation" + Each batch carries the full dataset schema during generation. Post-batch schema changes such as column dropping only alter past batches, so all columns remain accessible to generators while building follow-up batches. + +A processor can implement any combination of these callbacks. The built-in processors use `process_after_batch()` by default. ## Processor Types @@ -134,7 +145,6 @@ Processors execute in the order they're added. Plan accordingly when one process | Parameter | Type | Description | |-----------|------|-------------| | `name` | str | Identifier for the processor, used in output directory names | -| `build_stage` | BuildStage | When to run (default: `POST_BATCH`) | ### DropColumnsProcessorConfig diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py index 306192f0..2ea641e7 100644 --- a/packages/data-designer-config/src/data_designer/config/__init__.py +++ b/packages/data-designer-config/src/data_designer/config/__init__.py @@ -30,7 +30,6 @@ from data_designer.config.config_builder import DataDesignerConfigBuilder # noqa: F401 from data_designer.config.custom_column import custom_column_generator # noqa: F401 from data_designer.config.data_designer_config import DataDesignerConfig # noqa: F401 - from data_designer.config.dataset_builders import BuildStage # noqa: F401 from data_designer.config.mcp import ( # noqa: F401 LocalStdioMCPProvider, MCPProvider, @@ -141,8 +140,6 @@ "custom_column_generator": (f"{_MOD_BASE}.custom_column", "custom_column_generator"), # data_designer_config "DataDesignerConfig": (f"{_MOD_BASE}.data_designer_config", "DataDesignerConfig"), - # dataset_builders - "BuildStage": (f"{_MOD_BASE}.dataset_builders", "BuildStage"), # mcp "LocalStdioMCPProvider": (_MOD_MCP, "LocalStdioMCPProvider"), "MCPProvider": (_MOD_MCP, "MCPProvider"), diff --git a/packages/data-designer-config/src/data_designer/config/config_builder.py b/packages/data-designer-config/src/data_designer/config/config_builder.py index 332170b7..df27e47a 100644 --- a/packages/data-designer-config/src/data_designer/config/config_builder.py +++ b/packages/data-designer-config/src/data_designer/config/config_builder.py @@ -22,7 +22,6 @@ get_column_display_order, ) from data_designer.config.data_designer_config import DataDesignerConfig -from data_designer.config.dataset_builders import BuildStage from data_designer.config.default_model_settings import get_default_model_configs from data_designer.config.errors import BuilderConfigurationError, BuilderSerializationError, InvalidColumnTypeError from data_designer.config.exportable_config import ExportableConfigBase @@ -572,7 +571,7 @@ def get_columns_excluding_type(self, column_type: DataDesignerColumnType) -> lis column_type = resolve_string_enum(column_type, DataDesignerColumnType) return [c for c in self._column_configs.values() if c.column_type != column_type] - def get_processor_configs(self) -> dict[BuildStage, list[ProcessorConfigT]]: + def get_processor_configs(self) -> list[ProcessorConfigT]: """Get processor configuration objects. Returns: diff --git a/packages/data-designer-config/src/data_designer/config/dataset_builders.py b/packages/data-designer-config/src/data_designer/config/dataset_builders.py deleted file mode 100644 index bbfbb2fb..00000000 --- a/packages/data-designer-config/src/data_designer/config/dataset_builders.py +++ /dev/null @@ -1,13 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from enum import Enum - - -class BuildStage(str, Enum): - PRE_BATCH = "pre_batch" - POST_BATCH = "post_batch" - PRE_GENERATION = "pre_generation" - POST_GENERATION = "post_generation" diff --git a/packages/data-designer-config/src/data_designer/config/processors.py b/packages/data-designer-config/src/data_designer/config/processors.py index db7bb9ce..0435457b 100644 --- a/packages/data-designer-config/src/data_designer/config/processors.py +++ b/packages/data-designer-config/src/data_designer/config/processors.py @@ -12,11 +12,8 @@ from typing_extensions import TypeAlias from data_designer.config.base import ConfigBase -from data_designer.config.dataset_builders import BuildStage from data_designer.config.errors import InvalidConfigError -SUPPORTED_STAGES = [BuildStage.POST_BATCH] - class ProcessorType(str, Enum): """Enumeration of available processor types. @@ -33,33 +30,22 @@ class ProcessorType(str, Enum): class ProcessorConfig(ConfigBase, ABC): """Abstract base class for all processor configuration types. - Processors are transformations that run before or after columns are generated. - They can modify, reshape, or augment the dataset before it's saved. + Processors are transformations that run at different stages of the generation + pipeline. They can modify, reshape, or augment the dataset. + + The processor implementation determines which stages it handles by overriding + the appropriate callback methods (process_before_batch, process_after_batch, process_after_generation). Attributes: name: Unique name of the processor, used to identify the processor in results and to name output artifacts on disk. - build_stage: The stage at which the processor runs. Currently only `POST_BATCH` - is supported, meaning processors run after each batch of columns is generated. """ name: str = Field( description="The name of the processor, used to identify the processor in the results and to write the artifacts to disk.", ) - build_stage: BuildStage = Field( - default=BuildStage.POST_BATCH, - description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}", - ) processor_type: str - @field_validator("build_stage") - def validate_build_stage(cls, v: BuildStage) -> BuildStage: - if v not in SUPPORTED_STAGES: - raise ValueError( - f"Invalid dataset builder stage: {v}. Only these stages are supported: {', '.join(SUPPORTED_STAGES)}" - ) - return v - def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs: Any) -> ProcessorConfig: """Create a processor configuration from a processor type and keyword arguments. diff --git a/packages/data-designer-config/tests/config/test_processors.py b/packages/data-designer-config/tests/config/test_processors.py index b18814e6..e688be15 100644 --- a/packages/data-designer-config/tests/config/test_processors.py +++ b/packages/data-designer-config/tests/config/test_processors.py @@ -4,7 +4,6 @@ import pytest from pydantic import ValidationError -from data_designer.config.dataset_builders import BuildStage from data_designer.config.errors import InvalidConfigError from data_designer.config.processors import ( DropColumnsProcessorConfig, @@ -16,92 +15,64 @@ def test_drop_columns_processor_config_creation(): - config = DropColumnsProcessorConfig( - name="drop_columns_processor", build_stage=BuildStage.POST_BATCH, column_names=["col1", "col2"] - ) + config = DropColumnsProcessorConfig(name="drop_columns_processor", column_names=["col1", "col2"]) - assert config.build_stage == BuildStage.POST_BATCH assert config.column_names == ["col1", "col2"] assert config.processor_type == ProcessorType.DROP_COLUMNS assert isinstance(config, ProcessorConfig) def test_drop_columns_processor_config_validation(): - # Test unsupported stage raises error - with pytest.raises(ValidationError, match="Invalid dataset builder stage"): - DropColumnsProcessorConfig( - name="drop_columns_processor", build_stage=BuildStage.PRE_BATCH, column_names=["col1"] - ) - # Test missing required field raises error with pytest.raises(ValidationError, match="Field required"): - DropColumnsProcessorConfig(name="drop_columns_processor", build_stage=BuildStage.POST_BATCH) + DropColumnsProcessorConfig(name="drop_columns_processor") def test_drop_columns_processor_config_serialization(): - config = DropColumnsProcessorConfig( - name="drop_columns_processor", build_stage=BuildStage.POST_BATCH, column_names=["col1", "col2"] - ) + config = DropColumnsProcessorConfig(name="drop_columns_processor", column_names=["col1", "col2"]) # Serialize to dict config_dict = config.model_dump() - assert config_dict["build_stage"] == "post_batch" assert config_dict["column_names"] == ["col1", "col2"] # Deserialize from dict config_restored = DropColumnsProcessorConfig.model_validate(config_dict) - assert config_restored.build_stage == config.build_stage assert config_restored.column_names == config.column_names def test_schema_transform_processor_config_creation(): config = SchemaTransformProcessorConfig( name="output_format_processor", - build_stage=BuildStage.POST_BATCH, template={"text": "{{ col1 }}"}, ) - assert config.build_stage == BuildStage.POST_BATCH assert config.template == {"text": "{{ col1 }}"} assert config.processor_type == ProcessorType.SCHEMA_TRANSFORM assert isinstance(config, ProcessorConfig) def test_schema_transform_processor_config_validation(): - # Test unsupported stage raises error - with pytest.raises(ValidationError, match="Invalid dataset builder stage"): - SchemaTransformProcessorConfig( - name="schema_transform_processor", - build_stage=BuildStage.PRE_BATCH, - template={"text": "{{ col1 }}"}, - ) - # Test missing required field raises error with pytest.raises(ValidationError, match="Field required"): - SchemaTransformProcessorConfig(name="schema_transform_processor", build_stage=BuildStage.POST_BATCH) + SchemaTransformProcessorConfig(name="schema_transform_processor") # Test invalid template raises error with pytest.raises(InvalidConfigError, match="Template must be JSON serializable"): - SchemaTransformProcessorConfig( - name="schema_transform_processor", build_stage=BuildStage.POST_BATCH, template={"text": {1, 2, 3}} - ) + SchemaTransformProcessorConfig(name="schema_transform_processor", template={"text": {1, 2, 3}}) def test_schema_transform_processor_config_serialization(): config = SchemaTransformProcessorConfig( name="schema_transform_processor", - build_stage=BuildStage.POST_BATCH, template={"text": "{{ col1 }}"}, ) # Serialize to dict config_dict = config.model_dump() - assert config_dict["build_stage"] == "post_batch" assert config_dict["template"] == {"text": "{{ col1 }}"} # Deserialize from dict config_restored = SchemaTransformProcessorConfig.model_validate(config_dict) - assert config_restored.build_stage == config.build_stage assert config_restored.template == config.template @@ -110,7 +81,6 @@ def test_get_processor_config_from_kwargs(): config_drop_columns = get_processor_config_from_kwargs( ProcessorType.DROP_COLUMNS, name="drop_columns_processor", - build_stage=BuildStage.POST_BATCH, column_names=["col1"], ) assert isinstance(config_drop_columns, DropColumnsProcessorConfig) @@ -120,7 +90,6 @@ def test_get_processor_config_from_kwargs(): config_schema_transform = get_processor_config_from_kwargs( ProcessorType.SCHEMA_TRANSFORM, name="output_format_processor", - build_stage=BuildStage.POST_BATCH, template={"text": "{{ col1 }}"}, ) assert isinstance(config_schema_transform, SchemaTransformProcessorConfig) @@ -134,6 +103,6 @@ class UnknownProcessorType(str, Enum): UNKNOWN = "unknown" result = get_processor_config_from_kwargs( - UnknownProcessorType.UNKNOWN, name="unknown_processor", build_stage=BuildStage.POST_BATCH, column_names=["col1"] + UnknownProcessorType.UNKNOWN, name="unknown_processor", column_names=["col1"] ) assert result is None diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 781b0673..e5404d49 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -14,7 +14,6 @@ from data_designer.config.column_types import ColumnConfigT from data_designer.config.config_builder import BuilderConfig from data_designer.config.data_designer_config import DataDesignerConfig -from data_designer.config.dataset_builders import BuildStage from data_designer.config.processors import ( DropColumnsProcessorConfig, ProcessorConfig, @@ -29,11 +28,12 @@ from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated from data_designer.engine.compiler import compile_data_designer_config from data_designer.engine.dataset_builders.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage -from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError +from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager +from data_designer.engine.dataset_builders.utils.processor_runner import ProcessorRunner from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler from data_designer.engine.processing.processors.base import Processor @@ -67,8 +67,10 @@ def __init__( self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider) self._column_configs = compile_dataset_builder_column_configs(self._data_designer_config) - self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors( - self._data_designer_config.processors or [] + processors = self._initialize_processors(self._data_designer_config.processors or []) + self._processor_runner = ProcessorRunner( + processors=processors, + artifact_storage=resource_provider.artifact_storage, ) self._validate_column_configs() @@ -76,6 +78,17 @@ def __init__( def artifact_storage(self) -> ArtifactStorage: return self._resource_provider.artifact_storage + @property + def processors(self) -> tuple[Processor, ...]: + return self._processor_runner.processors + + def set_processor_runner(self, processors: list[Processor]) -> None: + """Replace the processor runner with a new one using the given processors.""" + self._processor_runner = ProcessorRunner( + processors=processors, + artifact_storage=self.artifact_storage, + ) + @functools.cached_property def single_column_configs(self) -> list[ColumnConfigT]: configs = [] @@ -107,15 +120,15 @@ def build( self.batch_manager.start(num_records=num_records, buffer_size=buffer_size) for batch_idx in range(self.batch_manager.num_batches): logger.info(f"âŗ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}") - self._run_batch(generators, batch_mode="batch", group_id=group_id) - df_batch = self._run_processors( - stage=BuildStage.POST_BATCH, - dataframe=self.batch_manager.get_current_batch(as_dataframe=True), + self._run_batch( + generators, + batch_mode="batch", + group_id=group_id, current_batch_number=batch_idx, + on_batch_complete=on_batch_complete, ) - self._write_processed_batch(df_batch) - self.batch_manager.finish_batch(on_batch_complete) self.batch_manager.finish() + self._processor_runner.run_after_generation(buffer_size) self._resource_provider.model_registry.log_model_usage(time.perf_counter() - start_time) @@ -138,11 +151,8 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame: return dataset def process_preview(self, dataset: pd.DataFrame) -> pd.DataFrame: - return self._run_processors( - stage=BuildStage.POST_BATCH, - dataframe=dataset.copy(), - current_batch_number=None, # preview mode does not have a batch number - ) + df = self._processor_runner.run_post_batch(dataset.copy(), current_batch_number=None) + return self._processor_runner.run_after_generation_on_df(df) def _initialize_generators(self) -> list[ColumnGenerator]: return [ @@ -159,15 +169,27 @@ def _write_builder_config(self) -> None: ) def _run_batch( - self, generators: list[ColumnGenerator], *, batch_mode: str, save_partial_results: bool = True, group_id: str + self, + generators: list[ColumnGenerator], + *, + batch_mode: str, + save_partial_results: bool = True, + group_id: str, + current_batch_number: int | None = None, + on_batch_complete: Callable[[Path], None] | None = None, ) -> None: pre_batch_snapshot = self._resource_provider.model_registry.get_model_usage_snapshot() + ran_pre_batch = False for generator in generators: generator.log_pre_generation() try: generation_strategy = generator.get_generation_strategy() if generator.can_generate_from_scratch and self.batch_manager.buffer_is_empty: self._run_from_scratch_column_generator(generator) + # Run PRE_BATCH after seed generator, before other columns + if not ran_pre_batch: + self._processor_runner.run_pre_batch(self.batch_manager) + ran_pre_batch = True elif generation_strategy == GenerationStrategy.CELL_BY_CELL: self._run_cell_by_cell_generator(generator) elif generation_strategy == GenerationStrategy.FULL_COLUMN: @@ -191,6 +213,12 @@ def _run_batch( except Exception: pass + if current_batch_number is not None: + df_batch = self.batch_manager.get_current_batch(as_dataframe=True) + df_batch = self._processor_runner.run_post_batch(df_batch, current_batch_number=current_batch_number) + self._write_processed_batch(df_batch) + self.batch_manager.finish_batch(on_batch_complete) + def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None: df = generator.generate_from_scratch(self.batch_manager.num_records_batch) self.batch_manager.add_records(df.to_dict(orient="records")) @@ -288,20 +316,20 @@ def _validate_column_configs(self) -> None: ).can_generate_from_scratch: raise DatasetGenerationError("🛑 The first column config must be a from-scratch column generator.") - def _initialize_processors(self, processor_configs: list[ProcessorConfig]) -> dict[BuildStage, list[Processor]]: + def _initialize_processors(self, processor_configs: list[ProcessorConfig]) -> list[Processor]: # Check columns marked for drop columns_to_drop = [config.name for config in self.single_column_configs if config.drop] - processors: dict[BuildStage, list[Processor]] = {stage: [] for stage in BuildStage} + processors: list[Processor] = [] for config in processor_configs: - processors[config.build_stage].append( + processors.append( self._registry.processors.get_for_config_type(type(config))( config=config, resource_provider=self._resource_provider, ) ) - # Manually included "drop columns" processor takes precedence (can e.g., pick stages other than post-batch) + # Manually included "drop columns" processor takes precedence if config.processor_type == ProcessorType.DROP_COLUMNS: for column in config.column_names: if column in columns_to_drop: @@ -309,12 +337,11 @@ def _initialize_processors(self, processor_configs: list[ProcessorConfig]) -> di # If there are still columns marked for drop, add the "drop columns" processor to drop them if len(columns_to_drop) > 0: - processors[BuildStage.POST_BATCH].append( # as post-batch by default + processors.append( DropColumnsProcessor( config=DropColumnsProcessorConfig( name="default_drop_columns_processor", column_names=columns_to_drop, - build_stage=BuildStage.POST_BATCH, ), resource_provider=self._resource_provider, ) @@ -322,18 +349,6 @@ def _initialize_processors(self, processor_configs: list[ProcessorConfig]) -> di return processors - def _run_processors( - self, stage: BuildStage, dataframe: pd.DataFrame, current_batch_number: int | None = None - ) -> pd.DataFrame: - for processor in self._processors[stage]: - try: - dataframe = processor.process(dataframe, current_batch_number=current_batch_number) - except Exception as e: - raise DatasetProcessingError( - f"🛑 Failed to process dataset with processor {processor.name} in stage {stage}: {e}" - ) from e - return dataframe - def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) -> None: """If a worker fails, we can handle the exception here.""" logger.warning( diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py index e277088a..a60d52b9 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py @@ -198,3 +198,9 @@ def update_records(self, records: list[dict]) -> None: f"the number of records in the buffer ({len(self._buffer)})." ) self._buffer = records + + def replace_buffer(self, records: list[dict]) -> None: + """Replace the buffer contents, updating the current batch size.""" + self._buffer = records + if self._num_records_list is not None: + self._num_records_list[self._current_batch_number] = len(records) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py new file mode 100644 index 00000000..284d45f5 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +import shutil +from enum import Enum +from typing import TYPE_CHECKING + +from data_designer.engine.dataset_builders.artifact_storage import BatchStage +from data_designer.engine.dataset_builders.errors import DatasetProcessingError + +if TYPE_CHECKING: + import pandas as pd + + from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage + from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager + from data_designer.engine.processing.processors.base import Processor + +logger = logging.getLogger(__name__) + + +class ProcessorStage(str, Enum): + """Valid processor callback stages.""" + + PRE_BATCH = "process_before_batch" + POST_BATCH = "process_after_batch" + AFTER_GENERATION = "process_after_generation" + + +class ProcessorRunner: + """Runs processor callbacks at various stages of dataset generation.""" + + def __init__( + self, + processors: list[Processor], + artifact_storage: ArtifactStorage, + ): + self._processors = processors + self._artifact_storage = artifact_storage + + @property + def processors(self) -> tuple[Processor, ...]: + return tuple(self._processors) + + def has_processors_for(self, stage: ProcessorStage) -> bool: + """Check if any processor implements the given stage.""" + return any(p.implements(stage.value) for p in self._processors) + + def _run_stage(self, df: pd.DataFrame, stage: ProcessorStage, **kwargs) -> pd.DataFrame: + """Run a processor callback on all processors that implement it.""" + original_len = len(df) + for processor in self._processors: + if not processor.implements(stage.value): + continue + try: + df = getattr(processor, stage.value)(df, **kwargs) + except Exception as e: + raise DatasetProcessingError(f"🛑 Failed in {stage.value} for {processor.name}: {e}") from e + if len(df) != original_len: + delta = len(df) - original_len + logger.info(f"â„šī¸ {stage.name} processors changed the record count by {delta:+d} records.") + return df + + def run_pre_batch(self, batch_manager: DatasetBatchManager) -> None: + """Run process_before_batch() on current batch.""" + if not self.has_processors_for(ProcessorStage.PRE_BATCH): + return + + df = batch_manager.get_current_batch(as_dataframe=True) + df = self._run_stage(df, ProcessorStage.PRE_BATCH) + batch_manager.replace_buffer(df.to_dict(orient="records")) + + def run_post_batch(self, df: pd.DataFrame, current_batch_number: int | None) -> pd.DataFrame: + """Run process_after_batch() on processors that implement it.""" + return self._run_stage(df, ProcessorStage.POST_BATCH, current_batch_number=current_batch_number) + + def run_after_generation_on_df(self, df: pd.DataFrame) -> pd.DataFrame: + """Run process_after_generation() on a DataFrame (for preview mode).""" + return self._run_stage(df, ProcessorStage.AFTER_GENERATION) + + def run_after_generation(self, batch_size: int) -> None: + """Load final dataset, run process_after_generation(), rewrite in chunks. + + Re-chunks the processed dataset using the given batch_size so that output + files stay consistently sized regardless of how many rows the processor + adds or removes. + """ + if not self.has_processors_for(ProcessorStage.AFTER_GENERATION): + return + + logger.info("âŗ Running process_after_generation on final dataset...") + df = self._artifact_storage.load_dataset() + df = self._run_stage(df, ProcessorStage.AFTER_GENERATION) + + shutil.rmtree(self._artifact_storage.final_dataset_path) + for i in range(0, max(len(df), 1), batch_size): + self._artifact_storage.write_batch_to_parquet_file( + batch_number=i // batch_size, + dataframe=df.iloc[i : i + batch_size], + batch_stage=BatchStage.FINAL_RESULT, + ) + logger.info(f"✅ process_after_generation complete. Final dataset has {len(df)} rows.") diff --git a/packages/data-designer-engine/src/data_designer/engine/processing/processors/base.py b/packages/data-designer-engine/src/data_designer/engine/processing/processors/base.py index 8dd47132..d6a32e91 100644 --- a/packages/data-designer-engine/src/data_designer/engine/processing/processors/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/processing/processors/base.py @@ -3,11 +3,58 @@ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import ABC from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT class Processor(ConfigurableTask[TaskConfigT], ABC): - @abstractmethod - def process(self, data: DataT, *, current_batch_number: int | None = None) -> DataT: ... + """Base class for dataset processors. + + Processors transform data at different stages of the generation pipeline. + Override the callback methods for the stages you want to handle. + """ + + def implements(self, method_name: str) -> bool: + """Check if subclass overrides a callback method.""" + return getattr(type(self), method_name) is not getattr(Processor, method_name) + + def process_before_batch(self, data: DataT) -> DataT: + """Called at PRE_BATCH stage before each batch is generated. + + Override to transform batch data before generation begins. + + Args: + data: The batch data before generation. + + Returns: + Transformed batch data. + """ + return data + + def process_after_batch(self, data: DataT, *, current_batch_number: int | None) -> DataT: + """Called at POST_BATCH stage after each batch is generated. + + Override to process each batch of generated data. + + Args: + data: The generated batch data. + current_batch_number: The current batch number (0-indexed), or None in preview mode. + + Returns: + Transformed batch data. + """ + return data + + def process_after_generation(self, data: DataT) -> DataT: + """Called at AFTER_GENERATION stage on the final combined dataset. + + Override to transform the complete generated dataset. + + Args: + data: The final combined dataset. + + Returns: + Transformed final dataset. + """ + return data diff --git a/packages/data-designer-engine/src/data_designer/engine/processing/processors/drop_columns.py b/packages/data-designer-engine/src/data_designer/engine/processing/processors/drop_columns.py index 98369a6b..bb26af2a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/processing/processors/drop_columns.py +++ b/packages/data-designer-engine/src/data_designer/engine/processing/processors/drop_columns.py @@ -18,10 +18,15 @@ class DropColumnsProcessor(Processor[DropColumnsProcessorConfig]): - def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame: + """Drops specified columns from the dataset after each batch.""" + + def process_after_batch(self, data: pd.DataFrame, *, current_batch_number: int | None) -> pd.DataFrame: logger.info(f"🙈 Dropping columns: {self.config.column_names}") - if current_batch_number is not None: # not in preview mode - self._save_dropped_columns_if_needed(data, current_batch_number) + if current_batch_number is not None: + self._save_dropped_columns(data, current_batch_number) + return self._drop_columns(data) + + def _drop_columns(self, data: pd.DataFrame) -> pd.DataFrame: for column in self.config.column_names: if column in data.columns: data.drop(columns=[column], inplace=True) @@ -29,7 +34,12 @@ def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None logger.warning(f"âš ī¸ Cannot drop column: `{column}` not found in the dataset.") return data - def _save_dropped_columns_if_needed(self, data: pd.DataFrame, current_batch_number: int) -> None: + def _save_dropped_columns(self, data: pd.DataFrame, current_batch_number: int) -> None: + # Only save columns that actually exist + existing_columns = [col for col in self.config.column_names if col in data.columns] + if not existing_columns: + return + logger.debug("đŸ“Ļ Saving dropped columns to dropped-columns directory") dropped_column_parquet_file_name = self.artifact_storage.create_batch_file_path( batch_number=current_batch_number, @@ -37,6 +47,6 @@ def _save_dropped_columns_if_needed(self, data: pd.DataFrame, current_batch_numb ).name self.artifact_storage.write_parquet_file( parquet_file_name=dropped_column_parquet_file_name, - dataframe=data[self.config.column_names], + dataframe=data[existing_columns], batch_stage=BatchStage.DROPPED_COLUMNS, ) diff --git a/packages/data-designer-engine/src/data_designer/engine/processing/processors/schema_transform.py b/packages/data-designer-engine/src/data_designer/engine/processing/processors/schema_transform.py index b84339e6..349afddc 100644 --- a/packages/data-designer-engine/src/data_designer/engine/processing/processors/schema_transform.py +++ b/packages/data-designer-engine/src/data_designer/engine/processing/processors/schema_transform.py @@ -41,19 +41,14 @@ def escape_for_json_string(s: str) -> str: class SchemaTransformProcessor(WithJinja2UserTemplateRendering, Processor[SchemaTransformProcessorConfig]): + """Transforms dataset schema using Jinja2 templates after each batch.""" + @property def template_as_str(self) -> str: return json.dumps(self.config.template) - def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame: - self.prepare_jinja2_template_renderer(self.template_as_str, data.columns.to_list()) - formatted_records = [] - for record in data.to_dict(orient="records"): - deserialized = deserialize_json_values(record) - escaped = _json_escape_record(deserialized) - rendered = self.render_template(escaped) - formatted_records.append(json.loads(rendered)) - formatted_data = pd.DataFrame(formatted_records) + def process_after_batch(self, data: pd.DataFrame, *, current_batch_number: int | None) -> pd.DataFrame: + formatted_data = self._transform(data) if current_batch_number is not None: self.artifact_storage.write_batch_to_parquet_file( batch_number=current_batch_number, @@ -67,5 +62,14 @@ def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None dataframe=formatted_data, batch_stage=BatchStage.PROCESSORS_OUTPUTS, ) - return data + + def _transform(self, data: pd.DataFrame) -> pd.DataFrame: + self.prepare_jinja2_template_renderer(self.template_as_str, data.columns.to_list()) + formatted_records = [] + for record in data.to_dict(orient="records"): + deserialized = deserialize_json_values(record) + escaped = _json_escape_record(deserialized) + rendered = self.render_template(escaped) + formatted_records.append(json.loads(rendered)) + return pd.DataFrame(formatted_records) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py index 521c39f5..c512871a 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py @@ -11,16 +11,18 @@ from data_designer.config.column_configs import CustomColumnConfig, LLMTextColumnConfig, SamplerColumnConfig from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.custom_column import custom_column_generator -from data_designer.config.dataset_builders import BuildStage from data_designer.config.processors import DropColumnsProcessorConfig from data_designer.config.run_config import RunConfig from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams +from data_designer.config.seed_source import DataFrameSeedSource, LocalFileSeedSource from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder -from data_designer.engine.dataset_builders.errors import DatasetGenerationError +from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum from data_designer.engine.models.usage import ModelUsageStats, TokenUsageStats +from data_designer.engine.processing.processors.base import Processor from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry +from data_designer.engine.resources.seed_reader import DataFrameSeedReader from data_designer.lazy_heavy_imports import pd if TYPE_CHECKING: @@ -38,11 +40,7 @@ def stub_test_column_configs(): @pytest.fixture def stub_test_processor_configs(): - return [ - DropColumnsProcessorConfig( - name="drop_columns_processor", build_stage=BuildStage.POST_BATCH, column_names=["column_to_drop"] - ) - ] + return [DropColumnsProcessorConfig(name="drop_columns_processor", column_names=["column_to_drop"])] @pytest.fixture @@ -53,7 +51,6 @@ def stub_test_config_builder(stub_test_column_configs, stub_model_configs): config_builder.add_processor( processor_type="drop_columns", name="drop_columns_processor", - build_stage=BuildStage.POST_BATCH, column_names=["column_to_drop"], ) return config_builder @@ -87,6 +84,45 @@ def stub_column_wise_builder(stub_resource_provider, stub_test_config_builder): ) +@pytest.fixture +def seed_data_setup(stub_resource_provider, tmp_path): + """Set up seed reader with test data and write seed file to disk.""" + seed_df = pd.DataFrame({"seed_id": [1, 2, 3, 4, 5], "text": ["a", "b", "c", "d", "e"]}) + seed_source = DataFrameSeedSource(df=seed_df) + seed_reader = DataFrameSeedReader() + seed_reader.attach(seed_source, Mock()) + stub_resource_provider.seed_reader = seed_reader + + seed_path = tmp_path / "seed.parquet" + seed_df.to_parquet(seed_path, index=False) + + return {"seed_df": seed_df, "seed_path": seed_path} + + +@pytest.fixture +def builder_with_seed(stub_resource_provider, stub_model_configs, seed_data_setup): + """Create a builder with seed dataset configured.""" + config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"]))) + config_builder.add_column(SamplerColumnConfig(name="extra", sampler_type="uuid", params=UUIDSamplerParams())) + + return ColumnWiseDatasetBuilder( + data_designer_config=config_builder.build(), + resource_provider=stub_resource_provider, + ) + + +def create_mock_processor(name: str, stages: list[str]) -> Mock: + """Create a mock processor that implements specified stages.""" + mock_processor = Mock(spec=Processor) + mock_processor.name = name + mock_processor.implements.side_effect = lambda m: m in stages + mock_processor.process_before_batch.side_effect = lambda df: df + mock_processor.process_after_batch.side_effect = lambda df, **kw: df + mock_processor.process_after_generation.side_effect = lambda df: df + return mock_processor + + def test_column_wise_dataset_builder_creation(stub_resource_provider, stub_test_config_builder): builder = ColumnWiseDatasetBuilder( data_designer_config=stub_test_config_builder.build(), @@ -170,6 +206,7 @@ def test_column_wise_dataset_builder_build_method_basic_flow( stub_resource_provider, ): stub_resource_provider.run_config = RunConfig(buffer_size=50) + stub_resource_provider.seed_reader = None # No seed data for this basic flow test stub_resource_provider.model_registry.run_health_check = Mock() stub_resource_provider.model_registry.get_model_usage_stats = Mock(return_value={"test": "stats"}) stub_resource_provider.model_registry.models = {} @@ -184,6 +221,7 @@ def test_column_wise_dataset_builder_build_method_basic_flow( stub_batch_manager.iter_current_batch.return_value = [(0, {"test": "data"})] stub_column_wise_builder.batch_manager = stub_batch_manager + stub_column_wise_builder.set_processor_runner([]) # No processors for basic flow test result_path = stub_column_wise_builder.build(num_records=100) @@ -232,16 +270,6 @@ def test_column_wise_dataset_builder_validate_column_configs( ) -def test_column_wise_dataset_builder_initialize_processors(stub_column_wise_builder): - processors = stub_column_wise_builder._processors - assert processors.keys() == set(BuildStage) - assert len(processors[BuildStage.PRE_BATCH]) == 0 - assert len(processors[BuildStage.POST_BATCH]) == 1 - assert len(processors[BuildStage.PRE_GENERATION]) == 0 - assert len(processors[BuildStage.POST_GENERATION]) == 0 - assert processors[BuildStage.POST_BATCH][0].config.column_names == ["column_to_drop"] - - def test_run_config_default_non_inference_max_parallel_workers() -> None: run_config = RunConfig() assert run_config.non_inference_max_parallel_workers == 4 @@ -352,8 +380,6 @@ def test_fan_out_with_threads_uses_early_shutdown_settings_from_resource_provide shutdown_error_window: int, ) -> None: """Test that _fan_out_with_threads uses run settings from resource_provider.""" - from data_designer.config.run_config import RunConfig - stub_resource_provider.run_config = RunConfig( disable_early_shutdown=disable_early_shutdown, shutdown_error_rate=configured_rate, @@ -405,3 +431,134 @@ def bad_fn(df: pd.DataFrame) -> pd.DataFrame: with pytest.raises(DatasetGenerationError, match=r"(?s)Failed to process column 'col'.*something broke"): builder.build_preview(num_records=3) + + +# Processor tests + + +@pytest.fixture +def simple_builder(stub_resource_provider, stub_model_configs): + """Minimal builder with a single UUID column and no batch files on disk.""" + config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + config_builder.add_column(SamplerColumnConfig(name="id", sampler_type="uuid", params=UUIDSamplerParams())) + return ColumnWiseDatasetBuilder( + data_designer_config=config_builder.build(), + resource_provider=stub_resource_provider, + ) + + +def test_initialize_processors(stub_column_wise_builder): + processors = stub_column_wise_builder.processors + assert isinstance(processors, tuple) + assert len(processors) == 1 + assert processors[0].config.column_names == ["column_to_drop"] + + +@pytest.mark.parametrize( + "processor_fn,batch_size,expected_rows,expected_files", + [ + pytest.param(lambda df: df, 3, 9, 3, id="noop_even"), + pytest.param(lambda df: df[df["id"] > 3], 3, 6, 2, id="filter_even"), + pytest.param(lambda df: df[df["id"] != 3].reset_index(drop=True), 3, 8, 3, id="filter_uneven"), + pytest.param(lambda df: df[df["id"] > 8], 3, 1, 1, id="filter_fewer_than_batch_size"), + ], +) +def test_run_after_generation( + stub_resource_provider, simple_builder, processor_fn, batch_size, expected_rows, expected_files +): + """Test that process_after_generation re-chunks output by batch_size.""" + storage = stub_resource_provider.artifact_storage + storage.mkdir_if_needed(storage.final_dataset_path) + pd.DataFrame({"id": list(range(1, 10))}).to_parquet(storage.final_dataset_path / "batch_00000.parquet", index=False) + + mock_processor = create_mock_processor("proc", ["process_after_generation"]) + mock_processor.process_after_generation.side_effect = processor_fn + + simple_builder.set_processor_runner([mock_processor]) + simple_builder._processor_runner.run_after_generation(batch_size) + + mock_processor.process_after_generation.assert_called_once() + batch_files = sorted(storage.final_dataset_path.glob("*.parquet")) + assert len(batch_files) == expected_files + assert sum(len(pd.read_parquet(f)) for f in batch_files) == expected_rows + + +@pytest.mark.parametrize("mode", ["preview", "build"]) +def test_all_processor_stages_run_in_order(builder_with_seed, mode): + """Test that all 3 processor stages run in correct order for both preview and build modes.""" + call_order = [] + all_stages = ["process_before_batch", "process_after_batch", "process_after_generation"] + + mock_processor = create_mock_processor("all_stages_processor", all_stages) + mock_processor.process_before_batch.side_effect = lambda df: (call_order.append("process_before_batch"), df)[1] + mock_processor.process_after_batch.side_effect = lambda df, **kw: (call_order.append("process_after_batch"), df)[1] + mock_processor.process_after_generation.side_effect = lambda df: ( + call_order.append("process_after_generation"), + df, + )[1] + + builder_with_seed.set_processor_runner([mock_processor]) + + if mode == "preview": + raw_dataset = builder_with_seed.build_preview(num_records=3) + builder_with_seed.process_preview(raw_dataset) + else: + builder_with_seed.build(num_records=3) + + mock_processor.process_before_batch.assert_called_once() + mock_processor.process_after_batch.assert_called_once() + mock_processor.process_after_generation.assert_called_once() + + assert call_order == all_stages + + +def test_processor_exception_in_process_after_batch_raises_error(simple_builder): + """Test that processor exceptions during process_after_batch are properly wrapped.""" + mock_processor = create_mock_processor("failing_processor", ["process_after_batch"]) + mock_processor.process_after_batch.side_effect = ValueError("Post-batch processing failed") + + simple_builder.set_processor_runner([mock_processor]) + + with pytest.raises(DatasetProcessingError, match="Failed in process_after_batch"): + simple_builder._processor_runner.run_post_batch(pd.DataFrame({"id": [1, 2, 3]}), current_batch_number=0) + + +def test_processor_with_no_implemented_stages_is_skipped(builder_with_seed): + """Test that a processor implementing no stages doesn't cause errors.""" + mock_processor = create_mock_processor("noop_processor", []) + builder_with_seed.set_processor_runner([mock_processor]) + + result = builder_with_seed.build_preview(num_records=3) + + assert len(result) == 3 + mock_processor.process_before_batch.assert_not_called() + mock_processor.process_after_batch.assert_not_called() + mock_processor.process_after_generation.assert_not_called() + + +def test_multiple_processors_run_in_definition_order(builder_with_seed): + """Test that multiple processors run in the order they were defined.""" + call_order = [] + + processors = [] + for label in ["a", "b", "c"]: + p = create_mock_processor(f"processor_{label}", ["process_before_batch"]) + p.process_before_batch.side_effect = lambda df, lbl=label: (call_order.append(lbl), df)[1] + processors.append(p) + + builder_with_seed.set_processor_runner(processors) + builder_with_seed.build(num_records=3) + + assert call_order == ["a", "b", "c"] + + +def test_process_preview_with_empty_dataframe(simple_builder): + """Test that process_preview handles empty DataFrames gracefully.""" + mock_processor = create_mock_processor("test_processor", ["process_after_batch", "process_after_generation"]) + simple_builder.set_processor_runner([mock_processor]) + + result = simple_builder.process_preview(pd.DataFrame()) + + assert len(result) == 0 + mock_processor.process_after_batch.assert_called_once() + mock_processor.process_after_generation.assert_called_once() diff --git a/packages/data-designer-engine/tests/engine/processing/processors/test_drop_columns.py b/packages/data-designer-engine/tests/engine/processing/processors/test_drop_columns.py index 53da3e4a..97662e98 100644 --- a/packages/data-designer-engine/tests/engine/processing/processors/test_drop_columns.py +++ b/packages/data-designer-engine/tests/engine/processing/processors/test_drop_columns.py @@ -8,7 +8,6 @@ import pytest -from data_designer.config.dataset_builders import BuildStage from data_designer.config.processors import DropColumnsProcessorConfig from data_designer.engine.dataset_builders.artifact_storage import BatchStage from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor @@ -20,9 +19,7 @@ @pytest.fixture def stub_processor_config(): - return DropColumnsProcessorConfig( - name="drop_columns_processor", build_stage=BuildStage.POST_BATCH, column_names=["col1", "col2"] - ) + return DropColumnsProcessorConfig(name="drop_columns_processor", column_names=["col1", "col2"]) @pytest.fixture @@ -84,34 +81,34 @@ def stub_empty_dataframe(): ), ], ) -def test_process_scenarios( +def test_process_after_batch_scenarios( stub_processor, stub_sample_dataframe, test_case, column_names, expected_result, expected_warning ): stub_processor.config.column_names = column_names if expected_warning: with patch("data_designer.engine.processing.processors.drop_columns.logger") as mock_logger: - result = stub_processor.process(stub_sample_dataframe.copy()) + result = stub_processor.process_after_batch(stub_sample_dataframe.copy(), current_batch_number=0) pd.testing.assert_frame_equal(result, pd.DataFrame(expected_result)) mock_logger.warning.assert_called_once_with(expected_warning) else: - result = stub_processor.process(stub_sample_dataframe.copy()) + result = stub_processor.process_after_batch(stub_sample_dataframe.copy(), current_batch_number=0) pd.testing.assert_frame_equal(result, pd.DataFrame(expected_result)) -def test_process_logging(stub_processor, stub_sample_dataframe): +def test_process_after_batch_logging(stub_processor, stub_sample_dataframe): with patch("data_designer.engine.processing.processors.drop_columns.logger") as mock_logger: - stub_processor.process(stub_sample_dataframe.copy()) + stub_processor.process_after_batch(stub_sample_dataframe.copy(), current_batch_number=0) mock_logger.info.assert_called_once_with("🙈 Dropping columns: ['col1', 'col2']") -def test_save_dropped_columns_without_preview(stub_processor, stub_sample_dataframe): +def test_save_dropped_columns(stub_processor, stub_sample_dataframe): stub_processor.config.column_names = ["col1", "col2"] with patch("data_designer.engine.processing.processors.drop_columns.logger") as mock_logger: - stub_processor.process(stub_sample_dataframe.copy(), current_batch_number=0) + stub_processor.process_after_batch(stub_sample_dataframe.copy(), current_batch_number=0) stub_processor.artifact_storage.write_parquet_file.assert_called_once() call_args = stub_processor.artifact_storage.write_parquet_file.call_args @@ -126,24 +123,19 @@ def test_save_dropped_columns_without_preview(stub_processor, stub_sample_datafr mock_logger.debug.assert_called_once_with("đŸ“Ļ Saving dropped columns to dropped-columns directory") -def test_save_dropped_columns_with_preview(stub_processor, stub_sample_dataframe): - stub_processor.config.column_names = ["col1", "col2"] - - stub_processor.process(stub_sample_dataframe.copy()) - stub_processor.artifact_storage.write_parquet_file.assert_not_called() - - def test_save_dropped_columns_with_nonexistent_columns(stub_processor, stub_sample_dataframe): + """When columns don't exist, no file is written but warnings are logged.""" stub_processor.config.column_names = ["nonexistent1", "nonexistent2"] with patch("data_designer.engine.processing.processors.drop_columns.logger"): - with pytest.raises(KeyError): - stub_processor.process(stub_sample_dataframe.copy(), current_batch_number=0) + stub_processor.process_after_batch(stub_sample_dataframe.copy(), current_batch_number=0) + # No file is written for nonexistent columns + stub_processor.artifact_storage.write_parquet_file.assert_not_called() -def test_process_inplace_modification(stub_processor, stub_sample_dataframe): +def test_process_after_batch_inplace_modification(stub_processor, stub_sample_dataframe): original_df = stub_sample_dataframe.copy() - result = stub_processor.process(original_df) + result = stub_processor.process_after_batch(original_df, current_batch_number=0) assert result is original_df @@ -152,11 +144,26 @@ def test_process_inplace_modification(stub_processor, stub_sample_dataframe): assert "col3" in result.columns -def test_process_empty_dataframe(stub_processor, stub_empty_dataframe): +def test_process_after_batch_empty_dataframe(stub_processor, stub_empty_dataframe): stub_processor.config.column_names = ["col1"] with patch("data_designer.engine.processing.processors.drop_columns.logger") as mock_logger: - result = stub_processor.process(stub_empty_dataframe) + result = stub_processor.process_after_batch(stub_empty_dataframe, current_batch_number=0) pd.testing.assert_frame_equal(result, stub_empty_dataframe) mock_logger.warning.assert_called_once_with("âš ī¸ Cannot drop column: `col1` not found in the dataset.") + + +def test_process_after_batch_preview_mode_does_not_save(stub_processor, stub_sample_dataframe): + """In preview mode (current_batch_number=None), columns are dropped but not saved to disk.""" + stub_processor.config.column_names = ["col1", "col2"] + + result = stub_processor.process_after_batch(stub_sample_dataframe.copy(), current_batch_number=None) + + # Columns should still be dropped + assert "col1" not in result.columns + assert "col2" not in result.columns + assert "col3" in result.columns + + # But no file should be written + stub_processor.artifact_storage.write_parquet_file.assert_not_called() diff --git a/packages/data-designer-engine/tests/engine/processing/processors/test_schema_transform.py b/packages/data-designer-engine/tests/engine/processing/processors/test_schema_transform.py index 520d67da..69b5b357 100644 --- a/packages/data-designer-engine/tests/engine/processing/processors/test_schema_transform.py +++ b/packages/data-designer-engine/tests/engine/processing/processors/test_schema_transform.py @@ -9,7 +9,6 @@ import pytest -from data_designer.config.dataset_builders import BuildStage from data_designer.config.processors import SchemaTransformProcessorConfig from data_designer.engine.dataset_builders.artifact_storage import BatchStage from data_designer.engine.processing.processors.schema_transform import SchemaTransformProcessor @@ -23,7 +22,6 @@ @pytest.fixture def stub_processor_config() -> SchemaTransformProcessorConfig: return SchemaTransformProcessorConfig( - build_stage=BuildStage.POST_BATCH, template={"text": "{{ col1 }}", "value": "{{ col2 }}"}, name="test_schema_transform", ) @@ -53,20 +51,20 @@ def stub_simple_dataframe() -> pd.DataFrame: ) -def test_process_returns_original_dataframe( +def test_process_after_batch_returns_original_dataframe( stub_processor: SchemaTransformProcessor, stub_sample_dataframe: pd.DataFrame ) -> None: - result = stub_processor.process(stub_sample_dataframe, current_batch_number=0) + result = stub_processor.process_after_batch(stub_sample_dataframe, current_batch_number=0) pd.testing.assert_frame_equal(result, stub_sample_dataframe) -def test_process_writes_formatted_output_to_parquet( +def test_process_after_batch_writes_formatted_output_to_parquet( stub_processor: SchemaTransformProcessor, stub_sample_dataframe: pd.DataFrame ) -> None: # Process the dataframe - result = stub_processor.process(stub_sample_dataframe, current_batch_number=0) + result = stub_processor.process_after_batch(stub_sample_dataframe, current_batch_number=0) - # Verify the original dataframe is returned + # Verify the original dataframe is returned (formatted data is artifact-only) pd.testing.assert_frame_equal(result, stub_sample_dataframe) # Verify write_batch_to_parquet_file was called with correct parameters @@ -97,20 +95,7 @@ def test_process_writes_formatted_output_to_parquet( assert json.loads(actual) == json.loads(expected), f"Row {i} mismatch: {actual} != {expected}" -def test_process_without_batch_number_does_not_write( - stub_processor: SchemaTransformProcessor, stub_sample_dataframe: pd.DataFrame -) -> None: - # Process without batch number (preview mode) - result = stub_processor.process(stub_sample_dataframe, current_batch_number=None) - - # Verify the original dataframe is returned - pd.testing.assert_frame_equal(result, stub_sample_dataframe) - - # Verify write_batch_to_parquet_file was NOT called - stub_processor.artifact_storage.write_batch_to_parquet_file.assert_not_called() - - -def test_process_with_json_serialized_values(stub_processor: SchemaTransformProcessor) -> None: +def test_process_after_batch_with_json_serialized_values(stub_processor: SchemaTransformProcessor) -> None: # Test with JSON-serialized values in dataframe df_with_json = pd.DataFrame( { @@ -120,7 +105,7 @@ def test_process_with_json_serialized_values(stub_processor: SchemaTransformProc ) # Process the dataframe - stub_processor.process(df_with_json, current_batch_number=0) + stub_processor.process_after_batch(df_with_json, current_batch_number=0) written_dataframe: pd.DataFrame = stub_processor.artifact_storage.write_batch_to_parquet_file.call_args.kwargs[ "dataframe" ] @@ -136,7 +121,7 @@ def test_process_with_json_serialized_values(stub_processor: SchemaTransformProc assert first_output["value"] == '{"nested": "value1"}' -def test_process_with_special_characters_in_llm_output(stub_processor: SchemaTransformProcessor) -> None: +def test_process_after_batch_with_special_characters_in_llm_output(stub_processor: SchemaTransformProcessor) -> None: """Test that LLM outputs with special characters are properly escaped for JSON. This addresses GitHub issue #227 where SchemaTransformProcessor fails with JSONDecodeError @@ -155,7 +140,7 @@ def test_process_with_special_characters_in_llm_output(stub_processor: SchemaTra ) # Process should not raise JSONDecodeError - stub_processor.process(df_with_special_chars, current_batch_number=0) + stub_processor.process_after_batch(df_with_special_chars, current_batch_number=0) written_dataframe: pd.DataFrame = stub_processor.artifact_storage.write_batch_to_parquet_file.call_args.kwargs[ "dataframe" ] @@ -172,7 +157,7 @@ def test_process_with_special_characters_in_llm_output(stub_processor: SchemaTra assert outputs[3]["text"] == "Tab\there" -def test_process_with_mixed_special_characters(stub_processor: SchemaTransformProcessor) -> None: +def test_process_after_batch_with_mixed_special_characters(stub_processor: SchemaTransformProcessor) -> None: """Test complex LLM output with multiple types of special characters.""" df_complex = pd.DataFrame( { @@ -183,7 +168,7 @@ def test_process_with_mixed_special_characters(stub_processor: SchemaTransformPr } ) - stub_processor.process(df_complex, current_batch_number=0) + stub_processor.process_after_batch(df_complex, current_batch_number=0) written_dataframe: pd.DataFrame = stub_processor.artifact_storage.write_batch_to_parquet_file.call_args.kwargs[ "dataframe" ] @@ -191,3 +176,19 @@ def test_process_with_mixed_special_characters(stub_processor: SchemaTransformPr assert len(written_dataframe) == 1 output = written_dataframe.iloc[0].to_dict() assert output["text"] == 'She replied: "I\'m not sure about that\\nLet me think..."' + + +def test_process_after_batch_preview_mode_writes_single_file( + stub_processor: SchemaTransformProcessor, stub_sample_dataframe: pd.DataFrame +) -> None: + """In preview mode (current_batch_number=None), transformed output is written as a single file.""" + stub_processor.artifact_storage.write_parquet_file = Mock() + result = stub_processor.process_after_batch(stub_sample_dataframe, current_batch_number=None) + + pd.testing.assert_frame_equal(result, stub_sample_dataframe) + + stub_processor.artifact_storage.write_batch_to_parquet_file.assert_not_called() + stub_processor.artifact_storage.write_parquet_file.assert_called_once() + call_args = stub_processor.artifact_storage.write_parquet_file.call_args + assert call_args.kwargs["parquet_file_name"] == "test_schema_transform.parquet" + assert call_args.kwargs["batch_stage"] == BatchStage.PROCESSORS_OUTPUTS diff --git a/packages/data-designer-engine/tests/engine/test_validation.py b/packages/data-designer-engine/tests/engine/test_validation.py index c0cc4bc0..97f795b5 100644 --- a/packages/data-designer-engine/tests/engine/test_validation.py +++ b/packages/data-designer-engine/tests/engine/test_validation.py @@ -12,7 +12,6 @@ Score, ValidationColumnConfig, ) -from data_designer.config.dataset_builders import BuildStage from data_designer.config.models import ImageContext, ModalityDataType from data_designer.config.processors import ( DropColumnsProcessorConfig, @@ -104,12 +103,10 @@ DropColumnsProcessorConfig( name="drop_columns_processor", column_names=["inexistent_column"], - build_stage=BuildStage.POST_BATCH, ), SchemaTransformProcessorConfig( name="schema_transform_processor_invalid_reference", template={"text": "{{ invalid_reference }}"}, - build_stage=BuildStage.POST_BATCH, ), ] ALLOWED_REFERENCE = [c.name for c in COLUMNS] diff --git a/packages/data-designer/tests/interface/test_data_designer.py b/packages/data-designer/tests/interface/test_data_designer.py index 84636e59..692bfc13 100644 --- a/packages/data-designer/tests/interface/test_data_designer.py +++ b/packages/data-designer/tests/interface/test_data_designer.py @@ -12,7 +12,6 @@ from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.config_builder import DataDesignerConfigBuilder -from data_designer.config.dataset_builders import BuildStage from data_designer.config.errors import InvalidConfigError from data_designer.config.models import ModelProvider from data_designer.config.processors import DropColumnsProcessorConfig @@ -323,11 +322,7 @@ def test_preview_with_dropped_columns( SamplerColumnConfig(name="uniform", sampler_type="uniform", params={"low": 1, "high": 100}) ) - config_builder.add_processor( - DropColumnsProcessorConfig( - name="drop_columns_processor", build_stage=BuildStage.POST_BATCH, column_names=["category"] - ) - ) + config_builder.add_processor(DropColumnsProcessorConfig(name="drop_columns_processor", column_names=["category"])) data_designer = DataDesigner( artifact_path=stub_artifact_path,