From e95a20af66171dd4708cf536e6fa67cd3c24c922 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Thu, 29 Jan 2026 16:11:08 -0500 Subject: [PATCH 1/3] feat: add execution graph builder plan with reference implementation This introduces a design plan for a memory-efficient execution graph that models cell-level dependencies for async dataset generation. The reference implementation is included to help build intuition about how the concepts could work in practice. The plan is the primary artifact - the code is exploratory. --- .../column_generators/generators/base.py | 15 + .../generators/expression.py | 5 + .../column_generators/generators/samplers.py | 5 + .../generators/seed_dataset.py | 5 + .../engine/execution_graph/__init__.py | 64 ++ .../engine/execution_graph/builder.py | 238 ++++ .../execution_graph/column_descriptor.py | 76 ++ .../engine/execution_graph/completion.py | 269 +++++ .../engine/execution_graph/graph.py | 447 ++++++++ .../engine/execution_graph/node_id.py | 57 + .../engine/execution_graph/traits.py | 38 + .../tests/engine/execution_graph/__init__.py | 2 + .../tests/engine/execution_graph/conftest.py | 88 ++ .../engine/execution_graph/test_builder.py | 315 +++++ .../execution_graph/test_column_descriptor.py | 80 ++ .../engine/execution_graph/test_completion.py | 205 ++++ .../engine/execution_graph/test_graph.py | 374 ++++++ .../engine/execution_graph/test_node_id.py | 101 ++ .../engine/execution_graph/test_traits.py | 62 + .../execution_graph_builder.plan.md | 356 ++++++ .../explore_execution_graph.ipynb | 1015 +++++++++++++++++ 21 files changed, 3817 insertions(+) create mode 100644 packages/data-designer-engine/src/data_designer/engine/execution_graph/__init__.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/execution_graph/builder.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/execution_graph/column_descriptor.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/execution_graph/completion.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/execution_graph/graph.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/execution_graph/node_id.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/execution_graph/traits.py create mode 100644 packages/data-designer-engine/tests/engine/execution_graph/__init__.py create mode 100644 packages/data-designer-engine/tests/engine/execution_graph/conftest.py create mode 100644 packages/data-designer-engine/tests/engine/execution_graph/test_builder.py create mode 100644 packages/data-designer-engine/tests/engine/execution_graph/test_column_descriptor.py create mode 100644 packages/data-designer-engine/tests/engine/execution_graph/test_completion.py create mode 100644 packages/data-designer-engine/tests/engine/execution_graph/test_graph.py create mode 100644 packages/data-designer-engine/tests/engine/execution_graph/test_node_id.py create mode 100644 packages/data-designer-engine/tests/engine/execution_graph/test_traits.py create mode 100644 plans/async-engine/execution_graph_builder.plan.md create mode 100644 plans/async-engine/explore_execution_graph.ipynb diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index 862061c3..b17de4a6 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -32,6 +32,21 @@ class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC): def can_generate_from_scratch(self) -> bool: return False + @property + def is_row_streamable(self) -> bool: + """Whether this generator can emit results as rows complete. + + For cell-by-cell generators, this is always True since they process + rows independently. For full-column generators, this defaults to False + (barrier behavior) but can be overridden for generators that can + process rows independently despite operating on full columns. + + Returns: + True if the generator can emit results row-by-row, False if it + requires all inputs before producing any output. + """ + return self.get_generation_strategy() == GenerationStrategy.CELL_BY_CELL + @staticmethod @abstractmethod def get_generation_strategy() -> GenerationStrategy: ... diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/expression.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/expression.py index 98c8fa7b..4fb0c0f6 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/expression.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/expression.py @@ -20,6 +20,11 @@ class ExpressionColumnGenerator(WithJinja2UserTemplateRendering, ColumnGeneratorFullColumn[ExpressionColumnConfig]): + @property + def is_row_streamable(self) -> bool: + """Expression generators process rows independently.""" + return True + def generate(self, data: pd.DataFrame) -> pd.DataFrame: logger.info(f"🧩 Generating column `{self.config.name}` from expression") diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/samplers.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/samplers.py index de18598a..ea510927 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/samplers.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/samplers.py @@ -29,6 +29,11 @@ class SamplerColumnGenerator(FromScratchColumnGenerator[SamplerMultiColumnConfig def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.FULL_COLUMN + @property + def is_row_streamable(self) -> bool: + """Sampler generators produce independent per-row data.""" + return True + def generate(self, data: pd.DataFrame) -> pd.DataFrame: df_samplers = self.generate_from_scratch(len(data)) return concat_datasets([data, df_samplers]) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py index 64193aee..d256f6d2 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py @@ -28,6 +28,11 @@ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColu def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.FULL_COLUMN + @property + def is_row_streamable(self) -> bool: + """Seed dataset generators produce independent per-row data.""" + return True + @property def num_records_sampled(self) -> int: return self._num_records_sampled diff --git a/packages/data-designer-engine/src/data_designer/engine/execution_graph/__init__.py b/packages/data-designer-engine/src/data_designer/engine/execution_graph/__init__.py new file mode 100644 index 00000000..f7b74a5d --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/execution_graph/__init__.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Execution graph for async dataset generation. + +This package provides a memory-efficient execution graph for modeling cell-level +dependencies in dataset generation. The graph supports different generator +execution traits (start, cell-by-cell, row-streamable, barrier) with a hybrid +representation that can handle millions of records efficiently. + +Example: + >>> from data_designer.engine.execution_graph import ( + ... GraphBuilder, + ... ExecutionGraph, + ... CompletionTracker, + ... ) + >>> + >>> # Build graph from config + >>> builder = GraphBuilder(column_generator_registry) + >>> graph = builder.build(config, num_records=1_000_000) + >>> + >>> # Execute with completion tracking + >>> tracker = CompletionTracker(graph.num_records) + >>> for node in graph.iter_ready_nodes(tracker): + ... gen_cls, config = graph.get_generator_and_config(node) + ... # Execute node... + ... tracker.mark_complete(node) +""" + +from data_designer.engine.execution_graph.builder import GraphBuilder +from data_designer.engine.execution_graph.column_descriptor import ColumnDescriptor +from data_designer.engine.execution_graph.completion import ( + CompletionTracker, + ThreadSafeCompletionTracker, +) +from data_designer.engine.execution_graph.graph import ( + CompletionTrackerProtocol, + ExecutionGraph, +) +from data_designer.engine.execution_graph.node_id import ( + BarrierNodeId, + CellNodeId, + NodeId, +) +from data_designer.engine.execution_graph.traits import ExecutionTraits + +__all__ = [ + # Node identification + "CellNodeId", + "BarrierNodeId", + "NodeId", + # Traits + "ExecutionTraits", + # Column descriptor + "ColumnDescriptor", + # Graph + "ExecutionGraph", + "CompletionTrackerProtocol", + # Builder + "GraphBuilder", + # Completion tracking + "CompletionTracker", + "ThreadSafeCompletionTracker", +] diff --git a/packages/data-designer-engine/src/data_designer/engine/execution_graph/builder.py b/packages/data-designer-engine/src/data_designer/engine/execution_graph/builder.py new file mode 100644 index 00000000..37753ac6 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/execution_graph/builder.py @@ -0,0 +1,238 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Graph builder for constructing execution graphs from DataDesigner configs. + +This module provides the GraphBuilder class that constructs ExecutionGraph +instances from DataDesignerConfig objects. It infers execution traits from +generator properties (not class names) to support plugin generators. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from data_designer.config.base import ConfigBase +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + GenerationStrategy, +) +from data_designer.engine.column_generators.registry import ColumnGeneratorRegistry +from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig +from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs +from data_designer.engine.execution_graph.column_descriptor import ColumnDescriptor +from data_designer.engine.execution_graph.graph import ExecutionGraph +from data_designer.engine.execution_graph.traits import ExecutionTraits + +if TYPE_CHECKING: + from data_designer.config.data_designer_config import DataDesignerConfig + + +class GraphBuilder: + """Factory for constructing ExecutionGraph instances from DataDesigner configs. + + The GraphBuilder infers execution traits from generator properties (not class + names) to support plugin generators. It handles multi-column configs by marking + additional columns as side effects. + + Example: + >>> builder = GraphBuilder(column_generator_registry) + >>> graph = builder.build(config, num_records=1_000_000) + >>> for node in graph.iter_start_nodes(): + ... print(node) + """ + + def __init__(self, registry: ColumnGeneratorRegistry) -> None: + """Initialize the graph builder. + + Args: + registry: The column generator registry to use for looking up generators. + """ + self._registry = registry + + def build( + self, + config: DataDesignerConfig, + num_records: int, + *, + strict: bool = True, + ) -> ExecutionGraph: + """Build an execution graph from a DataDesigner config. + + Args: + config: The DataDesigner configuration. + num_records: The number of records to generate. + strict: If True, validate all dependencies exist during construction. + + Returns: + An ExecutionGraph ready for async execution. + + Raises: + ValueError: If no columns have the START trait (can generate from scratch). + """ + descriptors = self._build_column_descriptors(config) + + # Validate at least one start column exists + has_start = any(desc.is_start_column for desc in descriptors) + if not has_start: + raise ValueError( + "At least one column must be able to generate from scratch (have can_generate_from_scratch=True)" + ) + + return ExecutionGraph(num_records, descriptors, strict=strict) + + def _build_column_descriptors(self, config: DataDesignerConfig) -> list[ColumnDescriptor]: + """Build column descriptors from config. + + This method compiles the user-facing column configs (e.g., SamplerColumnConfig) + into internal multi-column configs (e.g., SamplerMultiColumnConfig) that the + registry expects, then builds descriptors from those compiled configs. + + Args: + config: The DataDesigner configuration. + + Returns: + List of ColumnDescriptor objects in topological order. + """ + # Compile user-facing configs into internal multi-column configs + compiled_configs = compile_dataset_builder_column_configs(config) + + descriptors: list[ColumnDescriptor] = [] + for col_config in compiled_configs: + descriptor = self._build_column_descriptor(col_config) + descriptors.append(descriptor) + + return descriptors + + def _build_column_descriptor(self, col_config: ConfigBase) -> ColumnDescriptor: + """Build a single column descriptor from a config. + + Args: + col_config: The column configuration (SingleColumnConfig or MultiColumnConfig). + + Returns: + A ColumnDescriptor for the column. + """ + gen_cls = self._registry.get_for_config_type(type(col_config)) + traits = self._infer_traits(gen_cls) + + if isinstance(col_config, MultiColumnConfig): + # Multi-column configs use the first column as primary name + # Additional columns are marked as side effects + primary_name = col_config.columns[0].name + additional_columns = [c.name for c in col_config.columns[1:]] + + return ColumnDescriptor( + name=primary_name, + config=col_config, + generator_cls=gen_cls, + traits=traits, + dependencies=[], # Multi-column configs typically have no dependencies + side_effects=additional_columns, + ) + + # Single column config + return ColumnDescriptor( + name=col_config.name, + config=col_config, + generator_cls=gen_cls, + traits=traits, + dependencies=col_config.required_columns, + side_effects=col_config.side_effect_columns, + ) + + def _infer_traits(self, gen_cls: type[ColumnGenerator]) -> ExecutionTraits: + """Infer execution traits from generator class properties. + + This method uses generator properties (not class names) to determine + execution traits, making it compatible with plugin generators. + + Args: + gen_cls: The generator class to analyze. + + Returns: + ExecutionTraits flags for the generator. + """ + traits = ExecutionTraits.NONE + + # Check can_generate_from_scratch - use getattr for simple property access + # This works for both class attributes and properties with simple True/False returns + can_generate = getattr(gen_cls, "can_generate_from_scratch", False) + # Handle property objects vs actual values + if isinstance(can_generate, property): + # For properties, check if the class overrides with a known pattern + can_generate = self._evaluate_property_default(gen_cls, "can_generate_from_scratch", False) + if can_generate: + traits |= ExecutionTraits.START + + # Check generation strategy + strategy = gen_cls.get_generation_strategy() + if strategy == GenerationStrategy.CELL_BY_CELL: + traits |= ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE + else: # FULL_COLUMN + # Check is_row_streamable property + is_streamable = getattr(gen_cls, "is_row_streamable", False) + if isinstance(is_streamable, property): + is_streamable = self._evaluate_property_default(gen_cls, "is_row_streamable", False) + if is_streamable: + traits |= ExecutionTraits.ROW_STREAMABLE + else: + traits |= ExecutionTraits.BARRIER + + return traits + + def _evaluate_property_default( + self, + cls: type[ColumnGenerator], + property_name: str, + default: bool, + ) -> bool: + """Evaluate the default value of a property. + + For simple properties that return True or False, this inspects the + bytecode to determine the return value without instantiating. + + Args: + cls: The class to check. + property_name: The name of the property. + default: The default value if the property cannot be determined. + + Returns: + The property's default value. + """ + try: + import dis + + prop = getattr(cls, property_name, None) + if not isinstance(prop, property) or prop.fget is None: + return default + + fget = prop.fget + code = fget.__code__ + instructions = list(dis.get_instructions(code)) + + for i, instr in enumerate(instructions): + # Python 3.13+ uses RETURN_CONST for simple constant returns + if instr.opname == "RETURN_CONST" and isinstance(instr.argval, bool): + return instr.argval + + # Pre-3.13: LOAD_CONST True/False followed by RETURN_VALUE + if instr.opname == "RETURN_VALUE" and i > 0: + prev = instructions[i - 1] + if prev.opname == "LOAD_CONST" and isinstance(prev.argval, bool): + return prev.argval + + # If not a simple constant return, check based on property name + if property_name == "is_row_streamable": + # Default implementation compares generation strategy + strategy = cls.get_generation_strategy() + return strategy == GenerationStrategy.CELL_BY_CELL + + if property_name == "can_generate_from_scratch": + # Default is False in base ColumnGenerator + return default + + except Exception: + pass + + return default diff --git a/packages/data-designer-engine/src/data_designer/engine/execution_graph/column_descriptor.py b/packages/data-designer-engine/src/data_designer/engine/execution_graph/column_descriptor.py new file mode 100644 index 00000000..037ab60b --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/execution_graph/column_descriptor.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Column descriptor for the execution graph. + +This module defines the ColumnDescriptor dataclass that stores metadata about +a column in the execution graph, including its configuration, generator class, +execution traits, and dependencies. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from data_designer.engine.execution_graph.traits import ExecutionTraits + +if TYPE_CHECKING: + from data_designer.config.base import ConfigBase + from data_designer.engine.column_generators.generators.base import ColumnGenerator + + +@dataclass(slots=True) +class ColumnDescriptor: + """Metadata describing a column in the execution graph. + + Attributes: + name: The primary column name (for multi-column configs, this is the first column). + config: The column configuration object. + generator_cls: The generator class to use for this column. + traits: Execution traits inferred from the generator. + dependencies: List of column names this column depends on (from required_columns). + side_effects: List of additional column names this generator produces. + """ + + name: str + config: ConfigBase + generator_cls: type[ColumnGenerator] + traits: ExecutionTraits + dependencies: list[str] = field(default_factory=list) + side_effects: list[str] = field(default_factory=list) + + @property + def is_start_column(self) -> bool: + """Whether this column can generate data from scratch.""" + return bool(self.traits & ExecutionTraits.START) + + @property + def is_cell_by_cell(self) -> bool: + """Whether this column processes individual cells independently.""" + return bool(self.traits & ExecutionTraits.CELL_BY_CELL) + + @property + def is_row_streamable(self) -> bool: + """Whether this column can emit results as rows complete.""" + return bool(self.traits & ExecutionTraits.ROW_STREAMABLE) + + @property + def is_barrier(self) -> bool: + """Whether this column requires all inputs before producing any output.""" + return bool(self.traits & ExecutionTraits.BARRIER) + + @property + def has_dependencies(self) -> bool: + """Whether this column has any dependencies.""" + return len(self.dependencies) > 0 + + @property + def has_side_effects(self) -> bool: + """Whether this column produces additional columns as side effects.""" + return len(self.side_effects) > 0 + + @property + def all_produced_columns(self) -> list[str]: + """All column names produced by this generator (primary + side effects).""" + return [self.name, *self.side_effects] diff --git a/packages/data-designer-engine/src/data_designer/engine/execution_graph/completion.py b/packages/data-designer-engine/src/data_designer/engine/execution_graph/completion.py new file mode 100644 index 00000000..152a3d68 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/execution_graph/completion.py @@ -0,0 +1,269 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Completion tracking for execution graph nodes. + +This module provides memory-efficient completion trackers for tracking which +nodes have completed in an execution graph. Two variants are provided: + +- CompletionTracker: Simple, no locks, for asyncio/single-threaded use +- ThreadSafeCompletionTracker: Thread-safe variant with internal locking + +Both trackers use O(C) memory instead of O(C×R) by tracking fully completed +columns as sets and only storing partial progress for in-progress columns. +""" + +from __future__ import annotations + +import threading + +from data_designer.engine.execution_graph.node_id import BarrierNodeId, CellNodeId, NodeId + + +class CompletionTracker: + """Memory-efficient completion tracking for large datasets. + + Instead of storing every completed NodeId (O(C×R) memory), this tracker: + - Tracks fully completed columns as a set of names (O(C)) + - Only stores partial completion progress for in-progress columns + - Automatically compacts when columns fully complete + + This tracker is NOT thread-safe. Use ThreadSafeCompletionTracker for + concurrent access from multiple threads. + + Example: + >>> tracker = CompletionTracker(num_records=1000) + >>> tracker.mark_complete(CellNodeId(0, "col_a")) + >>> CellNodeId(0, "col_a") in tracker + True + >>> tracker.is_column_complete("col_a") + False + """ + + def __init__(self, num_records: int) -> None: + """Initialize the completion tracker. + + Args: + num_records: The total number of records in the dataset. + """ + self._num_records = num_records + self._completed_columns: set[str] = set() + self._completed_barriers: set[str] = set() + self._column_completion: dict[str, set[int]] = {} + + @property + def num_records(self) -> int: + """The total number of records in the dataset.""" + return self._num_records + + def mark_complete(self, node: NodeId) -> None: + """Mark a node as completed. + + Args: + node: The node to mark as complete. + """ + if isinstance(node, BarrierNodeId): + self._completed_barriers.add(node.column) + elif isinstance(node, CellNodeId): + # Skip if column is already fully complete + if node.column in self._completed_columns: + return + + progress = self._column_completion.setdefault(node.column, set()) + progress.add(node.row) + + # Check if column is now fully complete + if len(progress) == self._num_records: + self._completed_columns.add(node.column) + # Remove partial progress to save memory + del self._column_completion[node.column] + + def is_complete(self, node: NodeId) -> bool: + """Check if a node is completed. + + Args: + node: The node to check. + + Returns: + True if the node is completed, False otherwise. + """ + if isinstance(node, BarrierNodeId): + return node.column in self._completed_barriers + elif isinstance(node, CellNodeId): + if node.column in self._completed_columns: + return True + progress = self._column_completion.get(node.column, set()) + return node.row in progress + return False + + def is_column_complete(self, column: str) -> bool: + """Check if all cells of a column are complete. + + Args: + column: The column name to check. + + Returns: + True if all cells of the column are complete, False otherwise. + """ + return column in self._completed_columns + + def is_barrier_complete(self, column: str) -> bool: + """Check if a barrier is complete. + + Args: + column: The column name of the barrier. + + Returns: + True if the barrier is complete, False otherwise. + """ + return column in self._completed_barriers + + def column_completion_count(self, column: str) -> int: + """Get the number of completed cells for a column. + + Args: + column: The column name. + + Returns: + The number of completed cells. + """ + if column in self._completed_columns: + return self._num_records + return len(self._column_completion.get(column, set())) + + def reset(self) -> None: + """Reset the tracker to its initial state.""" + self._completed_columns.clear() + self._completed_barriers.clear() + self._column_completion.clear() + + def __contains__(self, node: NodeId) -> bool: + """Support `node in tracker` syntax. + + Args: + node: The node to check. + + Returns: + True if the node is completed. + """ + return self.is_complete(node) + + def __len__(self) -> int: + """Return the total number of completed nodes. + + Note: This is O(C) where C is the number of in-progress columns, + not O(C×R) since we compact fully completed columns. + """ + completed_cells = len(self._completed_columns) * self._num_records + sum( + len(progress) for progress in self._column_completion.values() + ) + completed_barriers = len(self._completed_barriers) + return completed_cells + completed_barriers + + +class ThreadSafeCompletionTracker: + """Thread-safe completion tracker for concurrent access. + + This tracker wraps a CompletionTracker with a lock to ensure thread-safe + access when marking nodes complete from multiple threads concurrently. + + Use this tracker when using thread pool executors or other multi-threaded + execution environments. + + Example: + >>> tracker = ThreadSafeCompletionTracker(num_records=1000) + >>> # Safe to call from multiple threads + >>> tracker.mark_complete(CellNodeId(0, "col_a")) + """ + + def __init__(self, num_records: int) -> None: + """Initialize the thread-safe completion tracker. + + Args: + num_records: The total number of records in the dataset. + """ + self._tracker = CompletionTracker(num_records) + self._lock = threading.Lock() + + @property + def num_records(self) -> int: + """The total number of records in the dataset.""" + return self._tracker.num_records + + def mark_complete(self, node: NodeId) -> None: + """Mark a node as completed (thread-safe). + + Args: + node: The node to mark as complete. + """ + with self._lock: + self._tracker.mark_complete(node) + + def is_complete(self, node: NodeId) -> bool: + """Check if a node is completed (thread-safe). + + Args: + node: The node to check. + + Returns: + True if the node is completed, False otherwise. + """ + with self._lock: + return self._tracker.is_complete(node) + + def is_column_complete(self, column: str) -> bool: + """Check if all cells of a column are complete (thread-safe). + + Args: + column: The column name to check. + + Returns: + True if all cells of the column are complete, False otherwise. + """ + with self._lock: + return self._tracker.is_column_complete(column) + + def is_barrier_complete(self, column: str) -> bool: + """Check if a barrier is complete (thread-safe). + + Args: + column: The column name of the barrier. + + Returns: + True if the barrier is complete, False otherwise. + """ + with self._lock: + return self._tracker.is_barrier_complete(column) + + def column_completion_count(self, column: str) -> int: + """Get the number of completed cells for a column (thread-safe). + + Args: + column: The column name. + + Returns: + The number of completed cells. + """ + with self._lock: + return self._tracker.column_completion_count(column) + + def reset(self) -> None: + """Reset the tracker to its initial state (thread-safe).""" + with self._lock: + self._tracker.reset() + + def __contains__(self, node: NodeId) -> bool: + """Support `node in tracker` syntax (thread-safe). + + Args: + node: The node to check. + + Returns: + True if the node is completed. + """ + return self.is_complete(node) + + def __len__(self) -> int: + """Return the total number of completed nodes (thread-safe).""" + with self._lock: + return len(self._tracker) diff --git a/packages/data-designer-engine/src/data_designer/engine/execution_graph/graph.py b/packages/data-designer-engine/src/data_designer/engine/execution_graph/graph.py new file mode 100644 index 00000000..feb343d7 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/execution_graph/graph.py @@ -0,0 +1,447 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Execution graph for async dataset generation. + +This module provides the ExecutionGraph class, which implements a hybrid +representation where column structure is stored explicitly while cell-level +nodes are virtual/computed on-demand to handle millions of records efficiently. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +from data_designer.engine.execution_graph.column_descriptor import ColumnDescriptor +from data_designer.engine.execution_graph.node_id import BarrierNodeId, CellNodeId, NodeId + +if TYPE_CHECKING: + from data_designer.config.base import ConfigBase + from data_designer.engine.column_generators.generators.base import ColumnGenerator + + +@runtime_checkable +class CompletionTrackerProtocol(Protocol): + """Protocol for completion tracking - enables duck typing with any tracker implementation.""" + + def is_complete(self, node: NodeId) -> bool: + """Check if a node is completed.""" + ... + + def __contains__(self, node: NodeId) -> bool: + """Support `node in tracker` syntax.""" + ... + + +class ExecutionGraph: + """Execution graph for async dataset generation with hybrid representation. + + The ExecutionGraph models cell-level dependencies for dataset generation while + maintaining memory efficiency through a hybrid representation: + + - **Explicit**: Column structure (ColumnDescriptors) is stored in memory + - **Virtual**: Cell-level nodes are computed on-demand, not stored + + This allows handling datasets with millions of records without creating + millions of explicit nodes and edges. + + The graph supports different execution traits: + - START: Columns that can generate data from scratch (no dependencies) + - CELL_BY_CELL: Columns that process individual cells independently + - ROW_STREAMABLE: Columns that can emit results as rows complete + - BARRIER: Columns that require all input rows before producing any output + + Attributes: + num_records: The number of records (rows) in the dataset. + num_columns: The number of columns in the dataset. + num_nodes: The total number of virtual nodes (cells + barriers). + + Examples: + >>> graph = ExecutionGraph(num_records=1000, column_descriptors=[...]) + >>> for node in graph.iter_start_nodes(): + ... print(node) # Cell(0, 'category'), Cell(1, 'category'), ... + >>> deps = graph.get_dependencies(CellNodeId(5, 'question')) + >>> # Returns [CellNodeId(5, 'context')] for a cell-by-cell column + """ + + def __init__( + self, + num_records: int, + column_descriptors: list[ColumnDescriptor], + *, + strict: bool = True, + ) -> None: + """Initialize the execution graph. + + Args: + num_records: The number of records (rows) to generate. + column_descriptors: List of column descriptors in topological order. + strict: If True, validate all dependencies exist during construction. + + Raises: + ValueError: If num_records is not positive, column_descriptors is empty, + or strict=True and dependencies are invalid. + """ + if num_records <= 0: + raise ValueError(f"num_records must be positive, got {num_records}") + if not column_descriptors: + raise ValueError("column_descriptors cannot be empty") + + self._num_records = num_records + self._columns: dict[str, ColumnDescriptor] = {desc.name: desc for desc in column_descriptors} + self._topo_order: list[str] = [desc.name for desc in column_descriptors] + + # Build reverse mapping for side effects + self._side_effect_to_parent: dict[str, str] = {} + for desc in column_descriptors: + for side_effect in desc.side_effects: + self._side_effect_to_parent[side_effect] = desc.name + + # Cache start columns + self._start_columns: list[str] = [name for name, desc in self._columns.items() if desc.is_start_column] + + # Cache barrier columns for efficient lookup + self._barrier_columns: set[str] = {name for name, desc in self._columns.items() if desc.is_barrier} + + if strict: + self._validate_dependencies() + + def _validate_dependencies(self) -> None: + """Validate all dependencies can be resolved. + + Raises: + ValueError: If any column has an unresolvable dependency. + """ + all_columns = set(self._columns.keys()) + all_side_effects = set(self._side_effect_to_parent.keys()) + known_names = all_columns | all_side_effects + + errors: list[str] = [] + for col_name, desc in self._columns.items(): + for dep in desc.dependencies: + if dep not in known_names: + errors.append(f"Column '{col_name}' depends on unknown column '{dep}'") + + if errors: + msg = "Invalid dependencies in execution graph:\n" + "\n".join(f" - {e}" for e in errors) + raise ValueError(msg) + + @property + def num_records(self) -> int: + """The number of records (rows) in the dataset.""" + return self._num_records + + @property + def num_columns(self) -> int: + """The number of columns in the dataset.""" + return len(self._columns) + + @property + def num_nodes(self) -> int: + """The total number of virtual nodes in the graph. + + This is computed as: + - For each non-barrier column: num_records cell nodes + - For each barrier column: 1 barrier node + num_records cell nodes + """ + barrier_count = len(self._barrier_columns) + return (self.num_columns * self._num_records) + barrier_count + + @property + def start_columns(self) -> list[str]: + """List of column names that can generate data from scratch.""" + return self._start_columns.copy() + + @property + def column_names(self) -> list[str]: + """List of column names in topological order.""" + return self._topo_order.copy() + + def get_column_descriptor(self, column: str) -> ColumnDescriptor: + """Get the descriptor for a column. + + Args: + column: The column name. + + Returns: + The ColumnDescriptor for the specified column. + + Raises: + KeyError: If the column does not exist. + """ + return self._columns[column] + + def get_generator_and_config(self, node: NodeId) -> tuple[type[ColumnGenerator], ConfigBase]: + """Get the generator class and config for a node. + + For barrier nodes, returns the generator/config of the associated column. + For cell nodes, returns the generator/config of that column. + + Args: + node: The node identifier. + + Returns: + A tuple of (generator_class, column_config). + + Raises: + KeyError: If the column does not exist. + """ + column = node.column + desc = self._columns[column] + return desc.generator_cls, desc.config + + def _resolve_column(self, name: str) -> str | None: + """Resolve a column name, following side effect mappings if needed. + + Args: + name: The column name to resolve. + + Returns: + The resolved column name, or None if not found. + """ + if name in self._columns: + return name + if name in self._side_effect_to_parent: + return self._side_effect_to_parent[name] + return None + + def get_dependencies(self, node: NodeId) -> list[NodeId]: + """Get the dependencies for a node. + + Dependency resolution depends on the column's execution traits: + + - **START columns**: No dependencies (empty list) + - **CELL_BY_CELL / ROW_STREAMABLE**: Row-local dependencies + - Cell (row=r, col=C) depends on Cell (row=r, col=D) for each D in required_columns + - **BARRIER columns**: + - BarrierNodeId depends on ALL cells of ALL dependency columns + - CellNodeId depends on the BarrierNodeId of its column + + Args: + node: The node to get dependencies for. + + Returns: + List of NodeIds that this node depends on. + """ + column = node.column + desc = self._columns[column] + + # START columns with no dependencies + if desc.is_start_column and not desc.has_dependencies: + return [] + + # Handle barrier columns + if desc.is_barrier: + if isinstance(node, BarrierNodeId): + # Barrier depends on ALL cells of ALL dependency columns + deps: list[NodeId] = [] + for dep_name in desc.dependencies: + resolved = self._resolve_column(dep_name) + if resolved is None: + continue + + dep_desc = self._columns[resolved] + # If dep is also a barrier, depend on its barrier node + if dep_desc.is_barrier: + deps.append(BarrierNodeId(resolved)) + else: + # Depend on all cells of the dependency column + for r in range(self._num_records): + deps.append(CellNodeId(r, resolved)) + return deps + elif isinstance(node, CellNodeId): + # Output cells depend on the barrier + return [BarrierNodeId(column)] + + # CELL_BY_CELL and ROW_STREAMABLE: row-local dependencies + if isinstance(node, CellNodeId): + row = node.row + deps = [] + for dep_name in desc.dependencies: + resolved = self._resolve_column(dep_name) + if resolved is None: + continue + + dep_desc = self._columns[resolved] + # If dep is a barrier, depend on the barrier node (not individual cells) + if dep_desc.is_barrier: + deps.append(BarrierNodeId(resolved)) + else: + deps.append(CellNodeId(row, resolved)) + return deps + + return [] + + def get_dependents(self, node: NodeId) -> list[NodeId]: + """Get the nodes that depend on this node. + + This is the reverse of get_dependencies. Useful for scheduling + dependent tasks when a node completes. + + Args: + node: The node to get dependents for. + + Returns: + List of NodeIds that depend on this node. + """ + dependents: list[NodeId] = [] + column = node.column + + if isinstance(node, CellNodeId): + row = node.row + + # Find columns that depend on this column + for col_name, desc in self._columns.items(): + if column not in desc.dependencies: + # Check if this column is the parent of a side effect dependency + is_side_effect_parent = any( + self._side_effect_to_parent.get(dep) == column for dep in desc.dependencies + ) + if not is_side_effect_parent: + continue + + if desc.is_barrier: + # Barrier depends on all cells, so add barrier node + dependents.append(BarrierNodeId(col_name)) + else: + # Row-local dependency + dependents.append(CellNodeId(row, col_name)) + + elif isinstance(node, BarrierNodeId): + # Barrier completion triggers all cells of that column + for r in range(self._num_records): + dependents.append(CellNodeId(r, column)) + + # Also check if other barriers depend on this barrier + for col_name, desc in self._columns.items(): + if desc.is_barrier and column in desc.dependencies: + dependents.append(BarrierNodeId(col_name)) + + return dependents + + def iter_nodes(self) -> Iterator[NodeId]: + """Iterate over all virtual nodes in the graph. + + Yields nodes in a consistent order: + 1. For each column in topological order: + - If barrier column: yield BarrierNodeId, then all CellNodeIds + - Otherwise: yield all CellNodeIds + + Yields: + NodeId instances (CellNodeId or BarrierNodeId). + """ + for col_name in self._topo_order: + desc = self._columns[col_name] + if desc.is_barrier: + yield BarrierNodeId(col_name) + for row in range(self._num_records): + yield CellNodeId(row, col_name) + + def iter_start_nodes(self) -> Iterator[CellNodeId]: + """Iterate over nodes that can start immediately (no dependencies). + + These are all cell nodes from START columns. + + Yields: + CellNodeId instances for start columns. + """ + for col_name in self._start_columns: + for row in range(self._num_records): + yield CellNodeId(row, col_name) + + def iter_ready_nodes(self, completed: CompletionTrackerProtocol) -> Iterator[NodeId]: + """Iterate over nodes whose dependencies are all satisfied. + + A node is ready if all of its dependencies are in the completed set. + This method is the primary interface for async execution engines. + + Args: + completed: A completion tracker (or set) indicating completed nodes. + + Yields: + NodeId instances that are ready for execution. + + Note: + For large datasets, consider using iter_ready_nodes_for_column + for more efficient targeted queries. + """ + for col_name in self._topo_order: + desc = self._columns[col_name] + + if desc.is_barrier: + barrier_node = BarrierNodeId(col_name) + if barrier_node not in completed: + deps = self.get_dependencies(barrier_node) + if all(dep in completed for dep in deps): + yield barrier_node + + # Check cell nodes only if barrier is complete + if barrier_node in completed: + for row in range(self._num_records): + cell_node = CellNodeId(row, col_name) + if cell_node not in completed: + yield cell_node + else: + for row in range(self._num_records): + cell_node = CellNodeId(row, col_name) + if cell_node not in completed: + deps = self.get_dependencies(cell_node) + if all(dep in completed for dep in deps): + yield cell_node + + def iter_ready_nodes_for_column(self, column: str, completed: CompletionTrackerProtocol) -> Iterator[NodeId]: + """Iterate over ready nodes for a specific column. + + More efficient than iter_ready_nodes when you know which column + to check, as it avoids scanning all columns. + + Args: + column: The column name to check. + completed: A completion tracker (or set) indicating completed nodes. + + Yields: + NodeId instances from the specified column that are ready. + """ + desc = self._columns[column] + + if desc.is_barrier: + barrier_node = BarrierNodeId(column) + if barrier_node not in completed: + deps = self.get_dependencies(barrier_node) + if all(dep in completed for dep in deps): + yield barrier_node + return # Can't yield cells until barrier is complete + + # Barrier is complete, yield all incomplete cells + for row in range(self._num_records): + cell_node = CellNodeId(row, column) + if cell_node not in completed: + yield cell_node + else: + for row in range(self._num_records): + cell_node = CellNodeId(row, column) + if cell_node not in completed: + deps = self.get_dependencies(cell_node) + if all(dep in completed for dep in deps): + yield cell_node + + def is_complete(self, completed: CompletionTrackerProtocol) -> bool: + """Check if all nodes in the graph have been completed. + + Args: + completed: A completion tracker (or set) indicating completed nodes. + + Returns: + True if all nodes are in the completed set. + """ + # Quick check: minimum required completions + expected = self.num_nodes + if hasattr(completed, "__len__") and len(completed) < expected: + return False + + # Verify all expected nodes are present + for node in self.iter_nodes(): + if node not in completed: + return False + return True diff --git a/packages/data-designer-engine/src/data_designer/engine/execution_graph/node_id.py b/packages/data-designer-engine/src/data_designer/engine/execution_graph/node_id.py new file mode 100644 index 00000000..0d8c335d --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/execution_graph/node_id.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Node identification types for the execution graph. + +This module defines the node ID types used to identify individual units of work +in the execution graph. There are two types of nodes: + +- CellNodeId: Identifies a single cell (row, column) in the dataset +- BarrierNodeId: Identifies a barrier synchronization point for a column + +Using frozen dataclasses with slots=True for memory efficiency when handling +millions of nodes. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TypeAlias + + +@dataclass(frozen=True, slots=True) +class CellNodeId: + """Identifies a single cell in the dataset. + + Attributes: + row: The row index (0-based). + column: The column name. + """ + + row: int + column: str + + def __repr__(self) -> str: + return f"Cell({self.row}, {self.column!r})" + + +@dataclass(frozen=True, slots=True) +class BarrierNodeId: + """Identifies a barrier synchronization point for a column. + + Barrier nodes represent a synchronization point where all input cells + must complete before any output cells can begin. This is used for + full-column generators that cannot process rows independently. + + Attributes: + column: The column name this barrier is for. + """ + + column: str + + def __repr__(self) -> str: + return f"Barrier({self.column!r})" + + +NodeId: TypeAlias = CellNodeId | BarrierNodeId +"""Type alias for any node identifier in the execution graph.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/execution_graph/traits.py b/packages/data-designer-engine/src/data_designer/engine/execution_graph/traits.py new file mode 100644 index 00000000..bdae8afb --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/execution_graph/traits.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Execution traits for column generators. + +This module defines the ExecutionTraits flag enum that describes the execution +characteristics of column generators. Traits are inferred from generator properties +(not hardcoded class names) to support plugin generators. +""" + +from __future__ import annotations + +from enum import Flag, auto + + +class ExecutionTraits(Flag): + """Flags describing execution characteristics of a column generator. + + Traits can be combined using bitwise operators: + traits = ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE + + Attributes: + NONE: No special traits (default). + START: Generator can produce data from scratch without dependencies. + Inferred from `can_generate_from_scratch = True`. + CELL_BY_CELL: Generator processes individual cells independently. + Inferred from `get_generation_strategy() == CELL_BY_CELL`. + ROW_STREAMABLE: Generator can emit results as rows complete. + Inferred from `is_row_streamable = True`. + BARRIER: Generator requires all input rows before producing any output. + Inferred from full-column strategy with `is_row_streamable = False`. + """ + + NONE = 0 + START = auto() + CELL_BY_CELL = auto() + ROW_STREAMABLE = auto() + BARRIER = auto() diff --git a/packages/data-designer-engine/tests/engine/execution_graph/__init__.py b/packages/data-designer-engine/tests/engine/execution_graph/__init__.py new file mode 100644 index 00000000..52a7a9da --- /dev/null +++ b/packages/data-designer-engine/tests/engine/execution_graph/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/data-designer-engine/tests/engine/execution_graph/conftest.py b/packages/data-designer-engine/tests/engine/execution_graph/conftest.py new file mode 100644 index 00000000..e374a926 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/execution_graph/conftest.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared fixtures for execution graph tests.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import Mock + +import pytest + +from data_designer.engine.execution_graph.column_descriptor import ColumnDescriptor +from data_designer.engine.execution_graph.traits import ExecutionTraits + +if TYPE_CHECKING: + from data_designer.config.base import ConfigBase + from data_designer.engine.column_generators.generators.base import ColumnGenerator + + +def create_mock_config(name: str) -> ConfigBase: + """Create a mock config with a name attribute.""" + mock = Mock() + mock.name = name + mock.required_columns = [] + mock.side_effect_columns = [] + return mock + + +def create_mock_generator() -> type[ColumnGenerator]: + """Create a mock generator class.""" + return Mock() + + +def create_descriptor( + name: str, + traits: ExecutionTraits, + dependencies: list[str] | None = None, + side_effects: list[str] | None = None, +) -> ColumnDescriptor: + """Helper to create column descriptors for testing.""" + return ColumnDescriptor( + name=name, + config=create_mock_config(name), + generator_cls=create_mock_generator(), + traits=traits, + dependencies=dependencies or [], + side_effects=side_effects or [], + ) + + +@pytest.fixture +def start_descriptor() -> ColumnDescriptor: + """A descriptor for a start column (can generate from scratch).""" + return create_descriptor( + name="category", + traits=ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE, + ) + + +@pytest.fixture +def cell_by_cell_descriptor() -> ColumnDescriptor: + """A descriptor for a cell-by-cell column with a dependency.""" + return create_descriptor( + name="question", + traits=ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + dependencies=["category"], + ) + + +@pytest.fixture +def barrier_descriptor() -> ColumnDescriptor: + """A descriptor for a barrier column.""" + return create_descriptor( + name="validation", + traits=ExecutionTraits.BARRIER, + dependencies=["question"], + ) + + +@pytest.fixture +def multi_column_descriptor() -> ColumnDescriptor: + """A descriptor for a multi-column generator with side effects.""" + return create_descriptor( + name="person_name", + traits=ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE, + side_effects=["person_email", "person_phone"], + ) diff --git a/packages/data-designer-engine/tests/engine/execution_graph/test_builder.py b/packages/data-designer-engine/tests/engine/execution_graph/test_builder.py new file mode 100644 index 00000000..8c871b3b --- /dev/null +++ b/packages/data-designer-engine/tests/engine/execution_graph/test_builder.py @@ -0,0 +1,315 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for graph builder.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import Mock, patch + +import pytest + +from data_designer.engine.column_generators.generators.base import GenerationStrategy +from data_designer.engine.execution_graph.builder import GraphBuilder +from data_designer.engine.execution_graph.traits import ExecutionTraits + +if TYPE_CHECKING: + from data_designer.config.data_designer_config import DataDesignerConfig + + +def create_mock_single_column_config( + name: str, + required_columns: list[str] | None = None, + side_effect_columns: list[str] | None = None, +) -> Mock: + """Create a mock single column config.""" + config = Mock() + config.name = name + config.required_columns = required_columns or [] + config.side_effect_columns = side_effect_columns or [] + return config + + +def create_mock_multi_column_config(column_names: list[str]) -> Mock: + """Create a mock multi-column config.""" + from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig + + config = Mock(spec=MultiColumnConfig) + config.columns = [Mock(name=name) for name in column_names] + for i, col in enumerate(config.columns): + col.name = column_names[i] + return config + + +def create_mock_generator_class( + can_generate_from_scratch: bool = False, + generation_strategy: GenerationStrategy = GenerationStrategy.FULL_COLUMN, + is_row_streamable: bool = False, +) -> type: + """Create a mock generator class with specified traits. + + Returns an actual class (not a Mock) so that getattr works correctly + for property access. + """ + # Capture variables in local scope + _can_gen = can_generate_from_scratch + _is_stream = is_row_streamable + _gen_strategy = generation_strategy + + class MockGenerator: + pass + + # Set class-level attributes that getattr will find + MockGenerator.can_generate_from_scratch = _can_gen + MockGenerator.is_row_streamable = _is_stream + + # Create a static method that returns the captured strategy + def get_generation_strategy() -> GenerationStrategy: + return _gen_strategy + + MockGenerator.get_generation_strategy = staticmethod(get_generation_strategy) + + return MockGenerator + + +def create_mock_registry(generator_map: dict[type, Mock]) -> Mock: + """Create a mock registry that returns generators based on config type.""" + registry = Mock() + registry.get_for_config_type.side_effect = lambda config_type: generator_map.get( + config_type, create_mock_generator_class() + ) + return registry + + +def create_mock_data_designer_config(column_configs: list[Mock]) -> DataDesignerConfig: + """Create a mock DataDesignerConfig.""" + config = Mock() + config.columns = column_configs + return config + + +# --- Build tests --- + + +@patch("data_designer.engine.execution_graph.builder.compile_dataset_builder_column_configs") +def test_build_with_single_start_column(mock_compile: Mock) -> None: + col_config = create_mock_single_column_config("starter") + gen_cls = create_mock_generator_class( + can_generate_from_scratch=True, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + ) + registry = create_mock_registry({type(col_config): gen_cls}) + config = create_mock_data_designer_config([col_config]) + + # Mock the compiler to return the configs directly + mock_compile.return_value = [col_config] + + builder = GraphBuilder(registry) + graph = builder.build(config, num_records=100) + + assert graph.num_records == 100 + assert graph.num_columns == 1 + assert "starter" in graph.start_columns + + +@patch("data_designer.engine.execution_graph.builder.compile_dataset_builder_column_configs") +def test_build_requires_at_least_one_start_column(mock_compile: Mock) -> None: + col_config = create_mock_single_column_config("no_start") + gen_cls = create_mock_generator_class( + can_generate_from_scratch=False, + generation_strategy=GenerationStrategy.FULL_COLUMN, + is_row_streamable=False, + ) + registry = create_mock_registry({type(col_config): gen_cls}) + config = create_mock_data_designer_config([col_config]) + + # Mock the compiler to return the configs directly + mock_compile.return_value = [col_config] + + builder = GraphBuilder(registry) + with pytest.raises(ValueError, match="generate from scratch"): + builder.build(config, num_records=100) + + +@patch("data_designer.engine.execution_graph.builder.compile_dataset_builder_column_configs") +def test_build_with_multi_column_config(mock_compile: Mock) -> None: + multi_config = create_mock_multi_column_config(["person_name", "person_email", "person_phone"]) + gen_cls = create_mock_generator_class( + can_generate_from_scratch=True, + generation_strategy=GenerationStrategy.FULL_COLUMN, + is_row_streamable=True, + ) + registry = create_mock_registry({type(multi_config): gen_cls}) + config = create_mock_data_designer_config([multi_config]) + + # Mock the compiler to return the multi-column config directly + mock_compile.return_value = [multi_config] + + builder = GraphBuilder(registry) + graph = builder.build(config, num_records=100) + + # Should have one descriptor with first column as primary + assert graph.num_columns == 1 + desc = graph.get_column_descriptor("person_name") + assert desc.name == "person_name" + assert desc.side_effects == ["person_email", "person_phone"] + + +# --- Trait inference tests --- + + +def test_infer_start_trait() -> None: + gen_cls = create_mock_generator_class( + can_generate_from_scratch=True, + generation_strategy=GenerationStrategy.FULL_COLUMN, + ) + + builder = GraphBuilder(Mock()) + traits = builder._infer_traits(gen_cls) + + assert bool(traits & ExecutionTraits.START) + + +def test_infer_cell_by_cell_trait() -> None: + gen_cls = create_mock_generator_class( + generation_strategy=GenerationStrategy.CELL_BY_CELL, + ) + + builder = GraphBuilder(Mock()) + traits = builder._infer_traits(gen_cls) + + assert bool(traits & ExecutionTraits.CELL_BY_CELL) + assert bool(traits & ExecutionTraits.ROW_STREAMABLE) + + +def test_infer_row_streamable_full_column() -> None: + gen_cls = create_mock_generator_class( + generation_strategy=GenerationStrategy.FULL_COLUMN, + is_row_streamable=True, + ) + + builder = GraphBuilder(Mock()) + traits = builder._infer_traits(gen_cls) + + assert bool(traits & ExecutionTraits.ROW_STREAMABLE) + assert not bool(traits & ExecutionTraits.BARRIER) + + +def test_infer_barrier_trait() -> None: + gen_cls = create_mock_generator_class( + generation_strategy=GenerationStrategy.FULL_COLUMN, + is_row_streamable=False, + ) + + builder = GraphBuilder(Mock()) + traits = builder._infer_traits(gen_cls) + + assert bool(traits & ExecutionTraits.BARRIER) + assert not bool(traits & ExecutionTraits.ROW_STREAMABLE) + + +# --- Dependencies tests --- + + +@patch("data_designer.engine.execution_graph.builder.compile_dataset_builder_column_configs") +def test_single_column_dependencies(mock_compile: Mock) -> None: + start_config = create_mock_single_column_config("starter") + dep_config = create_mock_single_column_config( + "dependent", + required_columns=["starter"], + ) + + start_gen = create_mock_generator_class( + can_generate_from_scratch=True, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + ) + dep_gen = create_mock_generator_class( + generation_strategy=GenerationStrategy.CELL_BY_CELL, + ) + + registry = create_mock_registry( + { + type(start_config): start_gen, + type(dep_config): dep_gen, + } + ) + config = create_mock_data_designer_config([start_config, dep_config]) + + # Mock the compiler to return the configs directly + mock_compile.return_value = [start_config, dep_config] + + builder = GraphBuilder(registry) + graph = builder.build(config, num_records=10) + + dep_desc = graph.get_column_descriptor("dependent") + assert dep_desc.dependencies == ["starter"] + + +@patch("data_designer.engine.execution_graph.builder.compile_dataset_builder_column_configs") +def test_side_effect_columns_captured(mock_compile: Mock) -> None: + config = create_mock_single_column_config( + "llm_output", + side_effect_columns=["reasoning_trace"], + ) + gen_cls = create_mock_generator_class( + can_generate_from_scratch=True, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + ) + registry = create_mock_registry({type(config): gen_cls}) + dd_config = create_mock_data_designer_config([config]) + + # Mock the compiler to return the config directly + mock_compile.return_value = [config] + + builder = GraphBuilder(registry) + graph = builder.build(dd_config, num_records=10) + + desc = graph.get_column_descriptor("llm_output") + assert desc.side_effects == ["reasoning_trace"] + + +# --- Strict mode tests --- + + +@patch("data_designer.engine.execution_graph.builder.compile_dataset_builder_column_configs") +def test_strict_mode_validates_dependencies(mock_compile: Mock) -> None: + config = create_mock_single_column_config( + "dependent", + required_columns=["nonexistent"], + ) + gen_cls = create_mock_generator_class( + can_generate_from_scratch=True, # To pass the start column check + generation_strategy=GenerationStrategy.CELL_BY_CELL, + ) + registry = create_mock_registry({type(config): gen_cls}) + dd_config = create_mock_data_designer_config([config]) + + # Mock the compiler to return the config directly + mock_compile.return_value = [config] + + builder = GraphBuilder(registry) + with pytest.raises(ValueError, match="unknown column"): + builder.build(dd_config, num_records=10, strict=True) + + +@patch("data_designer.engine.execution_graph.builder.compile_dataset_builder_column_configs") +def test_non_strict_mode_skips_validation(mock_compile: Mock) -> None: + config = create_mock_single_column_config( + "dependent", + required_columns=["nonexistent"], + ) + gen_cls = create_mock_generator_class( + can_generate_from_scratch=True, + generation_strategy=GenerationStrategy.CELL_BY_CELL, + ) + registry = create_mock_registry({type(config): gen_cls}) + dd_config = create_mock_data_designer_config([config]) + + # Mock the compiler to return the config directly + mock_compile.return_value = [config] + + builder = GraphBuilder(registry) + # Should not raise + graph = builder.build(dd_config, num_records=10, strict=False) + assert graph.num_columns == 1 diff --git a/packages/data-designer-engine/tests/engine/execution_graph/test_column_descriptor.py b/packages/data-designer-engine/tests/engine/execution_graph/test_column_descriptor.py new file mode 100644 index 00000000..609c4a99 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/execution_graph/test_column_descriptor.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for column descriptor.""" + +from data_designer.engine.execution_graph.traits import ExecutionTraits + +from .conftest import create_descriptor # noqa: TID252 + + +def test_is_start_column() -> None: + start = create_descriptor("col", ExecutionTraits.START) + non_start = create_descriptor("col", ExecutionTraits.CELL_BY_CELL) + + assert start.is_start_column + assert not non_start.is_start_column + + +def test_is_cell_by_cell() -> None: + cell = create_descriptor("col", ExecutionTraits.CELL_BY_CELL) + non_cell = create_descriptor("col", ExecutionTraits.BARRIER) + + assert cell.is_cell_by_cell + assert not non_cell.is_cell_by_cell + + +def test_is_row_streamable() -> None: + streamable = create_descriptor("col", ExecutionTraits.ROW_STREAMABLE) + non_streamable = create_descriptor("col", ExecutionTraits.BARRIER) + + assert streamable.is_row_streamable + assert not non_streamable.is_row_streamable + + +def test_is_barrier() -> None: + barrier = create_descriptor("col", ExecutionTraits.BARRIER) + non_barrier = create_descriptor("col", ExecutionTraits.CELL_BY_CELL) + + assert barrier.is_barrier + assert not non_barrier.is_barrier + + +def test_has_dependencies() -> None: + with_deps = create_descriptor("col", ExecutionTraits.NONE, dependencies=["dep1"]) + without_deps = create_descriptor("col", ExecutionTraits.NONE) + + assert with_deps.has_dependencies + assert not without_deps.has_dependencies + + +def test_has_side_effects() -> None: + with_effects = create_descriptor("col", ExecutionTraits.NONE, side_effects=["effect1"]) + without_effects = create_descriptor("col", ExecutionTraits.NONE) + + assert with_effects.has_side_effects + assert not without_effects.has_side_effects + + +def test_all_produced_columns() -> None: + without_effects = create_descriptor("primary", ExecutionTraits.NONE) + with_effects = create_descriptor( + "primary", + ExecutionTraits.NONE, + side_effects=["secondary", "tertiary"], + ) + + assert without_effects.all_produced_columns == ["primary"] + assert with_effects.all_produced_columns == ["primary", "secondary", "tertiary"] + + +def test_combined_traits() -> None: + combined = create_descriptor( + "col", + ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE, + ) + + assert combined.is_start_column + assert combined.is_row_streamable + assert not combined.is_barrier + assert not combined.is_cell_by_cell diff --git a/packages/data-designer-engine/tests/engine/execution_graph/test_completion.py b/packages/data-designer-engine/tests/engine/execution_graph/test_completion.py new file mode 100644 index 00000000..79cd889a --- /dev/null +++ b/packages/data-designer-engine/tests/engine/execution_graph/test_completion.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for completion tracking.""" + +import threading + +from data_designer.engine.execution_graph.completion import ( + CompletionTracker, + ThreadSafeCompletionTracker, +) +from data_designer.engine.execution_graph.node_id import BarrierNodeId, CellNodeId + +# --- CompletionTracker tests --- + + +def test_completion_tracker_initial_state() -> None: + tracker = CompletionTracker(num_records=100) + assert tracker.num_records == 100 + assert len(tracker) == 0 + + +def test_completion_tracker_mark_cell_complete() -> None: + tracker = CompletionTracker(num_records=10) + node = CellNodeId(5, "col_a") + + assert not tracker.is_complete(node) + tracker.mark_complete(node) + assert tracker.is_complete(node) + + +def test_completion_tracker_mark_barrier_complete() -> None: + tracker = CompletionTracker(num_records=10) + node = BarrierNodeId("col_a") + + assert not tracker.is_barrier_complete("col_a") + tracker.mark_complete(node) + assert tracker.is_barrier_complete("col_a") + + +def test_completion_tracker_contains_operator() -> None: + tracker = CompletionTracker(num_records=10) + node = CellNodeId(0, "col_a") + + assert node not in tracker + tracker.mark_complete(node) + assert node in tracker + + +def test_completion_tracker_column_completion_tracking() -> None: + tracker = CompletionTracker(num_records=3) + + # Complete some cells + tracker.mark_complete(CellNodeId(0, "col_a")) + assert tracker.column_completion_count("col_a") == 1 + assert not tracker.is_column_complete("col_a") + + tracker.mark_complete(CellNodeId(1, "col_a")) + assert tracker.column_completion_count("col_a") == 2 + assert not tracker.is_column_complete("col_a") + + tracker.mark_complete(CellNodeId(2, "col_a")) + assert tracker.column_completion_count("col_a") == 3 + assert tracker.is_column_complete("col_a") + + +def test_completion_tracker_memory_compaction() -> None: + tracker = CompletionTracker(num_records=3) + + # Complete all cells of a column + for r in range(3): + tracker.mark_complete(CellNodeId(r, "col_a")) + + # Column should be in completed_columns set, not in partial tracking + assert "col_a" in tracker._completed_columns + assert "col_a" not in tracker._column_completion + + # Individual cells should still be queryable + assert tracker.is_complete(CellNodeId(0, "col_a")) + assert tracker.is_complete(CellNodeId(1, "col_a")) + assert tracker.is_complete(CellNodeId(2, "col_a")) + + +def test_completion_tracker_len_counts_all_completed() -> None: + tracker = CompletionTracker(num_records=5) + + tracker.mark_complete(CellNodeId(0, "col_a")) + tracker.mark_complete(CellNodeId(1, "col_a")) + tracker.mark_complete(BarrierNodeId("col_b")) + + assert len(tracker) == 3 + + +def test_completion_tracker_len_with_compacted_column() -> None: + tracker = CompletionTracker(num_records=3) + + # Complete entire column + for r in range(3): + tracker.mark_complete(CellNodeId(r, "col_a")) + + # Add one cell from another column + tracker.mark_complete(CellNodeId(0, "col_b")) + + # 3 from col_a + 1 from col_b = 4 + assert len(tracker) == 4 + + +def test_completion_tracker_reset() -> None: + tracker = CompletionTracker(num_records=3) + + tracker.mark_complete(CellNodeId(0, "col_a")) + tracker.mark_complete(BarrierNodeId("col_b")) + + tracker.reset() + + assert len(tracker) == 0 + assert not tracker.is_complete(CellNodeId(0, "col_a")) + assert not tracker.is_barrier_complete("col_b") + + +def test_completion_tracker_duplicate_mark_complete() -> None: + tracker = CompletionTracker(num_records=3) + node = CellNodeId(0, "col_a") + + tracker.mark_complete(node) + tracker.mark_complete(node) # Duplicate should be safe + + assert tracker.column_completion_count("col_a") == 1 + + +def test_completion_tracker_mark_complete_after_column_compacted() -> None: + tracker = CompletionTracker(num_records=3) + + # Complete entire column + for r in range(3): + tracker.mark_complete(CellNodeId(r, "col_a")) + + # Try to mark a cell that's already complete + tracker.mark_complete(CellNodeId(0, "col_a")) # Should be no-op + + # Column should still be complete + assert tracker.is_column_complete("col_a") + assert len(tracker) == 3 + + +# --- ThreadSafeCompletionTracker tests --- + + +def test_thread_safe_tracker_basic_operations() -> None: + tracker = ThreadSafeCompletionTracker(num_records=10) + + assert tracker.num_records == 10 + assert len(tracker) == 0 + + node = CellNodeId(0, "col_a") + tracker.mark_complete(node) + assert tracker.is_complete(node) + assert node in tracker + + +def test_thread_safe_tracker_thread_safety() -> None: + tracker = ThreadSafeCompletionTracker(num_records=1000) + errors: list[Exception] = [] + + def mark_cells(start: int, end: int, column: str) -> None: + try: + for r in range(start, end): + tracker.mark_complete(CellNodeId(r, column)) + except Exception as e: + errors.append(e) + + # Create threads that mark different ranges + threads = [ + threading.Thread(target=mark_cells, args=(0, 500, "col_a")), + threading.Thread(target=mark_cells, args=(500, 1000, "col_a")), + threading.Thread(target=mark_cells, args=(0, 500, "col_b")), + threading.Thread(target=mark_cells, args=(500, 1000, "col_b")), + ] + + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + assert tracker.is_column_complete("col_a") + assert tracker.is_column_complete("col_b") + assert len(tracker) == 2000 + + +def test_thread_safe_tracker_barrier_operations() -> None: + tracker = ThreadSafeCompletionTracker(num_records=10) + + tracker.mark_complete(BarrierNodeId("barrier_col")) + assert tracker.is_barrier_complete("barrier_col") + + +def test_thread_safe_tracker_reset() -> None: + tracker = ThreadSafeCompletionTracker(num_records=10) + + tracker.mark_complete(CellNodeId(0, "col_a")) + tracker.reset() + + assert len(tracker) == 0 + assert not tracker.is_complete(CellNodeId(0, "col_a")) diff --git a/packages/data-designer-engine/tests/engine/execution_graph/test_graph.py b/packages/data-designer-engine/tests/engine/execution_graph/test_graph.py new file mode 100644 index 00000000..3dcaa87b --- /dev/null +++ b/packages/data-designer-engine/tests/engine/execution_graph/test_graph.py @@ -0,0 +1,374 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for execution graph.""" + +import pytest + +from data_designer.engine.execution_graph.completion import CompletionTracker +from data_designer.engine.execution_graph.graph import ExecutionGraph +from data_designer.engine.execution_graph.node_id import BarrierNodeId, CellNodeId +from data_designer.engine.execution_graph.traits import ExecutionTraits + +from .conftest import create_descriptor # noqa: TID252 + +# --- Init tests --- + + +def test_init_with_valid_params() -> None: + desc = create_descriptor("col", ExecutionTraits.START) + graph = ExecutionGraph(num_records=10, column_descriptors=[desc]) + + assert graph.num_records == 10 + assert graph.num_columns == 1 + + +def test_init_rejects_zero_records() -> None: + desc = create_descriptor("col", ExecutionTraits.START) + with pytest.raises(ValueError, match="num_records must be positive"): + ExecutionGraph(num_records=0, column_descriptors=[desc]) + + +def test_init_rejects_negative_records() -> None: + desc = create_descriptor("col", ExecutionTraits.START) + with pytest.raises(ValueError, match="num_records must be positive"): + ExecutionGraph(num_records=-5, column_descriptors=[desc]) + + +def test_init_rejects_empty_descriptors() -> None: + with pytest.raises(ValueError, match="column_descriptors cannot be empty"): + ExecutionGraph(num_records=10, column_descriptors=[]) + + +def test_init_validates_dependencies_by_default() -> None: + desc = create_descriptor("col", ExecutionTraits.NONE, dependencies=["nonexistent"]) + with pytest.raises(ValueError, match="depends on unknown column"): + ExecutionGraph(num_records=10, column_descriptors=[desc]) + + +def test_init_skips_validation_when_strict_false() -> None: + desc = create_descriptor("col", ExecutionTraits.START, dependencies=["nonexistent"]) + # Should not raise + graph = ExecutionGraph(num_records=10, column_descriptors=[desc], strict=False) + assert graph.num_columns == 1 + + +# --- Properties tests --- + + +def test_num_nodes_without_barriers() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor("b", ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, ["a"]), + ] + graph = ExecutionGraph(num_records=100, column_descriptors=descriptors) + # 2 columns × 100 records = 200 nodes + assert graph.num_nodes == 200 + + +def test_num_nodes_with_barriers() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor("b", ExecutionTraits.BARRIER, ["a"]), + ] + graph = ExecutionGraph(num_records=100, column_descriptors=descriptors) + # 2 columns × 100 records + 1 barrier = 201 nodes + assert graph.num_nodes == 201 + + +def test_start_columns() -> None: + descriptors = [ + create_descriptor("starter", ExecutionTraits.START), + create_descriptor("dependent", ExecutionTraits.CELL_BY_CELL, ["starter"]), + ] + graph = ExecutionGraph(num_records=10, column_descriptors=descriptors) + assert graph.start_columns == ["starter"] + + +def test_column_names() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START), + create_descriptor("b", ExecutionTraits.CELL_BY_CELL, ["a"]), + create_descriptor("c", ExecutionTraits.CELL_BY_CELL, ["b"]), + ] + graph = ExecutionGraph(num_records=10, column_descriptors=descriptors) + assert graph.column_names == ["a", "b", "c"] + + +# --- Dependencies tests --- + + +def test_start_column_no_dependencies() -> None: + desc = create_descriptor("start", ExecutionTraits.START) + graph = ExecutionGraph(num_records=10, column_descriptors=[desc]) + + deps = graph.get_dependencies(CellNodeId(0, "start")) + assert deps == [] + + +def test_cell_by_cell_row_local_dependencies() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=10, column_descriptors=descriptors) + + # Row 5 of column b depends on row 5 of column a + deps = graph.get_dependencies(CellNodeId(5, "b")) + assert deps == [CellNodeId(5, "a")] + + +def test_barrier_node_depends_on_all_cells() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor("b", ExecutionTraits.BARRIER, ["a"]), + ] + graph = ExecutionGraph(num_records=3, column_descriptors=descriptors) + + # Barrier depends on all cells of dependency column + barrier_deps = graph.get_dependencies(BarrierNodeId("b")) + expected = [CellNodeId(0, "a"), CellNodeId(1, "a"), CellNodeId(2, "a")] + assert barrier_deps == expected + + +def test_barrier_cell_depends_on_barrier() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor("b", ExecutionTraits.BARRIER, ["a"]), + ] + graph = ExecutionGraph(num_records=10, column_descriptors=descriptors) + + # Cell nodes of barrier column depend on the barrier node + deps = graph.get_dependencies(CellNodeId(5, "b")) + assert deps == [BarrierNodeId("b")] + + +def test_dependency_on_side_effect_column() -> None: + descriptors = [ + create_descriptor( + "person_name", + ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE, + side_effects=["person_email"], + ), + create_descriptor( + "greeting", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["person_email"], # Depends on side effect, not primary + ), + ] + graph = ExecutionGraph(num_records=10, column_descriptors=descriptors) + + # Should resolve to the parent column + deps = graph.get_dependencies(CellNodeId(5, "greeting")) + assert deps == [CellNodeId(5, "person_name")] + + +# --- Dependents tests --- + + +def test_get_dependents_of_start_cell() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=10, column_descriptors=descriptors) + + # Row 5 of column a should have row 5 of column b as dependent + dependents = graph.get_dependents(CellNodeId(5, "a")) + assert CellNodeId(5, "b") in dependents + + +def test_get_dependents_of_barrier() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor("b", ExecutionTraits.BARRIER, ["a"]), + ] + graph = ExecutionGraph(num_records=3, column_descriptors=descriptors) + + # Barrier completion triggers all cells of that column + dependents = graph.get_dependents(BarrierNodeId("b")) + expected = [CellNodeId(0, "b"), CellNodeId(1, "b"), CellNodeId(2, "b")] + assert dependents == expected + + +# --- Iteration tests --- + + +def test_iter_nodes() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=3, column_descriptors=descriptors) + + nodes = list(graph.iter_nodes()) + # 2 columns × 3 records = 6 nodes + assert len(nodes) == 6 + + # First column's cells come first + assert nodes[0] == CellNodeId(0, "a") + assert nodes[1] == CellNodeId(1, "a") + assert nodes[2] == CellNodeId(2, "a") + + # Second column's cells come next + assert nodes[3] == CellNodeId(0, "b") + assert nodes[4] == CellNodeId(1, "b") + assert nodes[5] == CellNodeId(2, "b") + + +def test_iter_nodes_with_barrier() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor("b", ExecutionTraits.BARRIER, ["a"]), + ] + graph = ExecutionGraph(num_records=2, column_descriptors=descriptors) + + nodes = list(graph.iter_nodes()) + # Column a: 2 cells + # Column b: 1 barrier + 2 cells + assert len(nodes) == 5 + + # Barrier comes before cells of barrier column + assert nodes[2] == BarrierNodeId("b") + assert nodes[3] == CellNodeId(0, "b") + assert nodes[4] == CellNodeId(1, "b") + + +def test_iter_start_nodes() -> None: + descriptors = [ + create_descriptor("starter", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "dependent", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["starter"], + ), + ] + graph = ExecutionGraph(num_records=3, column_descriptors=descriptors) + + start_nodes = list(graph.iter_start_nodes()) + assert len(start_nodes) == 3 + assert all(node.column == "starter" for node in start_nodes) + + +def test_iter_ready_nodes_initial() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=3, column_descriptors=descriptors) + tracker = CompletionTracker(3) + + # Initially, only start column cells are ready + ready = list(graph.iter_ready_nodes(tracker)) + assert len(ready) == 3 + assert all(node.column == "a" for node in ready) + + +def test_iter_ready_nodes_after_completion() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=3, column_descriptors=descriptors) + tracker = CompletionTracker(3) + + # Complete row 0 of column a + tracker.mark_complete(CellNodeId(0, "a")) + + ready = list(graph.iter_ready_nodes(tracker)) + # Remaining rows 1, 2 of column a are still ready + # Row 0 of column b is now also ready + assert CellNodeId(1, "a") in ready + assert CellNodeId(2, "a") in ready + assert CellNodeId(0, "b") in ready + assert CellNodeId(0, "a") not in ready # Already completed + + +def test_iter_ready_nodes_for_column() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=3, column_descriptors=descriptors) + tracker = CompletionTracker(3) + + # Complete row 0 of column a + tracker.mark_complete(CellNodeId(0, "a")) + + # Check only column b + ready_b = list(graph.iter_ready_nodes_for_column("b", tracker)) + assert ready_b == [CellNodeId(0, "b")] + + +# --- Completion tests --- + + +def test_is_complete_empty_tracker() -> None: + desc = create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE) + graph = ExecutionGraph(num_records=10, column_descriptors=[desc]) + tracker = CompletionTracker(10) + + assert not graph.is_complete(tracker) + + +def test_is_complete_all_done() -> None: + desc = create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE) + graph = ExecutionGraph(num_records=3, column_descriptors=[desc]) + tracker = CompletionTracker(3) + + for r in range(3): + tracker.mark_complete(CellNodeId(r, "a")) + + assert graph.is_complete(tracker) + + +# --- Getter tests --- + + +def test_get_column_descriptor() -> None: + desc = create_descriptor("my_col", ExecutionTraits.START) + graph = ExecutionGraph(num_records=10, column_descriptors=[desc]) + + retrieved = graph.get_column_descriptor("my_col") + assert retrieved.name == "my_col" + + +def test_get_column_descriptor_not_found() -> None: + desc = create_descriptor("my_col", ExecutionTraits.START) + graph = ExecutionGraph(num_records=10, column_descriptors=[desc]) + + with pytest.raises(KeyError): + graph.get_column_descriptor("nonexistent") + + +def test_get_generator_and_config() -> None: + desc = create_descriptor("my_col", ExecutionTraits.START) + graph = ExecutionGraph(num_records=10, column_descriptors=[desc]) + + gen_cls, config = graph.get_generator_and_config(CellNodeId(0, "my_col")) + assert gen_cls is desc.generator_cls + assert config is desc.config diff --git a/packages/data-designer-engine/tests/engine/execution_graph/test_node_id.py b/packages/data-designer-engine/tests/engine/execution_graph/test_node_id.py new file mode 100644 index 00000000..c429f2f1 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/execution_graph/test_node_id.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for node identification types.""" + +from data_designer.engine.execution_graph.node_id import BarrierNodeId, CellNodeId, NodeId + +# --- CellNodeId tests --- + + +def test_cell_node_id_creation() -> None: + node = CellNodeId(row=5, column="test_col") + assert node.row == 5 + assert node.column == "test_col" + + +def test_cell_node_id_frozen() -> None: + node = CellNodeId(row=0, column="col") + try: + node.row = 1 # type: ignore[misc] + assert False, "Should have raised FrozenInstanceError" + except AttributeError: + pass # Expected + + +def test_cell_node_id_equality() -> None: + node1 = CellNodeId(row=0, column="col") + node2 = CellNodeId(row=0, column="col") + node3 = CellNodeId(row=1, column="col") + node4 = CellNodeId(row=0, column="other") + + assert node1 == node2 + assert node1 != node3 + assert node1 != node4 + + +def test_cell_node_id_hashable() -> None: + node1 = CellNodeId(row=0, column="col") + node2 = CellNodeId(row=0, column="col") + node_set = {node1, node2} + assert len(node_set) == 1 + + +def test_cell_node_id_repr() -> None: + node = CellNodeId(row=5, column="test_col") + assert repr(node) == "Cell(5, 'test_col')" + + +# --- BarrierNodeId tests --- + + +def test_barrier_node_id_creation() -> None: + node = BarrierNodeId(column="barrier_col") + assert node.column == "barrier_col" + + +def test_barrier_node_id_frozen() -> None: + node = BarrierNodeId(column="col") + try: + node.column = "other" # type: ignore[misc] + assert False, "Should have raised FrozenInstanceError" + except AttributeError: + pass # Expected + + +def test_barrier_node_id_equality() -> None: + node1 = BarrierNodeId(column="col") + node2 = BarrierNodeId(column="col") + node3 = BarrierNodeId(column="other") + + assert node1 == node2 + assert node1 != node3 + + +def test_barrier_node_id_hashable() -> None: + node1 = BarrierNodeId(column="col") + node2 = BarrierNodeId(column="col") + node_set = {node1, node2} + assert len(node_set) == 1 + + +def test_barrier_node_id_repr() -> None: + node = BarrierNodeId(column="barrier_col") + assert repr(node) == "Barrier('barrier_col')" + + +# --- NodeId type alias tests --- + + +def test_node_id_can_be_cell_or_barrier() -> None: + cell: NodeId = CellNodeId(row=0, column="col") + barrier: NodeId = BarrierNodeId(column="col") + + assert isinstance(cell, CellNodeId) + assert isinstance(barrier, BarrierNodeId) + + +def test_cell_and_barrier_not_equal() -> None: + cell = CellNodeId(row=0, column="col") + barrier = BarrierNodeId(column="col") + assert cell != barrier diff --git a/packages/data-designer-engine/tests/engine/execution_graph/test_traits.py b/packages/data-designer-engine/tests/engine/execution_graph/test_traits.py new file mode 100644 index 00000000..6e5121e8 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/execution_graph/test_traits.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for execution traits.""" + +from data_designer.engine.execution_graph.traits import ExecutionTraits + + +def test_none_trait() -> None: + traits = ExecutionTraits.NONE + assert traits == ExecutionTraits.NONE + assert not bool(traits) + + +def test_individual_traits() -> None: + assert ExecutionTraits.START != ExecutionTraits.NONE + assert ExecutionTraits.CELL_BY_CELL != ExecutionTraits.NONE + assert ExecutionTraits.ROW_STREAMABLE != ExecutionTraits.NONE + assert ExecutionTraits.BARRIER != ExecutionTraits.NONE + + +def test_traits_are_distinct() -> None: + traits = [ + ExecutionTraits.START, + ExecutionTraits.CELL_BY_CELL, + ExecutionTraits.ROW_STREAMABLE, + ExecutionTraits.BARRIER, + ] + # Each trait should be distinct + for i, t1 in enumerate(traits): + for j, t2 in enumerate(traits): + if i != j: + assert t1 != t2 + + +def test_trait_combination() -> None: + combined = ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE + assert combined & ExecutionTraits.START + assert combined & ExecutionTraits.ROW_STREAMABLE + assert not (combined & ExecutionTraits.BARRIER) + assert not (combined & ExecutionTraits.CELL_BY_CELL) + + +def test_trait_check_with_bool() -> None: + traits = ExecutionTraits.START | ExecutionTraits.CELL_BY_CELL + + assert bool(traits & ExecutionTraits.START) + assert bool(traits & ExecutionTraits.CELL_BY_CELL) + assert not bool(traits & ExecutionTraits.BARRIER) + + +def test_cell_by_cell_implies_row_streamable() -> None: + # This is a common pattern: cell-by-cell generators are always row-streamable + traits = ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE + assert bool(traits & ExecutionTraits.ROW_STREAMABLE) + + +def test_start_with_barrier() -> None: + # A generator can start from scratch but still be a barrier + traits = ExecutionTraits.START | ExecutionTraits.BARRIER + assert bool(traits & ExecutionTraits.START) + assert bool(traits & ExecutionTraits.BARRIER) diff --git a/plans/async-engine/execution_graph_builder.plan.md b/plans/async-engine/execution_graph_builder.plan.md new file mode 100644 index 00000000..077d5f67 --- /dev/null +++ b/plans/async-engine/execution_graph_builder.plan.md @@ -0,0 +1,356 @@ +--- +name: Execution Graph Builder +overview: Build a memory-efficient execution graph that models cell-level dependencies for async dataset generation, supporting different generator execution traits (start, cell-by-cell, row-streamable, barrier) with trait inference from generator properties. +todos: + - id: node-types + content: Create node_id.py with CellNodeId, BarrierNodeId, and NodeId type alias + status: pending + - id: traits-enum + content: Create traits.py with ExecutionTraits Flag enum + status: pending + - id: column-descriptor + content: Create column_descriptor.py with ColumnDescriptor dataclass + status: pending + - id: base-class-property + content: Add is_row_streamable property to ColumnGenerator base class and override in subclasses + status: pending + - id: execution-graph + content: Create graph.py with ExecutionGraph class (hybrid representation, node iteration, dependency queries) + status: pending + - id: graph-builder + content: Create builder.py with GraphBuilder factory that infers traits from generator properties + status: pending + - id: completion-tracker + content: Create completion.py with CompletionTracker (simple) and ThreadSafeCompletionTracker variants + status: pending + - id: init-exports + content: Create __init__.py with public API exports + status: pending + - id: tests + content: Add comprehensive tests for graph building, node iteration, and dependency resolution + status: pending +isProject: false +--- + +# Execution Graph Builder + +## Overview + +Build a new graph-based execution framework that models cell-level dependencies for async dataset generation. The graph uses a **hybrid representation** where column structure is stored explicitly while cell-level nodes are computed on-demand to handle millions of records efficiently. + +## Architecture + +```mermaid +graph TD + subgraph config [Configuration Layer] + DDC[DataDesignerConfig] + CC[ColumnConfigs] + GEN[ColumnGenerators] + end + + subgraph graphLayer [Graph Builder] + GB[GraphBuilder] + GT[GeneratorTraits] + CD[ColumnDescriptor] + EG[ExecutionGraph] + end + + subgraph execution [Execution Engine Interface] + NI[iter_ready_nodes] + EC[ExecutionContext] + CT[CompletionTracker] + end + + DDC --> CC + CC --> GB + GEN --> GT + GB --> CD + CD --> EG + EG --> NI + EG --> EC + NI --> CT +``` + + + +## Generator Execution Traits + +Traits are inferred from generator properties (no hardcoded class-name matching): + + +| Trait | Source Property | Description | +| ---------------- | ------------------------------------------- | ------------------------------------------- | +| `START` | `can_generate_from_scratch = True` | Can initiate workflows without input | +| `CELL_BY_CELL` | `get_generation_strategy() == CELL_BY_CELL` | Processes individual cells independently | +| `ROW_STREAMABLE` | `is_row_streamable = True` | Full column generator that can emit per-row | +| `BARRIER` | `is_row_streamable = False` on full-column | Requires all inputs before any output | + + +### New Generator Property + +Add `is_row_streamable` property to `ColumnGenerator` base class in [packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py](packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py): + +```python +@property +def is_row_streamable(self) -> bool: + """Whether this generator can emit results as rows complete.""" + return self.get_generation_strategy() == GenerationStrategy.CELL_BY_CELL +``` + +Override in specific generators: + +- `ExpressionColumnGenerator`: `True` (processes rows independently) +- `ValidationColumnGenerator`: `False` (needs all rows for batch processing) +- `SamplerColumnGenerator`: `True` (generates per-row data) + +## Dependency Resolution + +### Cell-by-Cell / Row-Streamable Columns + +Row-local dependencies: + +- Cell `(row=5, col="question")` depends on `(row=5, col="context")` + +### Barrier Columns + +All-or-nothing dependencies: + +- `BarrierNodeId("validation")` depends on ALL cells of dependency columns +- Output cells `(row=*, col="validation")` depend on `BarrierNodeId("validation")` + +```mermaid +graph LR + subgraph inputColumn [Input Column] + C0[Cell 0] + C1[Cell 1] + CN[Cell N] + end + + subgraph barrier [Barrier Column] + B[BarrierNode] + O0[Output 0] + O1[Output 1] + ON[Output N] + end + + C0 --> B + C1 --> B + CN --> B + B --> O0 + B --> O1 + B --> ON +``` + + + +## File Structure + +``` +packages/data-designer-engine/src/data_designer/engine/execution_graph/ + __init__.py # Public exports + node_id.py # CellNodeId, BarrierNodeId, NodeId type alias + traits.py # ExecutionTraits Flag enum + column_descriptor.py # ColumnDescriptor dataclass + graph.py # ExecutionGraph class (hybrid representation) + builder.py # GraphBuilder factory (infers traits from generators) + completion.py # CompletionTracker (memory-efficient tracking) +``` + +## Key Components + +### 1. Node Identification (`node_id.py`) + +```python +@dataclass(frozen=True, slots=True) +class CellNodeId: + row: int + column: str + +@dataclass(frozen=True, slots=True) +class BarrierNodeId: + column: str + +NodeId = CellNodeId | BarrierNodeId +``` + +### 2. Execution Traits (`traits.py`) + +```python +class ExecutionTraits(Flag): + NONE = 0 + START = auto() # Can generate from scratch + CELL_BY_CELL = auto() # Processes individual cells + ROW_STREAMABLE = auto() # Can emit as rows complete + BARRIER = auto() # Requires all inputs first +``` + +### 3. Column Descriptor (`column_descriptor.py`) + +```python +@dataclass(slots=True) +class ColumnDescriptor: + name: str + config: ConfigBase + generator_cls: type[ColumnGenerator] + traits: ExecutionTraits + dependencies: list[str] # from required_columns + side_effects: list[str] # additional columns produced +``` + +### 4. Execution Graph (`graph.py`) + +Key methods for execution engines: + +- `iter_start_nodes()` - Nodes that can begin immediately +- `iter_ready_nodes(completed)` - Nodes with satisfied dependencies +- `get_dependencies(node)` - Dependencies for a node +- `get_generator_and_config(node)` - Generator class and config for execution + +### 5. Graph Builder (`builder.py`) + +```python +class GraphBuilder: + def build(self, config: DataDesignerConfig, num_records: int) -> ExecutionGraph: + # 1. Build descriptors from column configs + # 2. Infer traits from generator properties (not class names!) + # 3. Handle multi-column configs (mark additional columns as side effects) + # 4. Validate dependencies exist + # 5. Return ExecutionGraph with descriptors in topo order +``` + +Trait inference (plugin-compatible): + +```python +def _infer_traits(self, gen_cls: type[ColumnGenerator]) -> ExecutionTraits: + traits = ExecutionTraits.NONE + + # Check can_generate_from_scratch property (works for any generator) + if getattr(gen_cls, 'can_generate_from_scratch', False): + traits |= ExecutionTraits.START + + # Check generation strategy + strategy = gen_cls.get_generation_strategy() + if strategy == GenerationStrategy.CELL_BY_CELL: + traits |= ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE + else: # FULL_COLUMN + # Check is_row_streamable property + if getattr(gen_cls, 'is_row_streamable', False): + traits |= ExecutionTraits.ROW_STREAMABLE + else: + traits |= ExecutionTraits.BARRIER + + return traits +``` + +### 6. Completion Tracker (`completion.py`) + +Memory-efficient tracking for large datasets with two variants: + +**Base Protocol:** + +```python +class CompletionTrackerProtocol(Protocol): + """Protocol for completion tracking - enables duck typing.""" + + def mark_complete(self, node: NodeId) -> None: ... + def is_complete(self, node: NodeId) -> bool: ... + def is_column_complete(self, column: str) -> bool: ... + def __contains__(self, node: NodeId) -> bool: ... +``` + +**Simple Tracker (asyncio / single-threaded):** + +```python +class CompletionTracker: + """O(C) memory instead of O(C x R). No locks - for single-threaded use.""" + + def mark_complete(self, node: NodeId) -> None: ... + def is_complete(self, node: NodeId) -> bool: ... + def is_column_complete(self, column: str) -> bool: ... +``` + +**Thread-Safe Tracker (thread pool / multi-threaded):** + +```python +class ThreadSafeCompletionTracker: + """Thread-safe variant with internal locking for concurrent access.""" + + def mark_complete(self, node: NodeId) -> None: + with self._lock: + # ... same logic as CompletionTracker + + def is_complete(self, node: NodeId) -> bool: + with self._lock: + # ... same logic as CompletionTracker +``` + +Both trackers use the same memory-efficient strategy: + +- Track fully completed columns as a set of names: O(C) +- Only store partial completion progress for in-progress columns +- Automatically compact when columns fully complete + +**Thread-Safety Notes:** + +- `ExecutionGraph` is **immutable after construction** - no locks needed +- Only the `CompletionTracker` holds mutable state +- Choose tracker variant based on your execution model: + - `CompletionTracker`: asyncio, single-threaded executors + - `ThreadSafeCompletionTracker`: thread pools, concurrent.futures + +## Memory Efficiency + + +| Aspect | Storage | +| ----------------- | ---------------------------------------------- | +| Column metadata | O(C) where C = number of columns | +| Cell nodes | Virtual - O(1) per node, computed on iteration | +| Topological order | O(C), cached once | +| NodeId objects | Frozen dataclasses with `slots=True` | +| Dependency edges | Computed mathematically, not stored | + + +A graph with 1M records and 10 columns uses ~same memory as 10 records. + +## Example Usage + +```python +from data_designer.engine.execution_graph import GraphBuilder, ExecutionGraph +from data_designer.engine.execution_graph.completion import ( + CompletionTracker, + ThreadSafeCompletionTracker, +) + +# Build graph from config (immutable after construction) +builder = GraphBuilder(column_generator_registry) +graph: ExecutionGraph = builder.build(config, num_records=1_000_000) + +# Option 1: Single-threaded / asyncio execution +tracker = CompletionTracker(graph.num_records) +for node in graph.iter_ready_nodes(tracker): + gen_cls, config = graph.get_generator_and_config(node) + # Execute node... + tracker.mark_complete(node) + +# Option 2: Multi-threaded execution (e.g., with ThreadPoolExecutor) +tracker = ThreadSafeCompletionTracker(graph.num_records) +# Safe to call mark_complete() from multiple threads concurrently +``` + +## Multi-Column Config Handling + +Multi-column configs (samplers, seed datasets) produce multiple output columns. The graph handles this by: + +1. Using the first column as the primary descriptor name +2. Marking additional columns as `side_effects` +3. Building a reverse mapping `_side_effect_to_parent` for dependency resolution +4. Allowing downstream columns to depend on any produced column + +## Validation + +The graph builder validates: + +1. All dependency columns exist (or are side effects of existing columns) +2. No circular dependencies (enforced by topological order) +3. At least one column has `START` trait (can generate from scratch) + diff --git a/plans/async-engine/explore_execution_graph.ipynb b/plans/async-engine/explore_execution_graph.ipynb new file mode 100644 index 00000000..2b3892cf --- /dev/null +++ b/plans/async-engine/explore_execution_graph.ipynb @@ -0,0 +1,1015 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exploring the Execution Graph Builder\n", + "\n", + "This notebook explores the new execution graph framework for async dataset generation.\n", + "The graph uses a **hybrid representation** where column structure is stored explicitly\n", + "while cell-level nodes are computed on-demand to handle millions of records efficiently.\n", + "\n", + "## Key Concepts\n", + "\n", + "- **CellNodeId**: Identifies a single cell (row, column) in the dataset\n", + "- **BarrierNodeId**: Synchronization point where all inputs must complete before outputs begin\n", + "- **ExecutionTraits**: Flags describing generator behavior (START, CELL_BY_CELL, ROW_STREAMABLE, BARRIER)\n", + "- **ColumnDescriptor**: Metadata about a column including its traits and dependencies\n", + "- **ExecutionGraph**: The main graph structure with virtual node iteration\n", + "- **CompletionTracker**: Memory-efficient tracking of completed nodes\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from data_designer.engine.execution_graph import (\n", + " BarrierNodeId,\n", + " CellNodeId,\n", + " ColumnDescriptor,\n", + " CompletionTracker,\n", + " ExecutionGraph,\n", + " ExecutionTraits,\n", + " GraphBuilder,\n", + " NodeId,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Node Identification\n", + "\n", + "Nodes are identified using frozen dataclasses with `slots=True` for memory efficiency.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cell node: Cell(5, 'question')\n", + " Row: 5, Column: question\n", + "\n", + "Barrier node: Barrier('validation')\n", + " Column: validation\n" + ] + } + ], + "source": [ + "# CellNodeId identifies a single cell (row, column)\n", + "cell = CellNodeId(row=5, column=\"question\")\n", + "print(f\"Cell node: {cell}\")\n", + "print(f\" Row: {cell.row}, Column: {cell.column}\")\n", + "\n", + "# BarrierNodeId identifies a synchronization point for a column\n", + "barrier = BarrierNodeId(column=\"validation\")\n", + "print(f\"\\nBarrier node: {barrier}\")\n", + "print(f\" Column: {barrier.column}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Completed nodes: {Barrier('validation'), Cell(0, 'category'), Cell(1, 'category')}\n", + "\n", + "CellNodeId(0, 'category') in set: True\n", + "CellNodeId(2, 'category') in set: False\n" + ] + } + ], + "source": [ + "# Nodes are hashable and can be used in sets/dicts\n", + "completed_nodes: set[NodeId] = {\n", + " CellNodeId(0, \"category\"),\n", + " CellNodeId(1, \"category\"),\n", + " BarrierNodeId(\"validation\"),\n", + "}\n", + "\n", + "print(f\"Completed nodes: {completed_nodes}\")\n", + "print(f\"\\nCellNodeId(0, 'category') in set: {CellNodeId(0, 'category') in completed_nodes}\")\n", + "print(f\"CellNodeId(2, 'category') in set: {CellNodeId(2, 'category') in completed_nodes}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Execution Traits\n", + "\n", + "Traits describe the execution characteristics of column generators. They're inferred from generator properties (not class names) to support plugin generators.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Available traits:\n", + " - START: 1\n", + " - CELL_BY_CELL: 2\n", + " - ROW_STREAMABLE: 4\n", + " - BARRIER: 8\n" + ] + } + ], + "source": [ + "# ExecutionTraits is a Flag enum - traits can be combined\n", + "print(\"Available traits:\")\n", + "for trait in ExecutionTraits:\n", + " if trait != ExecutionTraits.NONE:\n", + " print(f\" - {trait.name}: {trait.value}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sampler traits: ExecutionTraits.START|CELL_BY_CELL|ROW_STREAMABLE\n", + " Has START: True\n", + " Has BARRIER: False\n", + "\n", + "Barrier traits: ExecutionTraits.BARRIER\n", + " Has BARRIER: True\n" + ] + } + ], + "source": [ + "# Traits can be combined using bitwise OR\n", + "sampler_traits = ExecutionTraits.START | ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE\n", + "print(f\"Sampler traits: {sampler_traits}\")\n", + "print(f\" Has START: {bool(sampler_traits & ExecutionTraits.START)}\")\n", + "print(f\" Has BARRIER: {bool(sampler_traits & ExecutionTraits.BARRIER)}\")\n", + "\n", + "# Validator/barrier column traits\n", + "barrier_traits = ExecutionTraits.BARRIER\n", + "print(f\"\\nBarrier traits: {barrier_traits}\")\n", + "print(f\" Has BARRIER: {bool(barrier_traits & ExecutionTraits.BARRIER)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Column Descriptors\n", + "\n", + "ColumnDescriptor stores metadata about a column including its configuration, generator class, execution traits, and dependencies.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Category column:\n", + " is_start_column: True\n", + " is_cell_by_cell: True\n", + " is_row_streamable: True\n", + " is_barrier: False\n" + ] + } + ], + "source": [ + "# Import a mock config and generator for demonstration\n", + "from data_designer.engine.column_generators.generators.base import ColumnGenerator\n", + "\n", + "# For demonstration, we'll create descriptors manually\n", + "# In practice, GraphBuilder creates these from DataDesignerConfig\n", + "\n", + "# A start column (sampler) - can generate data from scratch\n", + "category_desc = ColumnDescriptor(\n", + " name=\"category\",\n", + " config=None, # Would be a real config in practice\n", + " generator_cls=ColumnGenerator, # Would be a real generator\n", + " traits=ExecutionTraits.START | ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE,\n", + " dependencies=[],\n", + " side_effects=[],\n", + ")\n", + "\n", + "print(\"Category column:\")\n", + "print(f\" is_start_column: {category_desc.is_start_column}\")\n", + "print(f\" is_cell_by_cell: {category_desc.is_cell_by_cell}\")\n", + "print(f\" is_row_streamable: {category_desc.is_row_streamable}\")\n", + "print(f\" is_barrier: {category_desc.is_barrier}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question column:\n", + " is_start_column: False\n", + " has_dependencies: True\n", + " dependencies: ['category']\n" + ] + } + ], + "source": [ + "# A dependent column (LLM text) - processes cells independently\n", + "question_desc = ColumnDescriptor(\n", + " name=\"question\",\n", + " config=None,\n", + " generator_cls=ColumnGenerator,\n", + " traits=ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE,\n", + " dependencies=[\"category\"], # Depends on category column\n", + " side_effects=[],\n", + ")\n", + "\n", + "print(\"Question column:\")\n", + "print(f\" is_start_column: {question_desc.is_start_column}\")\n", + "print(f\" has_dependencies: {question_desc.has_dependencies}\")\n", + "print(f\" dependencies: {question_desc.dependencies}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation column:\n", + " is_barrier: True\n", + " is_row_streamable: False\n" + ] + } + ], + "source": [ + "# A barrier column (validator) - needs all inputs before producing any output\n", + "validation_desc = ColumnDescriptor(\n", + " name=\"validation\",\n", + " config=None,\n", + " generator_cls=ColumnGenerator,\n", + " traits=ExecutionTraits.BARRIER,\n", + " dependencies=[\"question\"],\n", + " side_effects=[],\n", + ")\n", + "\n", + "print(\"Validation column:\")\n", + "print(f\" is_barrier: {validation_desc.is_barrier}\")\n", + "print(f\" is_row_streamable: {validation_desc.is_row_streamable}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Execution Graph\n", + "\n", + "The ExecutionGraph models cell-level dependencies while maintaining memory efficiency through a hybrid representation:\n", + "\n", + "- **Explicit**: Column structure (ColumnDescriptors) is stored in memory\n", + "- **Virtual**: Cell-level nodes are computed on-demand, not stored\n", + "\n", + "This allows handling datasets with millions of records without creating millions of explicit nodes.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Graph properties:\n", + " num_records: 5\n", + " num_columns: 3\n", + " num_nodes: 16\n", + " start_columns: ['category']\n", + " column_names: ['category', 'question', 'validation']\n" + ] + } + ], + "source": [ + "# Create an execution graph with our descriptors\n", + "# Descriptors must be in topological order (dependencies before dependents)\n", + "graph = ExecutionGraph(\n", + " num_records=5,\n", + " column_descriptors=[category_desc, question_desc, validation_desc],\n", + ")\n", + "\n", + "print(\"Graph properties:\")\n", + "print(f\" num_records: {graph.num_records}\")\n", + "print(f\" num_columns: {graph.num_columns}\")\n", + "print(f\" num_nodes: {graph.num_nodes}\")\n", + "print(f\" start_columns: {graph.start_columns}\")\n", + "print(f\" column_names: {graph.column_names}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Large graph with 1M records:\n", + " num_nodes: 3,000,001\n", + " Memory: O(3) columns, not O(3,000,001) nodes\n" + ] + } + ], + "source": [ + "# Memory efficiency demonstration\n", + "# With 1M records and 10 columns, we'd have ~10M virtual nodes\n", + "# But we only store O(C) column metadata, not O(C x R) nodes\n", + "\n", + "large_graph = ExecutionGraph(\n", + " num_records=1_000_000,\n", + " column_descriptors=[category_desc, question_desc, validation_desc],\n", + ")\n", + "\n", + "print(\"Large graph with 1M records:\")\n", + "print(f\" num_nodes: {large_graph.num_nodes:,}\")\n", + "print(f\" Memory: O({large_graph.num_columns}) columns, not O({large_graph.num_nodes:,}) nodes\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Node Iteration\n", + "\n", + "The graph provides several iteration methods for execution engines.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Start nodes (first 10):\n", + " Cell(0, 'category')\n", + " Cell(1, 'category')\n", + " Cell(2, 'category')\n", + " Cell(3, 'category')\n", + " Cell(4, 'category')\n" + ] + } + ], + "source": [ + "# iter_start_nodes() - nodes that can begin immediately (no dependencies)\n", + "print(\"Start nodes (first 10):\")\n", + "for i, node in enumerate(graph.iter_start_nodes()):\n", + " if i >= 10:\n", + " print(\" ...\")\n", + " break\n", + " print(f\" {node}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "All nodes:\n", + " Cell(0, 'category')\n", + " Cell(1, 'category')\n", + " Cell(2, 'category')\n", + " Cell(3, 'category')\n", + " Cell(4, 'category')\n", + " Cell(0, 'question')\n", + " Cell(1, 'question')\n", + " Cell(2, 'question')\n", + " Cell(3, 'question')\n", + " Cell(4, 'question')\n", + " Barrier('validation')\n", + " Cell(0, 'validation')\n", + " Cell(1, 'validation')\n", + " Cell(2, 'validation')\n", + " Cell(3, 'validation')\n", + " Cell(4, 'validation')\n" + ] + } + ], + "source": [ + "# iter_nodes() - all virtual nodes in topological order\n", + "print(\"All nodes:\")\n", + "for node in graph.iter_nodes():\n", + " print(f\" {node}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Dependency Resolution\n", + "\n", + "Dependencies are resolved differently based on column traits:\n", + "\n", + "- **CELL_BY_CELL / ROW_STREAMABLE**: Row-local dependencies (same row)\n", + "- **BARRIER**: BarrierNodeId depends on ALL cells of dependency columns\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dependencies for Cell(2, 'category'): []\n" + ] + } + ], + "source": [ + "# Dependencies for a start column (none)\n", + "cell = CellNodeId(row=2, column=\"category\")\n", + "deps = graph.get_dependencies(cell)\n", + "print(f\"Dependencies for {cell}: {deps}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dependencies for Cell(2, 'question'): [Cell(2, 'category')]\n", + " -> Same row in dependency column (category)\n" + ] + } + ], + "source": [ + "# Dependencies for a cell-by-cell column (row-local)\n", + "cell = CellNodeId(row=2, column=\"question\")\n", + "deps = graph.get_dependencies(cell)\n", + "print(f\"Dependencies for {cell}: {deps}\")\n", + "print(\" -> Same row in dependency column (category)\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dependencies for Barrier('validation'):\n", + " Cell(0, 'question')\n", + " Cell(1, 'question')\n", + " Cell(2, 'question')\n", + " Cell(3, 'question')\n", + " Cell(4, 'question')\n", + " -> ALL cells of dependency column (question)\n" + ] + } + ], + "source": [ + "# Dependencies for a barrier node (ALL cells of dependency columns)\n", + "barrier = BarrierNodeId(column=\"validation\")\n", + "deps = graph.get_dependencies(barrier)\n", + "print(f\"Dependencies for {barrier}:\")\n", + "for dep in deps:\n", + " print(f\" {dep}\")\n", + "print(\" -> ALL cells of dependency column (question)\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dependencies for Cell(2, 'validation'): [Barrier('validation')]\n", + " -> Depends on the barrier being complete\n" + ] + } + ], + "source": [ + "# Dependencies for output cells of a barrier column\n", + "# They depend on the barrier node, not directly on input cells\n", + "cell = CellNodeId(row=2, column=\"validation\")\n", + "deps = graph.get_dependencies(cell)\n", + "print(f\"Dependencies for {cell}: {deps}\")\n", + "print(\" -> Depends on the barrier being complete\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Completion Tracking\n", + "\n", + "The CompletionTracker provides memory-efficient tracking:\n", + "\n", + "- Tracks fully completed columns as a set of names: O(C)\n", + "- Only stores partial completion progress for in-progress columns\n", + "- Automatically compacts when columns fully complete\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initial state:\n", + " Completed nodes: 0\n", + " category complete: False\n" + ] + } + ], + "source": [ + "# Create a completion tracker\n", + "tracker = CompletionTracker(num_records=graph.num_records)\n", + "\n", + "print(\"Initial state:\")\n", + "print(f\" Completed nodes: {len(tracker)}\")\n", + "print(f\" category complete: {tracker.is_column_complete('category')}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "After marking 2 cells complete:\n", + " Completed nodes: 2\n", + " category completion count: 2\n", + " category complete: False\n" + ] + } + ], + "source": [ + "# Mark individual cells complete\n", + "tracker.mark_complete(CellNodeId(0, \"category\"))\n", + "tracker.mark_complete(CellNodeId(1, \"category\"))\n", + "\n", + "print(\"After marking 2 cells complete:\")\n", + "print(f\" Completed nodes: {len(tracker)}\")\n", + "print(f\" category completion count: {tracker.column_completion_count('category')}\")\n", + "print(f\" category complete: {tracker.is_column_complete('category')}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CellNodeId(0, 'category') in tracker: True\n", + "CellNodeId(2, 'category') in tracker: False\n" + ] + } + ], + "source": [ + "# Check node completion using 'in' syntax\n", + "print(f\"CellNodeId(0, 'category') in tracker: {CellNodeId(0, 'category') in tracker}\")\n", + "print(f\"CellNodeId(2, 'category') in tracker: {CellNodeId(2, 'category') in tracker}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "After completing category column:\n", + " Completed nodes: 5\n", + " category complete: True\n" + ] + } + ], + "source": [ + "# Complete the rest of the category column\n", + "for row in range(2, graph.num_records):\n", + " tracker.mark_complete(CellNodeId(row, \"category\"))\n", + "\n", + "print(\"After completing category column:\")\n", + "print(f\" Completed nodes: {len(tracker)}\")\n", + "print(f\" category complete: {tracker.is_column_complete('category')}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Ready Node Iteration\n", + "\n", + "The `iter_ready_nodes()` method is the primary interface for async execution engines.\n", + "It yields nodes whose dependencies are all satisfied.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ready nodes after category completion:\n", + " Cell(0, 'question')\n", + " Cell(1, 'question')\n", + " Cell(2, 'question')\n", + " Cell(3, 'question')\n", + " Cell(4, 'question')\n" + ] + } + ], + "source": [ + "# Now that category is complete, question cells should be ready\n", + "ready_nodes = list(graph.iter_ready_nodes(tracker))\n", + "\n", + "print(\"Ready nodes after category completion:\")\n", + "for node in ready_nodes:\n", + " print(f\" {node}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ready nodes after question completion:\n", + " Barrier('validation')\n" + ] + } + ], + "source": [ + "# Complete some question cells\n", + "for row in range(graph.num_records):\n", + " tracker.mark_complete(CellNodeId(row, \"question\"))\n", + "\n", + "# Now the validation barrier should be ready\n", + "ready_nodes = list(graph.iter_ready_nodes(tracker))\n", + "\n", + "print(\"Ready nodes after question completion:\")\n", + "for node in ready_nodes:\n", + " print(f\" {node}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ready nodes after barrier completion:\n", + " Cell(0, 'validation')\n", + " Cell(1, 'validation')\n", + " Cell(2, 'validation')\n", + " Cell(3, 'validation')\n", + " Cell(4, 'validation')\n" + ] + } + ], + "source": [ + "# Complete the barrier\n", + "tracker.mark_complete(BarrierNodeId(\"validation\"))\n", + "\n", + "# Now validation output cells should be ready\n", + "ready_nodes = list(graph.iter_ready_nodes(tracker))\n", + "\n", + "print(\"Ready nodes after barrier completion:\")\n", + "for node in ready_nodes:\n", + " print(f\" {node}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Graph complete: True\n", + "Total completed nodes: 16\n" + ] + } + ], + "source": [ + "# Complete validation cells\n", + "for row in range(graph.num_records):\n", + " tracker.mark_complete(CellNodeId(row, \"validation\"))\n", + "\n", + "# Check if graph is complete\n", + "print(f\"Graph complete: {graph.is_complete(tracker)}\")\n", + "print(f\"Total completed nodes: {len(tracker)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 9. Simulation: Full Execution Loop\n", + "\n", + "Let's simulate a complete execution loop using the graph.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Starting Execution Simulation ===\n", + "Graph: 3 records, 3 columns\n", + "\n", + "Iteration 1: 3 ready nodes\n", + " Executing: Cell(0, 'category')\n", + " Executing: Cell(1, 'category')\n", + " Executing: Cell(2, 'category')\n", + "\n", + "Iteration 2: 3 ready nodes\n", + " Executing: Cell(0, 'question')\n", + " Executing: Cell(1, 'question')\n", + " Executing: Cell(2, 'question')\n", + "\n", + "Iteration 3: 1 ready nodes\n", + " Executing: Barrier('validation')\n", + "\n", + "Iteration 4: 3 ready nodes\n", + " Executing: Cell(0, 'validation')\n", + " Executing: Cell(1, 'validation')\n", + " Executing: Cell(2, 'validation')\n", + "\n", + "=== Execution Complete ===\n", + "Total iterations: 4\n", + "Total nodes processed: 10\n" + ] + } + ], + "source": [ + "# Create a fresh graph and tracker\n", + "graph = ExecutionGraph(\n", + " num_records=3,\n", + " column_descriptors=[category_desc, question_desc, validation_desc],\n", + ")\n", + "tracker = CompletionTracker(num_records=graph.num_records)\n", + "\n", + "print(\"=== Starting Execution Simulation ===\")\n", + "print(f\"Graph: {graph.num_records} records, {graph.num_columns} columns\\n\")\n", + "\n", + "iteration = 0\n", + "while not graph.is_complete(tracker):\n", + " iteration += 1\n", + " ready = list(graph.iter_ready_nodes(tracker))\n", + "\n", + " if not ready:\n", + " print(\"ERROR: No ready nodes but graph not complete!\")\n", + " break\n", + "\n", + " print(f\"Iteration {iteration}: {len(ready)} ready nodes\")\n", + " for node in ready:\n", + " # Simulate execution\n", + " print(f\" Executing: {node}\")\n", + " tracker.mark_complete(node)\n", + " print()\n", + "\n", + "print(\"=== Execution Complete ===\")\n", + "print(f\"Total iterations: {iteration}\")\n", + "print(f\"Total nodes processed: {len(tracker)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 10. Using GraphBuilder with Real Configs\n", + "\n", + "In practice, you'll use `GraphBuilder` to construct graphs from `DataDesignerConfig` objects.\n", + "The builder infers execution traits from generator properties automatically.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Config columns: ['category', 'question']\n" + ] + } + ], + "source": [ + "# Import the config API and registry\n", + "import data_designer.config as dd\n", + "from data_designer.engine.column_generators.registry import create_default_column_generator_registry\n", + "\n", + "# Create a simple config\n", + "config_builder = dd.DataDesignerConfigBuilder()\n", + "\n", + "# Add a sampler column (START trait)\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"category\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(values=[\"science\", \"history\", \"math\"]),\n", + " )\n", + ")\n", + "\n", + "# Add an LLM text column (depends on category)\n", + "config_builder.add_column(\n", + " dd.LLMTextColumnConfig(\n", + " name=\"question\",\n", + " prompt=\"Generate a trivia question about {{ category }}\",\n", + " model_alias=\"gpt-4o-mini\",\n", + " )\n", + ")\n", + "\n", + "config = config_builder.build()\n", + "print(f\"Config columns: {[c.name for c in config.columns]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Built graph:\n", + " num_records: 100\n", + " num_columns: 2\n", + " num_nodes: 200\n", + " start_columns: ['category']\n", + " column_names: ['category', 'question']\n" + ] + } + ], + "source": [ + "# Build the graph using GraphBuilder\n", + "builder = GraphBuilder(create_default_column_generator_registry())\n", + "\n", + "graph = builder.build(config, num_records=100)\n", + "\n", + "print(\"Built graph:\")\n", + "print(f\" num_records: {graph.num_records}\")\n", + "print(f\" num_columns: {graph.num_columns}\")\n", + "print(f\" num_nodes: {graph.num_nodes}\")\n", + "print(f\" start_columns: {graph.start_columns}\")\n", + "print(f\" column_names: {graph.column_names}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "category:\n", + " traits: ExecutionTraits.START|ROW_STREAMABLE\n", + " is_start_column: True\n", + " is_barrier: False\n", + " dependencies: []\n", + "\n", + "question:\n", + " traits: ExecutionTraits.CELL_BY_CELL|ROW_STREAMABLE\n", + " is_start_column: False\n", + " is_barrier: False\n", + " dependencies: ['category']\n" + ] + } + ], + "source": [ + "# Inspect inferred traits\n", + "for col_name in graph.column_names:\n", + " desc = graph.get_column_descriptor(col_name)\n", + " print(f\"\\n{col_name}:\")\n", + " print(f\" traits: {desc.traits}\")\n", + " print(f\" is_start_column: {desc.is_start_column}\")\n", + " print(f\" is_barrier: {desc.is_barrier}\")\n", + " print(f\" dependencies: {desc.dependencies}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "The execution graph framework provides:\n", + "\n", + "1. **Memory Efficiency**: O(C) storage for column metadata, virtual nodes computed on-demand\n", + "2. **Trait-Based Execution**: Generators describe their execution characteristics via properties\n", + "3. **Flexible Dependencies**: Row-local for streaming, barrier synchronization for batch operations\n", + "4. **Async-Ready**: `iter_ready_nodes()` enables efficient async/parallel execution\n", + "5. **Plugin Support**: Traits inferred from properties, not class names\n", + "\n", + "### Key API Methods\n", + "\n", + "| Method | Description |\n", + "| -------------------------------- | ---------------------------------------- |\n", + "| `iter_start_nodes()` | Nodes that can begin immediately |\n", + "| `iter_ready_nodes(tracker)` | Nodes with satisfied dependencies |\n", + "| `get_dependencies(node)` | Dependencies for a node |\n", + "| `get_dependents(node)` | Nodes that depend on this node |\n", + "| `get_generator_and_config(node)` | Generator class and config for execution |\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 9df4fd181914c84b74640767627d7f290fc05695 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Fri, 30 Jan 2026 10:54:26 -0500 Subject: [PATCH 2/3] feat: add checkpoint/restart support to execution graph Add row-complete batch checkpointing for resuming interrupted generation: - Add is_row_complete() and get_completed_row_count() to ExecutionGraph - Add to_checkpoint() and from_checkpoint() to CompletionTracker - Add thread-safe checkpoint methods to ThreadSafeCompletionTracker - Export CHECKPOINT_VERSION for compatibility checking - Add comprehensive tests for checkpoint functionality - Update plan documentation with checkpoint/restart section --- .../engine/execution_graph/__init__.py | 2 + .../engine/execution_graph/completion.py | 115 ++++++++ .../engine/execution_graph/graph.py | 50 ++++ .../engine/execution_graph/test_completion.py | 279 ++++++++++++++++++ .../engine/execution_graph/test_graph.py | 202 +++++++++++++ .../execution_graph_builder.plan.md | 107 +++++++ 6 files changed, 755 insertions(+) diff --git a/packages/data-designer-engine/src/data_designer/engine/execution_graph/__init__.py b/packages/data-designer-engine/src/data_designer/engine/execution_graph/__init__.py index f7b74a5d..1ef84b1f 100644 --- a/packages/data-designer-engine/src/data_designer/engine/execution_graph/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/execution_graph/__init__.py @@ -30,6 +30,7 @@ from data_designer.engine.execution_graph.builder import GraphBuilder from data_designer.engine.execution_graph.column_descriptor import ColumnDescriptor from data_designer.engine.execution_graph.completion import ( + CHECKPOINT_VERSION, CompletionTracker, ThreadSafeCompletionTracker, ) @@ -61,4 +62,5 @@ # Completion tracking "CompletionTracker", "ThreadSafeCompletionTracker", + "CHECKPOINT_VERSION", ] diff --git a/packages/data-designer-engine/src/data_designer/engine/execution_graph/completion.py b/packages/data-designer-engine/src/data_designer/engine/execution_graph/completion.py index 152a3d68..e6375cc7 100644 --- a/packages/data-designer-engine/src/data_designer/engine/execution_graph/completion.py +++ b/packages/data-designer-engine/src/data_designer/engine/execution_graph/completion.py @@ -11,14 +11,24 @@ Both trackers use O(C) memory instead of O(C×R) by tracking fully completed columns as sets and only storing partial progress for in-progress columns. + +Checkpoint/Restart Support: +- to_checkpoint(): Create a compact checkpoint representing complete rows +- from_checkpoint(): Restore tracker state from a checkpoint """ from __future__ import annotations import threading +from typing import TYPE_CHECKING from data_designer.engine.execution_graph.node_id import BarrierNodeId, CellNodeId, NodeId +if TYPE_CHECKING: + from data_designer.engine.execution_graph.graph import ExecutionGraph + +CHECKPOINT_VERSION = 1 + class CompletionTracker: """Memory-efficient completion tracking for large datasets. @@ -160,6 +170,74 @@ def __len__(self) -> int: completed_barriers = len(self._completed_barriers) return completed_cells + completed_barriers + def to_checkpoint(self, graph: ExecutionGraph) -> dict: + """Create a row-complete checkpoint. + + Returns a compact checkpoint representing complete rows. The checkpoint + format is O(1) regardless of dataset size, storing only the count of + contiguously completed rows. + + Args: + graph: The ExecutionGraph to use for row completion checking. + + Returns: + A checkpoint dictionary with version and completed_rows count. + + Example: + >>> checkpoint = tracker.to_checkpoint(graph) + >>> # {"version": 1, "completed_rows": 5000} + """ + completed_rows = graph.get_completed_row_count(self) + return { + "version": CHECKPOINT_VERSION, + "completed_rows": completed_rows, + } + + @classmethod + def from_checkpoint(cls, checkpoint: dict, graph: ExecutionGraph) -> "CompletionTracker": + """Restore a tracker from a row-complete checkpoint. + + Marks all cells for completed rows as done, reconstructing the tracker + state from the compact checkpoint format. + + Args: + checkpoint: A checkpoint dictionary from to_checkpoint(). + graph: The ExecutionGraph to use for restoration. + + Returns: + A new CompletionTracker with the restored state. + + Raises: + ValueError: If the checkpoint version is incompatible. + + Example: + >>> tracker = CompletionTracker.from_checkpoint(checkpoint, graph) + >>> # Resume generation from where we left off + """ + version = checkpoint.get("version", 0) + if version != CHECKPOINT_VERSION: + raise ValueError(f"Incompatible checkpoint version: {version}, expected {CHECKPOINT_VERSION}") + + tracker = cls(graph.num_records) + completed_rows = checkpoint["completed_rows"] + + # Mark all cells for completed rows across all columns + for col_name in graph.column_names: + desc = graph.get_column_descriptor(col_name) + if desc.is_barrier: + # For barrier columns, mark the barrier as complete + tracker._completed_barriers.add(col_name) + + # Mark cells as complete for this column + if completed_rows == graph.num_records: + # Entire column is complete + tracker._completed_columns.add(col_name) + elif completed_rows > 0: + # Partial completion - store row indices + tracker._column_completion[col_name] = set(range(completed_rows)) + + return tracker + class ThreadSafeCompletionTracker: """Thread-safe completion tracker for concurrent access. @@ -267,3 +345,40 @@ def __len__(self) -> int: """Return the total number of completed nodes (thread-safe).""" with self._lock: return len(self._tracker) + + def to_checkpoint(self, graph: ExecutionGraph) -> dict: + """Create a row-complete checkpoint (thread-safe). + + Returns a compact checkpoint representing complete rows. The checkpoint + format is O(1) regardless of dataset size. + + Args: + graph: The ExecutionGraph to use for row completion checking. + + Returns: + A checkpoint dictionary with version and completed_rows count. + """ + with self._lock: + return self._tracker.to_checkpoint(graph) + + @classmethod + def from_checkpoint(cls, checkpoint: dict, graph: ExecutionGraph) -> "ThreadSafeCompletionTracker": + """Restore a tracker from a row-complete checkpoint (thread-safe). + + Marks all cells for completed rows as done, reconstructing the tracker + state from the compact checkpoint format. + + Args: + checkpoint: A checkpoint dictionary from to_checkpoint(). + graph: The ExecutionGraph to use for restoration. + + Returns: + A new ThreadSafeCompletionTracker with the restored state. + + Raises: + ValueError: If the checkpoint version is incompatible. + """ + inner_tracker = CompletionTracker.from_checkpoint(checkpoint, graph) + wrapper = cls(graph.num_records) + wrapper._tracker = inner_tracker + return wrapper diff --git a/packages/data-designer-engine/src/data_designer/engine/execution_graph/graph.py b/packages/data-designer-engine/src/data_designer/engine/execution_graph/graph.py index feb343d7..7025332c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/execution_graph/graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/execution_graph/graph.py @@ -445,3 +445,53 @@ def is_complete(self, completed: CompletionTrackerProtocol) -> bool: if node not in completed: return False return True + + def is_row_complete(self, row: int, completed: CompletionTrackerProtocol) -> bool: + """Check if all cells for a row are complete. + + A row is considered complete when all cells across all columns for that + row have been completed. For barrier columns, this also requires that + the barrier itself is complete. + + Args: + row: The row index to check. + completed: A completion tracker (or set) indicating completed nodes. + + Returns: + True if all cells for the row are complete, False otherwise. + + Raises: + ValueError: If row is out of range. + """ + if row < 0 or row >= self._num_records: + raise ValueError(f"Row {row} is out of range [0, {self._num_records})") + + for col_name in self._topo_order: + desc = self._columns[col_name] + if desc.is_barrier: + if BarrierNodeId(col_name) not in completed: + return False + if CellNodeId(row, col_name) not in completed: + return False + return True + + def get_completed_row_count(self, completed: CompletionTrackerProtocol) -> int: + """Get the count of contiguous complete rows starting from row 0. + + Returns the highest N where rows 0..N-1 are all complete. This is useful + for checkpoint/restart scenarios where you want to know how many rows + can be safely saved as a usable partial dataset. + + Args: + completed: A completion tracker (or set) indicating completed nodes. + + Returns: + The count of contiguous complete rows starting from row 0. + """ + count = 0 + for row in range(self._num_records): + if self.is_row_complete(row, completed): + count += 1 + else: + break + return count diff --git a/packages/data-designer-engine/tests/engine/execution_graph/test_completion.py b/packages/data-designer-engine/tests/engine/execution_graph/test_completion.py index 79cd889a..bbf7815c 100644 --- a/packages/data-designer-engine/tests/engine/execution_graph/test_completion.py +++ b/packages/data-designer-engine/tests/engine/execution_graph/test_completion.py @@ -5,11 +5,18 @@ import threading +import pytest + from data_designer.engine.execution_graph.completion import ( + CHECKPOINT_VERSION, CompletionTracker, ThreadSafeCompletionTracker, ) +from data_designer.engine.execution_graph.graph import ExecutionGraph from data_designer.engine.execution_graph.node_id import BarrierNodeId, CellNodeId +from data_designer.engine.execution_graph.traits import ExecutionTraits + +from .conftest import create_descriptor # noqa: TID252 # --- CompletionTracker tests --- @@ -203,3 +210,275 @@ def test_thread_safe_tracker_reset() -> None: assert len(tracker) == 0 assert not tracker.is_complete(CellNodeId(0, "col_a")) + + +# --- Checkpoint tests --- + + +def _create_simple_graph(num_records: int = 10) -> ExecutionGraph: + """Create a simple graph with two columns for checkpoint testing.""" + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + return ExecutionGraph(num_records=num_records, column_descriptors=descriptors) + + +def _create_barrier_graph(num_records: int = 10) -> ExecutionGraph: + """Create a graph with a barrier column for checkpoint testing.""" + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor("b", ExecutionTraits.BARRIER, ["a"]), + ] + return ExecutionGraph(num_records=num_records, column_descriptors=descriptors) + + +def test_checkpoint_no_completed_rows() -> None: + graph = _create_simple_graph(num_records=10) + tracker = CompletionTracker(num_records=10) + + checkpoint = tracker.to_checkpoint(graph) + + assert checkpoint["version"] == CHECKPOINT_VERSION + assert checkpoint["completed_rows"] == 0 + + +def test_checkpoint_some_completed_rows() -> None: + graph = _create_simple_graph(num_records=10) + tracker = CompletionTracker(num_records=10) + + # Complete first 5 rows across all columns + for row in range(5): + tracker.mark_complete(CellNodeId(row, "a")) + tracker.mark_complete(CellNodeId(row, "b")) + + checkpoint = tracker.to_checkpoint(graph) + + assert checkpoint["version"] == CHECKPOINT_VERSION + assert checkpoint["completed_rows"] == 5 + + +def test_checkpoint_all_completed_rows() -> None: + graph = _create_simple_graph(num_records=5) + tracker = CompletionTracker(num_records=5) + + # Complete all rows + for row in range(5): + tracker.mark_complete(CellNodeId(row, "a")) + tracker.mark_complete(CellNodeId(row, "b")) + + checkpoint = tracker.to_checkpoint(graph) + + assert checkpoint["version"] == CHECKPOINT_VERSION + assert checkpoint["completed_rows"] == 5 + + +def test_checkpoint_partial_row_not_counted() -> None: + graph = _create_simple_graph(num_records=10) + tracker = CompletionTracker(num_records=10) + + # Complete first 3 rows fully + for row in range(3): + tracker.mark_complete(CellNodeId(row, "a")) + tracker.mark_complete(CellNodeId(row, "b")) + + # Complete row 3 only in column a (partial row) + tracker.mark_complete(CellNodeId(3, "a")) + + checkpoint = tracker.to_checkpoint(graph) + + # Only 3 complete rows (row 3 is incomplete) + assert checkpoint["completed_rows"] == 3 + + +def test_checkpoint_non_contiguous_rows_only_counts_contiguous() -> None: + graph = _create_simple_graph(num_records=10) + tracker = CompletionTracker(num_records=10) + + # Complete rows 0, 1, 2, but skip row 3, complete row 4 + for row in [0, 1, 2, 4]: + tracker.mark_complete(CellNodeId(row, "a")) + tracker.mark_complete(CellNodeId(row, "b")) + + checkpoint = tracker.to_checkpoint(graph) + + # Only count contiguous rows from 0 (0, 1, 2) + assert checkpoint["completed_rows"] == 3 + + +def test_from_checkpoint_restores_state() -> None: + graph = _create_simple_graph(num_records=10) + checkpoint = {"version": CHECKPOINT_VERSION, "completed_rows": 5} + + tracker = CompletionTracker.from_checkpoint(checkpoint, graph) + + # All cells in rows 0-4 should be complete + for row in range(5): + assert tracker.is_complete(CellNodeId(row, "a")) + assert tracker.is_complete(CellNodeId(row, "b")) + + # Cells in rows 5-9 should not be complete + for row in range(5, 10): + assert not tracker.is_complete(CellNodeId(row, "a")) + assert not tracker.is_complete(CellNodeId(row, "b")) + + +def test_from_checkpoint_restores_all_complete() -> None: + graph = _create_simple_graph(num_records=5) + checkpoint = {"version": CHECKPOINT_VERSION, "completed_rows": 5} + + tracker = CompletionTracker.from_checkpoint(checkpoint, graph) + + # All cells should be complete + for row in range(5): + assert tracker.is_complete(CellNodeId(row, "a")) + assert tracker.is_complete(CellNodeId(row, "b")) + + # Both columns should be marked as fully complete + assert tracker.is_column_complete("a") + assert tracker.is_column_complete("b") + + +def test_from_checkpoint_empty_checkpoint() -> None: + graph = _create_simple_graph(num_records=10) + checkpoint = {"version": CHECKPOINT_VERSION, "completed_rows": 0} + + tracker = CompletionTracker.from_checkpoint(checkpoint, graph) + + # No cells should be complete + for row in range(10): + assert not tracker.is_complete(CellNodeId(row, "a")) + assert not tracker.is_complete(CellNodeId(row, "b")) + + +def test_from_checkpoint_invalid_version() -> None: + graph = _create_simple_graph(num_records=10) + checkpoint = {"version": 999, "completed_rows": 5} + + with pytest.raises(ValueError, match="Incompatible checkpoint version"): + CompletionTracker.from_checkpoint(checkpoint, graph) + + +def test_checkpoint_roundtrip() -> None: + graph = _create_simple_graph(num_records=10) + tracker = CompletionTracker(num_records=10) + + # Complete first 7 rows + for row in range(7): + tracker.mark_complete(CellNodeId(row, "a")) + tracker.mark_complete(CellNodeId(row, "b")) + + # Create checkpoint and restore + checkpoint = tracker.to_checkpoint(graph) + restored = CompletionTracker.from_checkpoint(checkpoint, graph) + + # Verify restored state matches original + for row in range(10): + for col in ["a", "b"]: + assert tracker.is_complete(CellNodeId(row, col)) == restored.is_complete(CellNodeId(row, col)) + + +def test_checkpoint_with_barrier_column() -> None: + graph = _create_barrier_graph(num_records=5) + tracker = CompletionTracker(num_records=5) + + # Complete all cells in column a + for row in range(5): + tracker.mark_complete(CellNodeId(row, "a")) + + # Complete barrier and all cells in column b + tracker.mark_complete(BarrierNodeId("b")) + for row in range(5): + tracker.mark_complete(CellNodeId(row, "b")) + + checkpoint = tracker.to_checkpoint(graph) + assert checkpoint["completed_rows"] == 5 + + +def test_from_checkpoint_with_barrier_marks_barrier_complete() -> None: + graph = _create_barrier_graph(num_records=5) + checkpoint = {"version": CHECKPOINT_VERSION, "completed_rows": 5} + + tracker = CompletionTracker.from_checkpoint(checkpoint, graph) + + # Barrier should be marked complete + assert tracker.is_barrier_complete("b") + + # All cells should be complete + for row in range(5): + assert tracker.is_complete(CellNodeId(row, "a")) + assert tracker.is_complete(CellNodeId(row, "b")) + + +def test_checkpoint_barrier_incomplete_means_no_rows_complete() -> None: + graph = _create_barrier_graph(num_records=5) + tracker = CompletionTracker(num_records=5) + + # Complete all cells in column a + for row in range(5): + tracker.mark_complete(CellNodeId(row, "a")) + + # Barrier not complete, so no rows are "complete" even though column a is done + checkpoint = tracker.to_checkpoint(graph) + assert checkpoint["completed_rows"] == 0 + + +# --- ThreadSafeCompletionTracker checkpoint tests --- + + +def test_thread_safe_checkpoint_no_completed_rows() -> None: + graph = _create_simple_graph(num_records=10) + tracker = ThreadSafeCompletionTracker(num_records=10) + + checkpoint = tracker.to_checkpoint(graph) + + assert checkpoint["version"] == CHECKPOINT_VERSION + assert checkpoint["completed_rows"] == 0 + + +def test_thread_safe_checkpoint_some_completed_rows() -> None: + graph = _create_simple_graph(num_records=10) + tracker = ThreadSafeCompletionTracker(num_records=10) + + for row in range(5): + tracker.mark_complete(CellNodeId(row, "a")) + tracker.mark_complete(CellNodeId(row, "b")) + + checkpoint = tracker.to_checkpoint(graph) + + assert checkpoint["completed_rows"] == 5 + + +def test_thread_safe_from_checkpoint() -> None: + graph = _create_simple_graph(num_records=10) + checkpoint = {"version": CHECKPOINT_VERSION, "completed_rows": 5} + + tracker = ThreadSafeCompletionTracker.from_checkpoint(checkpoint, graph) + + for row in range(5): + assert tracker.is_complete(CellNodeId(row, "a")) + assert tracker.is_complete(CellNodeId(row, "b")) + + for row in range(5, 10): + assert not tracker.is_complete(CellNodeId(row, "a")) + assert not tracker.is_complete(CellNodeId(row, "b")) + + +def test_thread_safe_checkpoint_roundtrip() -> None: + graph = _create_simple_graph(num_records=10) + tracker = ThreadSafeCompletionTracker(num_records=10) + + for row in range(7): + tracker.mark_complete(CellNodeId(row, "a")) + tracker.mark_complete(CellNodeId(row, "b")) + + checkpoint = tracker.to_checkpoint(graph) + restored = ThreadSafeCompletionTracker.from_checkpoint(checkpoint, graph) + + for row in range(10): + for col in ["a", "b"]: + assert tracker.is_complete(CellNodeId(row, col)) == restored.is_complete(CellNodeId(row, col)) diff --git a/packages/data-designer-engine/tests/engine/execution_graph/test_graph.py b/packages/data-designer-engine/tests/engine/execution_graph/test_graph.py index 3dcaa87b..0f8be2aa 100644 --- a/packages/data-designer-engine/tests/engine/execution_graph/test_graph.py +++ b/packages/data-designer-engine/tests/engine/execution_graph/test_graph.py @@ -372,3 +372,205 @@ def test_get_generator_and_config() -> None: gen_cls, config = graph.get_generator_and_config(CellNodeId(0, "my_col")) assert gen_cls is desc.generator_cls assert config is desc.config + + +# --- Row completion tests --- + + +def test_is_row_complete_no_cells_done() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=10, column_descriptors=descriptors) + tracker = CompletionTracker(10) + + assert not graph.is_row_complete(0, tracker) + + +def test_is_row_complete_partial_row() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=10, column_descriptors=descriptors) + tracker = CompletionTracker(10) + + # Only complete column a for row 0 + tracker.mark_complete(CellNodeId(0, "a")) + + assert not graph.is_row_complete(0, tracker) + + +def test_is_row_complete_full_row() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=10, column_descriptors=descriptors) + tracker = CompletionTracker(10) + + tracker.mark_complete(CellNodeId(0, "a")) + tracker.mark_complete(CellNodeId(0, "b")) + + assert graph.is_row_complete(0, tracker) + + +def test_is_row_complete_with_barrier_incomplete() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor("b", ExecutionTraits.BARRIER, ["a"]), + ] + graph = ExecutionGraph(num_records=3, column_descriptors=descriptors) + tracker = CompletionTracker(3) + + # Complete all cells in column a and column b, but not the barrier + for row in range(3): + tracker.mark_complete(CellNodeId(row, "a")) + tracker.mark_complete(CellNodeId(row, "b")) + + # Row should not be complete because barrier is not complete + assert not graph.is_row_complete(0, tracker) + + +def test_is_row_complete_with_barrier_complete() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor("b", ExecutionTraits.BARRIER, ["a"]), + ] + graph = ExecutionGraph(num_records=3, column_descriptors=descriptors) + tracker = CompletionTracker(3) + + # Complete all cells and the barrier + for row in range(3): + tracker.mark_complete(CellNodeId(row, "a")) + tracker.mark_complete(BarrierNodeId("b")) + for row in range(3): + tracker.mark_complete(CellNodeId(row, "b")) + + assert graph.is_row_complete(0, tracker) + assert graph.is_row_complete(1, tracker) + assert graph.is_row_complete(2, tracker) + + +def test_is_row_complete_out_of_range() -> None: + desc = create_descriptor("a", ExecutionTraits.START) + graph = ExecutionGraph(num_records=10, column_descriptors=[desc]) + tracker = CompletionTracker(10) + + with pytest.raises(ValueError, match="Row 10 is out of range"): + graph.is_row_complete(10, tracker) + + with pytest.raises(ValueError, match="Row -1 is out of range"): + graph.is_row_complete(-1, tracker) + + +def test_get_completed_row_count_empty() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=10, column_descriptors=descriptors) + tracker = CompletionTracker(10) + + assert graph.get_completed_row_count(tracker) == 0 + + +def test_get_completed_row_count_some_rows() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=10, column_descriptors=descriptors) + tracker = CompletionTracker(10) + + # Complete first 5 rows + for row in range(5): + tracker.mark_complete(CellNodeId(row, "a")) + tracker.mark_complete(CellNodeId(row, "b")) + + assert graph.get_completed_row_count(tracker) == 5 + + +def test_get_completed_row_count_all_rows() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=5, column_descriptors=descriptors) + tracker = CompletionTracker(5) + + for row in range(5): + tracker.mark_complete(CellNodeId(row, "a")) + tracker.mark_complete(CellNodeId(row, "b")) + + assert graph.get_completed_row_count(tracker) == 5 + + +def test_get_completed_row_count_non_contiguous() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor( + "b", + ExecutionTraits.CELL_BY_CELL | ExecutionTraits.ROW_STREAMABLE, + ["a"], + ), + ] + graph = ExecutionGraph(num_records=10, column_descriptors=descriptors) + tracker = CompletionTracker(10) + + # Complete rows 0, 1, 2, skip 3, complete 4 + for row in [0, 1, 2, 4]: + tracker.mark_complete(CellNodeId(row, "a")) + tracker.mark_complete(CellNodeId(row, "b")) + + # Only count contiguous from row 0 + assert graph.get_completed_row_count(tracker) == 3 + + +def test_get_completed_row_count_with_barrier() -> None: + descriptors = [ + create_descriptor("a", ExecutionTraits.START | ExecutionTraits.ROW_STREAMABLE), + create_descriptor("b", ExecutionTraits.BARRIER, ["a"]), + ] + graph = ExecutionGraph(num_records=5, column_descriptors=descriptors) + tracker = CompletionTracker(5) + + # Complete all cells in column a + for row in range(5): + tracker.mark_complete(CellNodeId(row, "a")) + + # Without barrier complete, no rows are complete + assert graph.get_completed_row_count(tracker) == 0 + + # Complete barrier and all cells in column b + tracker.mark_complete(BarrierNodeId("b")) + for row in range(5): + tracker.mark_complete(CellNodeId(row, "b")) + + # Now all rows are complete + assert graph.get_completed_row_count(tracker) == 5 diff --git a/plans/async-engine/execution_graph_builder.plan.md b/plans/async-engine/execution_graph_builder.plan.md index 077d5f67..7af825f1 100644 --- a/plans/async-engine/execution_graph_builder.plan.md +++ b/plans/async-engine/execution_graph_builder.plan.md @@ -29,6 +29,9 @@ todos: - id: tests content: Add comprehensive tests for graph building, node iteration, and dependency resolution status: pending + - id: checkpoint-support + content: Add checkpoint/restart support with row-complete batch checkpointing + status: done isProject: false --- @@ -354,3 +357,107 @@ The graph builder validates: 2. No circular dependencies (enforced by topological order) 3. At least one column has `START` trait (can generate from scratch) +## Checkpoint/Restart Support + +The execution graph supports **row-complete batch checkpointing** for resuming interrupted generation runs. + +### Key Design Decisions + +1. **Row-complete batches**: Checkpoint when all cells for a batch of rows are complete +2. **Compact checkpoint format**: Store `{"completed_rows": 5000}` not individual cell indices +3. **Usable partial datasets**: Each checkpoint can be loaded and used independently +4. **Restore via graph**: Restoration uses the graph structure to expand row ranges into cell-level tracking + +### Mental Model + +A row is "complete" when **all columns for that row** have their cells done: +- For regular columns: `CellNodeId(row, column)` is complete +- For barrier columns: the `BarrierNodeId(column)` is complete AND `CellNodeId(row, column)` is complete + +### Checkpoint Format + +```json +{ + "version": 1, + "completed_rows": 5000 +} +``` + +This is **O(1) regardless of dataset size** - no row indices stored. + +### Row Completion API + +`ExecutionGraph` provides methods for row-level completion checking: + +```python +# Check if a specific row is complete (all columns done) +graph.is_row_complete(row=5, completed=tracker) + +# Get count of contiguous complete rows from row 0 +completed_count = graph.get_completed_row_count(tracker) +``` + +### Checkpoint API + +Both `CompletionTracker` and `ThreadSafeCompletionTracker` support checkpointing: + +```python +# Create checkpoint +checkpoint = tracker.to_checkpoint(graph) +# Returns: {"version": 1, "completed_rows": 5000} + +# Restore from checkpoint +tracker = CompletionTracker.from_checkpoint(checkpoint, graph) +# or +tracker = ThreadSafeCompletionTracker.from_checkpoint(checkpoint, graph) +``` + +### Usage Pattern + +**Checkpoint during execution:** +```python +BATCH_SIZE = 1000 +last_checkpoint = 0 + +# Periodically check and checkpoint +completed = graph.get_completed_row_count(tracker) +if completed >= last_checkpoint + BATCH_SIZE: + storage.update_metadata({"checkpoint": tracker.to_checkpoint(graph)}) + last_checkpoint = completed +``` + +**Resume from checkpoint:** +```python +try: + metadata = storage.read_metadata() + if "checkpoint" in metadata: + tracker = CompletionTracker.from_checkpoint(metadata["checkpoint"], graph) + print(f"Resumed: {metadata['checkpoint']['completed_rows']} rows complete") + else: + tracker = CompletionTracker(graph.num_records) +except FileNotFoundError: + tracker = CompletionTracker(graph.num_records) + +# iter_ready_nodes automatically skips completed rows +for node in graph.iter_ready_nodes(tracker): + ... +``` + +### Why Row-Complete Batches? + +| Approach | Checkpoint Size (1M records, 50% done) | Usable? | +|----------|----------------------------------------|---------| +| Per-cell indices | ~4MB per partial column | Partial | +| Row-complete batch | ~50 bytes | ✓ Yes | + +Row-complete batches give you: +- **Constant-size checkpoints**: O(1) storage +- **Usable partial data**: Each checkpoint is a valid dataset subset +- **Simple mental model**: "5000 rows are done" + +### Out-of-Order Completion Note + +If rows don't complete in order (due to concurrent workers), the current implementation uses **Option A: Wait for contiguous completion**: +- Only checkpoint contiguous rows from 0 +- Some work may be re-done on restart +- Simpler implementation that matches the batch-oriented nature of dataset generation From 9a364efd4bb8fdf1d8f673b8bab26bab46901617 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Fri, 30 Jan 2026 11:17:08 -0500 Subject: [PATCH 3/3] docs: document barrier column checkpoint limitation and options Add "Open Decision: Checkpointing with Barrier Columns" section to the execution graph builder plan. This documents that barrier columns prevent intermediate checkpoints and outlines four potential approaches to address this in the future. --- .../execution_graph_builder.plan.md | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/plans/async-engine/execution_graph_builder.plan.md b/plans/async-engine/execution_graph_builder.plan.md index 7af825f1..3e20e7cc 100644 --- a/plans/async-engine/execution_graph_builder.plan.md +++ b/plans/async-engine/execution_graph_builder.plan.md @@ -461,3 +461,114 @@ If rows don't complete in order (due to concurrent workers), the current impleme - Only checkpoint contiguous rows from 0 - Some work may be re-done on restart - Simpler implementation that matches the batch-oriented nature of dataset generation + +--- + +## Open Decision: Checkpointing with Barrier Columns + +### The Problem + +With the current design, **barrier columns prevent any intermediate checkpoints**. + +A row is only "complete" when ALL columns (including barriers) are done for that row. Since barrier columns require ALL input rows before producing ANY output, no rows can be checkpointed until the barrier finishes processing the entire dataset. + +**Example Pipeline**: `A (start) → B (row-streamable) → C (barrier/validation)` + +| Stage | `get_completed_row_count()` | +|-------|----------------------------| +| A: 5000/10000 rows done | 0 | +| B: 5000/10000 rows done | 0 | +| C barrier waiting | 0 | +| C barrier executing | 0 | +| C barrier complete | 10000 | + +**Impact**: If generation fails during or before barrier execution, all pre-barrier work is lost. + +### Options + +#### Option A: Accept Current Behavior (Dataset-Scoped Barriers) + +Keep barriers dataset-scoped. Accept that checkpoints only occur after all barriers complete. + +| Pros | Cons | +|------|------| +| Simple mental model | No intermediate checkpoints with barriers | +| No changes needed | Work lost on failure before barrier completes | +| Correct for cross-row operations (global dedup, normalization) | | + +#### Option B: Batch-Scoped Barriers + +Execute barriers per-batch instead of per-dataset. Each batch completes independently. + +``` +Batch 0: A[0:999] → B[0:999] → Barrier(C, batch=0) → C[0:999] → checkpoint(1000) +Batch 1: A[1000:1999] → B[1000:1999] → Barrier(C, batch=1) → C[1000:1999] → checkpoint(2000) +``` + +**Implementation changes:** +```python +@dataclass(frozen=True, slots=True) +class BarrierNodeId: + column: str + batch: int | None = None # None = dataset-scoped, int = batch-scoped +``` + +| Pros | Cons | +|------|------| +| Enables incremental checkpoints | More complex graph structure | +| Limits work lost on failure | Not valid for cross-row operations | +| Natural fit for batch-oriented execution | Requires generator to declare batch-safety | + +**Semantic validity depends on barrier type:** +- ✅ Row-independent validation (format checks, schema validation) +- ✅ Per-batch statistics +- ❌ Global uniqueness checks +- ❌ Min-max normalization across dataset + +#### Option C: Column-Level Checkpointing + +Extend checkpoint format to track completed columns, not just completed rows. + +```json +{ + "version": 2, + "completed_rows": 0, + "completed_columns": ["A", "B"] +} +``` + +On restart, skip re-generating completed columns and feed their data into the barrier. + +| Pros | Cons | +|------|------| +| Preserves pre-barrier work | More complex checkpoint format | +| Works with dataset-scoped barriers | Requires loading partial data on restart | +| No changes to barrier semantics | Larger checkpoint size O(C) vs O(1) | + +#### Option D: Hybrid Approach + +Add a `BATCH_BARRIER` trait for barriers that are batch-safe, keeping `BARRIER` for dataset-scoped operations. + +```python +class ExecutionTraits(Flag): + ... + BARRIER = auto() # Dataset-scoped (e.g., global dedup) + BATCH_BARRIER = auto() # Batch-scoped (e.g., validation) +``` + +Generators declare which type they support. Graph builder creates appropriate node structure. + +| Pros | Cons | +|------|------| +| Flexibility - use right tool for job | Most complex implementation | +| Preserves correctness for cross-row ops | Generators must declare batch-safety | +| Enables checkpoints where semantically valid | Two barrier code paths to maintain | + +### Decision Needed + +Which approach should we implement? Considerations: + +1. **How common are barriers in typical pipelines?** (validation is common) +2. **How long do generation runs typically take?** (longer = more checkpoint value) +3. **Are there cross-row barriers we need to support?** (global dedup, normalization) +4. **Implementation complexity vs. user value tradeoff?**