From dc041f78a95d40b835a9ec34356f10248885daf9 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 25 Nov 2025 12:16:17 -0700 Subject: [PATCH 01/64] Add generation type to ModelConfig --- src/data_designer/config/models.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index 6bff8efd..3e06a8fc 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -205,11 +205,18 @@ def _is_value_in_range(self, value: float, min_value: float, max_value: float) - return min_value <= value <= max_value +class GenerationType(str, Enum): + CHAT_COMPLETION = "chat-completion" + TEXT_EMBEDDING = "text-embedding" + IMAGE_GENERATION = "image-generation" + + class ModelConfig(ConfigBase): alias: str model: str inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters) provider: Optional[str] = None + generation_type: GenerationType = GenerationType.CHAT_COMPLETION class ModelProvider(ConfigBase): From 0d6b830f6439b6bece0b642c921281fd817de5d3 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 25 Nov 2025 12:21:28 -0700 Subject: [PATCH 02/64] pass tests --- src/data_designer/config/default_model_settings.py | 5 +++-- src/data_designer/config/models.py | 2 +- tests/cli/repositories/test_model_repository.py | 4 +++- tests/config/test_config_builder.py | 2 +- tests/config/test_models.py | 2 +- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/data_designer/config/default_model_settings.py b/src/data_designer/config/default_model_settings.py index 32c1d42b..33d6dad4 100644 --- a/src/data_designer/config/default_model_settings.py +++ b/src/data_designer/config/default_model_settings.py @@ -103,7 +103,8 @@ def resolve_seed_default_model_settings() -> None: f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}" ) save_config_file( - MODEL_CONFIGS_FILE_PATH, {"model_configs": [mc.model_dump() for mc in get_builtin_model_configs()]} + MODEL_CONFIGS_FILE_PATH, + {"model_configs": [mc.model_dump(mode="json") for mc in get_builtin_model_configs()]}, ) if not MODEL_PROVIDERS_FILE_PATH.exists(): @@ -111,7 +112,7 @@ def resolve_seed_default_model_settings() -> None: f"πŸͺ„ Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}" ) save_config_file( - MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]} + MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump(mode="json") for p in get_builtin_model_providers()]} ) if not MANAGED_ASSETS_PATH.exists(): diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index 3e06a8fc..17698346 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -215,8 +215,8 @@ class ModelConfig(ConfigBase): alias: str model: str inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters) - provider: Optional[str] = None generation_type: GenerationType = GenerationType.CHAT_COMPLETION + provider: Optional[str] = None class ModelProvider(ConfigBase): diff --git a/tests/cli/repositories/test_model_repository.py b/tests/cli/repositories/test_model_repository.py index 01884b5c..624cd360 100644 --- a/tests/cli/repositories/test_model_repository.py +++ b/tests/cli/repositories/test_model_repository.py @@ -21,7 +21,9 @@ def test_load_does_not_exist(): def test_load_exists(tmp_path: Path, stub_model_configs: list[ModelConfig]): model_configs_file_path = tmp_path / MODEL_CONFIGS_FILE_NAME - save_config_file(model_configs_file_path, {"model_configs": [mc.model_dump() for mc in stub_model_configs]}) + save_config_file( + model_configs_file_path, {"model_configs": [mc.model_dump(mode="json") for mc in stub_model_configs]} + ) repository = ModelRepository(tmp_path) assert repository.load() is not None assert repository.load().model_configs == stub_model_configs diff --git a/tests/config/test_config_builder.py b/tests/config/test_config_builder.py index 337d934e..aab8112a 100644 --- a/tests/config/test_config_builder.py +++ b/tests/config/test_config_builder.py @@ -54,7 +54,7 @@ def stub_data_designer_builder(stub_data_designer_builder_config_str): def test_loading_model_configs_in_constructor(stub_model_configs): - stub_model_configs_dict = [mc.model_dump() for mc in stub_model_configs] + stub_model_configs_dict = [mc.model_dump(mode="json") for mc in stub_model_configs] # test loading model configs from a list builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) assert builder.model_configs == stub_model_configs diff --git a/tests/config/test_models.py b/tests/config/test_models.py index 9ccda6d5..6a3d7b25 100644 --- a/tests/config/test_models.py +++ b/tests/config/test_models.py @@ -212,7 +212,7 @@ def test_load_model_configs(): ModelConfig(alias="test", model="test"), ModelConfig(alias="test2", model="test2"), ] - stub_model_configs_dict_list = [mc.model_dump() for mc in stub_model_configs] + stub_model_configs_dict_list = [mc.model_dump(mode="json") for mc in stub_model_configs] assert load_model_configs([]) == [] assert load_model_configs(stub_model_configs) == stub_model_configs From 254fd8a71e261a7bb3ac71ad14d8aa10772529ea Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 25 Nov 2025 14:36:02 -0700 Subject: [PATCH 03/64] added generate_text_embeddings --- .../generators/llm_generators.py | 1 - src/data_designer/engine/models/facade.py | 51 ++++++++++++++++++- .../generators/test_llm_generators.py | 17 ------- tests/engine/models/test_facade.py | 48 +++++++++++------ 4 files changed, 82 insertions(+), 35 deletions(-) diff --git a/src/data_designer/engine/column_generators/generators/llm_generators.py b/src/data_designer/engine/column_generators/generators/llm_generators.py index ee0ab58a..8f4cfc90 100644 --- a/src/data_designer/engine/column_generators/generators/llm_generators.py +++ b/src/data_designer/engine/column_generators/generators/llm_generators.py @@ -96,7 +96,6 @@ def generate(self, data: dict) -> dict: max_correction_steps=self.max_conversation_correction_steps, max_conversation_restarts=self.max_conversation_restarts, purpose=f"running generation for column '{self.config.name}'", - **self.inference_parameters.generate_kwargs, ) data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response)) diff --git a/src/data_designer/engine/models/facade.py b/src/data_designer/engine/models/facade.py index 93ca0fd7..b0ad3472 100644 --- a/src/data_designer/engine/models/facade.py +++ b/src/data_designer/engine/models/facade.py @@ -9,7 +9,7 @@ from typing import Any from litellm.types.router import DeploymentTypedDict, LiteLLM_Params -from litellm.types.utils import ModelResponse +from litellm.types.utils import EmbeddingResponse, ModelResponse from data_designer.config.models import ModelConfig, ModelProvider from data_designer.engine.model_provider import ModelProviderRegistry @@ -67,6 +67,7 @@ def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = extra={"model": self.model_name, "messages": messages, "sensitive": True}, ) response = None + kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs} if self.model_provider.extra_body: kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} try: @@ -87,6 +88,41 @@ def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = if not skip_usage_tracking: self._track_usage(response) + @catch_llm_exceptions + def generate_text_embeddings( + self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs + ) -> list[list[float]]: + logger.debug( + f"Generating embeddings with model {self.model_name!r}...", + extra={ + "model": self.model_name, + "input_count": len(input_texts), + "sensitive": True, + }, + ) + kwargs |= self._model_config.inference_parameters.generate_kwargs + if self.model_provider.extra_body: + kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} + try: + response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs) + logger.debug( + f"Received embeddings from model {self.model_name!r}", + extra={ + "model": self.model_name, + "embedding_count": len(response.data) if response.data else 0, + "usage": self._usage_stats.model_dump(), + }, + ) + if response.data and len(response.data) == len(input_texts): + return [data["embedding"] for data in response.data] + else: + raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}") + except Exception as e: + raise e + finally: + if not skip_usage_tracking: + self._track_usage_from_embedding(response) + @catch_llm_exceptions def generate( self, @@ -223,3 +259,16 @@ def _track_usage(self, response: ModelResponse | None) -> None: ), request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) + + def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None: + if response is None: + self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) + return + if response.usage is not None and response.usage.prompt_tokens is not None: + self._usage_stats.extend( + token_usage=TokenUsageStats( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=0, + ), + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + ) diff --git a/tests/engine/column_generators/generators/test_llm_generators.py b/tests/engine/column_generators/generators/test_llm_generators.py index 259f3a08..acaa2c6f 100644 --- a/tests/engine/column_generators/generators/test_llm_generators.py +++ b/tests/engine/column_generators/generators/test_llm_generators.py @@ -259,20 +259,3 @@ def test_generate_with_json_deserialization(): result = generator.generate(data) assert result["test_column"] == {"result": "json_output"} - - -def test_generate_with_inference_parameters(): - generator, _, mock_model, _, mock_inference_params, mock_prompt_renderer, mock_response_recipe = ( - _create_generator_with_mocks() - ) - - mock_inference_params.generate_kwargs = {"temperature": 0.7, "max_tokens": 100} - _setup_generate_mocks(mock_prompt_renderer, mock_response_recipe, mock_model) - - data = {"input": "test_input"} - generator.generate(data) - - call_args = mock_model.generate.call_args - assert call_args[1]["temperature"] == 0.7 - assert call_args[1]["max_tokens"] == 100 - assert call_args[1]["purpose"] == "running generation for column 'test_column'" diff --git a/tests/engine/models/test_facade.py b/tests/engine/models/test_facade.py index 4fa73d9a..d240eeaa 100644 --- a/tests/engine/models/test_facade.py +++ b/tests/engine/models/test_facade.py @@ -133,7 +133,9 @@ def raise_exception(*args, **kwargs): stub_model_facade.completion(messages) -def test_completion_with_kwargs(stub_model_facade, stub_expected_response): +def test_completion_kwargs_overrides_model_config_generate_kwargs( + stub_model_configs, stub_model_facade, stub_expected_response +): captured_kwargs = {} def mock_completion(model_name, messages, **kwargs): @@ -147,28 +149,42 @@ def mock_completion(model_name, messages, **kwargs): result = stub_model_facade.completion(messages, **kwargs) assert result == stub_expected_response - assert captured_kwargs == kwargs + # completion kwargs overrides model config generate kwargs + assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs} @patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) -def test_completion_with_extra_body(mock_router_completion, stub_model_facade): +def test_provider_extra_body_overrides_completion_kwargs(mock_router_completion, stub_model_configs, stub_model_facade): messages = [{"role": "user", "content": "test"}] + stub_provider_extra_body = {"foo": "bar"} - # completion call has no extra body argument and provider has no extra body + # model config has generate kwargs, completion call has no kwargs, and provider has no extra body _ = stub_model_facade.completion(messages) assert len(mock_router_completion.call_args) == 2 assert mock_router_completion.call_args[0][1] == "stub-model-text" assert mock_router_completion.call_args[0][2] == messages + assert mock_router_completion.call_args[1] == stub_model_configs[0].inference_parameters.generate_kwargs - # completion call has no extra body argument and provider has extra body. - # Should pull extra body from model provider - custom_extra_body = {"some_custom_key": "some_custom_value"} - stub_model_facade.model_provider.extra_body = custom_extra_body - _ = stub_model_facade.completion(messages) - assert mock_router_completion.call_args[1] == {"extra_body": custom_extra_body} - - # completion call has extra body argument and provider has extra body. - # Should merge the two with provider extra body taking precedence - completion_extra_body = {"some_completion_key": "some_completion_value", "some_custom_key": "some_different_value"} - _ = stub_model_facade.completion(messages, extra_body=completion_extra_body) - assert mock_router_completion.call_args[1] == {"extra_body": {**completion_extra_body, **custom_extra_body}} + # model config has generate kwargs, completion call has kwargs, and provider has no extra body + # completion kwargs overrides model config generate kwargs + _ = stub_model_facade.completion(messages, temperature=0.1) + assert len(mock_router_completion.call_args) == 2 + assert mock_router_completion.call_args[0][1] == "stub-model-text" + assert mock_router_completion.call_args[0][2] == messages + assert mock_router_completion.call_args[1] == { + **stub_model_configs[0].inference_parameters.generate_kwargs, + "temperature": 0.1, + } + + # model config has generate kwargs, completion call has kwargs, and provider has extra body + # provider extra body overrides completion kwargs + stub_model_facade.model_provider.extra_body = stub_provider_extra_body + _ = stub_model_facade.completion(messages, temperature=0.15, extra_body={"foo": "bat"}) + assert len(mock_router_completion.call_args) == 2 + assert mock_router_completion.call_args[0][1] == "stub-model-text" + assert mock_router_completion.call_args[0][2] == messages + assert mock_router_completion.call_args[1] == { + **stub_model_configs[0].inference_parameters.generate_kwargs, + "temperature": 0.15, + "extra_body": stub_provider_extra_body, + } From 1126ea1bdfdf842ed8073aaf0cea6e405a77c0ce Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 25 Nov 2025 15:18:59 -0700 Subject: [PATCH 04/64] tests --- src/data_designer/engine/models/facade.py | 21 +-- tests/engine/models/test_facade.py | 152 +++++++++++++--------- 2 files changed, 105 insertions(+), 68 deletions(-) diff --git a/src/data_designer/engine/models/facade.py b/src/data_designer/engine/models/facade.py index b0ad3472..4e3f36ef 100644 --- a/src/data_designer/engine/models/facade.py +++ b/src/data_designer/engine/models/facade.py @@ -67,11 +67,9 @@ def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = extra={"model": self.model_name, "messages": messages, "sensitive": True}, ) response = None - kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs} - if self.model_provider.extra_body: - kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} + kwargs = self.consolidate_kwargs(**kwargs) try: - response = self._router.completion(self.model_name, messages, **kwargs) + response = self._router.completion(model=self.model_name, messages=messages, **kwargs) logger.debug( f"Received completion from model {self.model_name!r}", extra={ @@ -85,9 +83,15 @@ def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = except Exception as e: raise e finally: - if not skip_usage_tracking: + if not skip_usage_tracking and response is not None: self._track_usage(response) + def consolidate_kwargs(self, **kwargs) -> dict[str, Any]: + kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs} + if self.model_provider.extra_body: + kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} + return kwargs + @catch_llm_exceptions def generate_text_embeddings( self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs @@ -100,9 +104,8 @@ def generate_text_embeddings( "sensitive": True, }, ) - kwargs |= self._model_config.inference_parameters.generate_kwargs - if self.model_provider.extra_body: - kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} + kwargs = self.consolidate_kwargs(**kwargs) + response = None try: response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs) logger.debug( @@ -120,7 +123,7 @@ def generate_text_embeddings( except Exception as e: raise e finally: - if not skip_usage_tracking: + if not skip_usage_tracking and response is not None: self._track_usage_from_embedding(response) @catch_llm_exceptions diff --git a/tests/engine/models/test_facade.py b/tests/engine/models/test_facade.py index d240eeaa..afe27730 100644 --- a/tests/engine/models/test_facade.py +++ b/tests/engine/models/test_facade.py @@ -4,7 +4,7 @@ from collections import namedtuple from unittest.mock import patch -from litellm.types.utils import Choices, Message, ModelResponse +from litellm.types.utils import Choices, EmbeddingResponse, Message, ModelResponse import pytest from data_designer.engine.models.errors import ModelGenerationValidationFailureError @@ -30,10 +30,20 @@ def stub_model_facade(stub_model_configs, stub_secrets_resolver, stub_model_prov @pytest.fixture -def stub_expected_response(): +def stub_completion_messages(): + return [{"role": "user", "content": "test"}] + + +@pytest.fixture +def stub_expected_completion_response(): return ModelResponse(choices=Choices(message=Message(content="Test response"))) +@pytest.fixture +def stub_expected_embedding_response(): + return EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2) + + @pytest.mark.parametrize( "max_correction_steps,max_conversation_restarts,total_calls", [ @@ -105,6 +115,24 @@ def test_usage_stats_property(stub_model_facade): assert hasattr(stub_model_facade.usage_stats, "model_dump") +def test_consolidate_kwargs(stub_model_configs, stub_model_facade): + # Model config generate kwargs are used as base + result = stub_model_facade.consolidate_kwargs() + assert result == stub_model_configs[0].inference_parameters.generate_kwargs + + # kwargs overrides model config generate kwargs + result = stub_model_facade.consolidate_kwargs(temperature=0.01) + assert result == {**stub_model_configs[0].inference_parameters.generate_kwargs, "temperature": 0.01} + + # Provider extra_body overrides all other kwargs + stub_model_facade.model_provider.extra_body = {"foo_provider": "bar_provider"} + result = stub_model_facade.consolidate_kwargs(extra_body={"foo": "bar"}) + assert result == { + **stub_model_configs[0].inference_parameters.generate_kwargs, + "extra_body": {"foo_provider": "bar_provider", "foo": "bar"}, + } + + @pytest.mark.parametrize( "skip_usage_tracking", [ @@ -112,79 +140,85 @@ def test_usage_stats_property(stub_model_facade): True, ], ) -def test_completion_success(stub_model_facade, stub_expected_response, skip_usage_tracking): - stub_model_facade._router.completion = lambda model_name, messages, **kwargs: stub_expected_response - - messages = [{"role": "user", "content": "test"}] - result = stub_model_facade.completion(messages, skip_usage_tracking=skip_usage_tracking) - - assert result == stub_expected_response - - -def test_completion_with_exception(stub_model_facade): - def raise_exception(*args, **kwargs): - raise Exception("Router error") +@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) +def test_completion_success( + mock_router_completion, + stub_completion_messages, + stub_model_configs, + stub_model_facade, + stub_expected_completion_response, + skip_usage_tracking, +): + mock_router_completion.side_effect = lambda self, model, messages, **kwargs: stub_expected_completion_response + result = stub_model_facade.completion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking) + assert result == stub_expected_completion_response + assert mock_router_completion.call_count == 1 + assert mock_router_completion.call_args[1] == { + "model": "stub-model-text", + "messages": stub_completion_messages, + **stub_model_configs[0].inference_parameters.generate_kwargs, + } - stub_model_facade._router.completion = raise_exception - messages = [{"role": "user", "content": "test"}] +@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) +def test_completion_with_exception(mock_router_completion, stub_completion_messages, stub_model_facade): + mock_router_completion.side_effect = Exception("Router error") with pytest.raises(Exception, match="Router error"): - stub_model_facade.completion(messages) + stub_model_facade.completion(stub_completion_messages) -def test_completion_kwargs_overrides_model_config_generate_kwargs( - stub_model_configs, stub_model_facade, stub_expected_response +@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) +def test_completion_with_kwargs( + mock_router_completion, + stub_completion_messages, + stub_model_configs, + stub_model_facade, + stub_expected_completion_response, ): captured_kwargs = {} - def mock_completion(model_name, messages, **kwargs): + def mock_completion(self, model, messages, **kwargs): captured_kwargs.update(kwargs) - return stub_expected_response + return stub_expected_completion_response - stub_model_facade._router.completion = mock_completion + mock_router_completion.side_effect = mock_completion - messages = [{"role": "user", "content": "test"}] kwargs = {"temperature": 0.7, "max_tokens": 100} - result = stub_model_facade.completion(messages, **kwargs) + result = stub_model_facade.completion(stub_completion_messages, **kwargs) - assert result == stub_expected_response + assert result == stub_expected_completion_response # completion kwargs overrides model config generate kwargs assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs} -@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) -def test_provider_extra_body_overrides_completion_kwargs(mock_router_completion, stub_model_configs, stub_model_facade): - messages = [{"role": "user", "content": "test"}] - stub_provider_extra_body = {"foo": "bar"} - - # model config has generate kwargs, completion call has no kwargs, and provider has no extra body - _ = stub_model_facade.completion(messages) - assert len(mock_router_completion.call_args) == 2 - assert mock_router_completion.call_args[0][1] == "stub-model-text" - assert mock_router_completion.call_args[0][2] == messages - assert mock_router_completion.call_args[1] == stub_model_configs[0].inference_parameters.generate_kwargs - - # model config has generate kwargs, completion call has kwargs, and provider has no extra body - # completion kwargs overrides model config generate kwargs - _ = stub_model_facade.completion(messages, temperature=0.1) - assert len(mock_router_completion.call_args) == 2 - assert mock_router_completion.call_args[0][1] == "stub-model-text" - assert mock_router_completion.call_args[0][2] == messages - assert mock_router_completion.call_args[1] == { - **stub_model_configs[0].inference_parameters.generate_kwargs, - "temperature": 0.1, - } +@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True) +def test_generate_text_embeddings_success(mock_router_embedding, stub_model_facade, stub_expected_embedding_response): + mock_router_embedding.side_effect = lambda self, model, input, **kwargs: stub_expected_embedding_response + input_texts = ["test1", "test2"] + result = stub_model_facade.generate_text_embeddings(input_texts) + assert result == [data["embedding"] for data in stub_expected_embedding_response.data] - # model config has generate kwargs, completion call has kwargs, and provider has extra body - # provider extra body overrides completion kwargs - stub_model_facade.model_provider.extra_body = stub_provider_extra_body - _ = stub_model_facade.completion(messages, temperature=0.15, extra_body={"foo": "bat"}) - assert len(mock_router_completion.call_args) == 2 - assert mock_router_completion.call_args[0][1] == "stub-model-text" - assert mock_router_completion.call_args[0][2] == messages - assert mock_router_completion.call_args[1] == { - **stub_model_configs[0].inference_parameters.generate_kwargs, - "temperature": 0.15, - "extra_body": stub_provider_extra_body, - } + +@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True) +def test_generate_text_embeddings_with_exception(mock_router_embedding, stub_model_facade): + mock_router_embedding.side_effect = Exception("Router error") + + with pytest.raises(Exception, match="Router error"): + stub_model_facade.generate_text_embeddings(["test1", "test2"]) + + +@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True) +def test_generate_text_embeddings_with_kwargs( + mock_router_embedding, stub_model_configs, stub_model_facade, stub_expected_embedding_response +): + captured_kwargs = {} + + def mock_embedding(self, model, input, **kwargs): + captured_kwargs.update(kwargs) + return stub_expected_embedding_response + + mock_router_embedding.side_effect = mock_embedding + kwargs = {"temperature": 0.7, "max_tokens": 100, "input_type": "query"} + _ = stub_model_facade.generate_text_embeddings(["test1", "test2"], **kwargs) + assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs} From 744bc8fd4c9d5662966ba09282b69daf326be9b8 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 25 Nov 2025 16:46:42 -0700 Subject: [PATCH 05/64] remove sensitive=True old artifact no longer needed --- src/data_designer/engine/models/facade.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/data_designer/engine/models/facade.py b/src/data_designer/engine/models/facade.py index 4e3f36ef..ea72d4c3 100644 --- a/src/data_designer/engine/models/facade.py +++ b/src/data_designer/engine/models/facade.py @@ -64,7 +64,7 @@ def usage_stats(self) -> ModelUsageStats: def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse: logger.debug( f"Prompting model {self.model_name!r}...", - extra={"model": self.model_name, "messages": messages, "sensitive": True}, + extra={"model": self.model_name, "messages": messages}, ) response = None kwargs = self.consolidate_kwargs(**kwargs) @@ -101,7 +101,6 @@ def generate_text_embeddings( extra={ "model": self.model_name, "input_count": len(input_texts), - "sensitive": True, }, ) kwargs = self.consolidate_kwargs(**kwargs) From b913f8d6dfc0d3717badc30a3ae4176287cbf9b8 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 25 Nov 2025 17:11:23 -0700 Subject: [PATCH 06/64] Slight refactor --- .../utils/column_statistics_calculations.py | 2 +- ...llm_generators.py => generation_mixins.py} | 94 ++++--------------- .../generators/llm_completion_generators.py | 71 ++++++++++++++ .../engine/column_generators/registry.py | 2 +- .../dataset_builders/column_wise_builder.py | 6 +- ...s.py => test_llm_completion_generators.py} | 6 +- .../engine/column_generators/test_registry.py | 2 +- 7 files changed, 97 insertions(+), 86 deletions(-) rename src/data_designer/engine/column_generators/generators/{llm_generators.py => generation_mixins.py} (64%) create mode 100644 src/data_designer/engine/column_generators/generators/llm_completion_generators.py rename tests/engine/column_generators/generators/{test_llm_generators.py => test_llm_completion_generators.py} (97%) diff --git a/src/data_designer/engine/analysis/utils/column_statistics_calculations.py b/src/data_designer/engine/analysis/utils/column_statistics_calculations.py index 120caef4..1b23c0ea 100644 --- a/src/data_designer/engine/analysis/utils/column_statistics_calculations.py +++ b/src/data_designer/engine/analysis/utils/column_statistics_calculations.py @@ -23,7 +23,7 @@ SingleColumnConfig, ValidationColumnConfig, ) -from data_designer.engine.column_generators.generators.llm_generators import ( +from data_designer.engine.column_generators.utils.prompt_renderer import ( PromptType, RecordBasedPromptRenderer, create_response_recipe, diff --git a/src/data_designer/engine/column_generators/generators/llm_generators.py b/src/data_designer/engine/column_generators/generators/generation_mixins.py similarity index 64% rename from src/data_designer/engine/column_generators/generators/llm_generators.py rename to src/data_designer/engine/column_generators/generators/generation_mixins.py index 8f4cfc90..4e29a37a 100644 --- a/src/data_designer/engine/column_generators/generators/llm_generators.py +++ b/src/data_designer/engine/column_generators/generators/generation_mixins.py @@ -4,20 +4,9 @@ import functools import logging -from data_designer.config.column_configs import ( - LLMCodeColumnConfig, - LLMJudgeColumnConfig, - LLMStructuredColumnConfig, - LLMTextColumnConfig, -) from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP from data_designer.config.models import InferenceParameters, ModelConfig from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX -from data_designer.engine.column_generators.generators.base import ( - ColumnGenerator, - GenerationStrategy, - GeneratorMetadata, -) from data_designer.engine.column_generators.utils.prompt_renderer import ( PromptType, RecordBasedPromptRenderer, @@ -26,7 +15,6 @@ from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.recipes.base import ResponseRecipe from data_designer.engine.processing.utils import deserialize_json_values -from data_designer.engine.resources.resource_provider import ResourceType DEFAULT_MAX_CONVERSATION_RESTARTS = 5 DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0 @@ -35,7 +23,7 @@ logger = logging.getLogger(__name__) -class WithLLMGeneration: +class WithModelGeneration: @functools.cached_property def model(self) -> ModelFacade: return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias) @@ -59,6 +47,21 @@ def prompt_renderer(self) -> RecordBasedPromptRenderer: }, ) + def log_pre_generation(self) -> None: + emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type] + logger.info(f"{emoji} Preparing {self.config.column_type} column generation") + logger.info(f" |-- column name: {self.config.name!r}") + logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}") + if self.model_config.provider is None: + logger.info(f" |-- default model provider: {self._get_provider_name()!r}") + + def _get_provider_name(self) -> str: + model_alias = self.model_config.alias + provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) + return provider.name + + +class WithCompletionGeneration(WithModelGeneration): @functools.cached_property def response_recipe(self) -> ResponseRecipe: return create_response_recipe(self.config, self.model_config) @@ -104,68 +107,3 @@ def generate(self, data: dict) -> dict: data[self.config.name + REASONING_TRACE_COLUMN_POSTFIX] = reasoning_trace return data - - def log_pre_generation(self) -> None: - emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type] - logger.info(f"{emoji} Preparing {self.config.column_type} column generation") - logger.info(f" |-- column name: {self.config.name!r}") - logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}") - if self.model_config.provider is None: - logger.info(f" |-- default model provider: {self._get_provider_name()!r}") - - def _get_provider_name(self) -> str: - model_alias = self.model_config.alias - provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) - return provider.name - - -class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfig]): - @staticmethod - def metadata() -> GeneratorMetadata: - return GeneratorMetadata( - name="llm_text_generator", - description="Generate a new dataset cell from a prompt template", - generation_strategy=GenerationStrategy.CELL_BY_CELL, - required_resources=[ResourceType.MODEL_REGISTRY], - ) - - -class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfig]): - @staticmethod - def metadata() -> GeneratorMetadata: - return GeneratorMetadata( - name="llm_code_generator", - description="Generate a new dataset cell from a prompt template", - generation_strategy=GenerationStrategy.CELL_BY_CELL, - required_resources=[ResourceType.MODEL_REGISTRY], - ) - - -class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructuredColumnConfig]): - @staticmethod - def metadata() -> GeneratorMetadata: - return GeneratorMetadata( - name="llm_structured_generator", - description="Generate a new dataset cell from a prompt template", - generation_strategy=GenerationStrategy.CELL_BY_CELL, - required_resources=[ResourceType.MODEL_REGISTRY], - ) - - -class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnConfig]): - @staticmethod - def metadata() -> GeneratorMetadata: - return GeneratorMetadata( - name="llm_judge_generator", - description="Judge a new dataset cell based on a set of rubrics", - generation_strategy=GenerationStrategy.CELL_BY_CELL, - required_resources=[ResourceType.MODEL_REGISTRY], - ) - - @property - def max_conversation_correction_steps(self) -> int: - return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS - - @property - def max_conversation_restarts(self) -> int: - return 2 * DEFAULT_MAX_CONVERSATION_RESTARTS diff --git a/src/data_designer/engine/column_generators/generators/llm_completion_generators.py b/src/data_designer/engine/column_generators/generators/llm_completion_generators.py new file mode 100644 index 00000000..cc61c619 --- /dev/null +++ b/src/data_designer/engine/column_generators/generators/llm_completion_generators.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from data_designer.config.column_configs import ( + LLMCodeColumnConfig, + LLMJudgeColumnConfig, + LLMStructuredColumnConfig, + LLMTextColumnConfig, +) +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + GenerationStrategy, + GeneratorMetadata, +) +from data_designer.engine.column_generators.generators.generation_mixins import ( + DEFAULT_MAX_CONVERSATION_RESTARTS, + WithCompletionGeneration, +) +from data_designer.engine.resources.resource_provider import ResourceType + +logger = logging.getLogger(__name__) + + +class LLMTextCellGenerator(WithCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]): + @staticmethod + def metadata() -> GeneratorMetadata: + return GeneratorMetadata( + name="llm_text_generator", + description="Generate a new dataset cell from a prompt template", + generation_strategy=GenerationStrategy.CELL_BY_CELL, + required_resources=[ResourceType.MODEL_REGISTRY], + ) + + +class LLMCodeCellGenerator(WithCompletionGeneration, ColumnGenerator[LLMCodeColumnConfig]): + @staticmethod + def metadata() -> GeneratorMetadata: + return GeneratorMetadata( + name="llm_code_generator", + description="Generate a new dataset cell from a prompt template", + generation_strategy=GenerationStrategy.CELL_BY_CELL, + required_resources=[ResourceType.MODEL_REGISTRY], + ) + + +class LLMStructuredCellGenerator(WithCompletionGeneration, ColumnGenerator[LLMStructuredColumnConfig]): + @staticmethod + def metadata() -> GeneratorMetadata: + return GeneratorMetadata( + name="llm_structured_generator", + description="Generate a new dataset cell from a prompt template", + generation_strategy=GenerationStrategy.CELL_BY_CELL, + required_resources=[ResourceType.MODEL_REGISTRY], + ) + + +class LLMJudgeCellGenerator(WithCompletionGeneration, ColumnGenerator[LLMJudgeColumnConfig]): + @staticmethod + def metadata() -> GeneratorMetadata: + return GeneratorMetadata( + name="llm_judge_generator", + description="Judge a new dataset cell based on a set of rubrics", + generation_strategy=GenerationStrategy.CELL_BY_CELL, + required_resources=[ResourceType.MODEL_REGISTRY], + ) + + @property + def max_conversation_restarts(self) -> int: + return DEFAULT_MAX_CONVERSATION_RESTARTS * 2 diff --git a/src/data_designer/engine/column_generators/registry.py b/src/data_designer/engine/column_generators/registry.py index 61b43753..56a176ae 100644 --- a/src/data_designer/engine/column_generators/registry.py +++ b/src/data_designer/engine/column_generators/registry.py @@ -13,7 +13,7 @@ from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.base import ColumnGenerator from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator -from data_designer.engine.column_generators.generators.llm_generators import ( +from data_designer.engine.column_generators.generators.llm_completion_generators import ( LLMCodeCellGenerator, LLMJudgeCellGenerator, LLMStructuredCellGenerator, diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index e7060f82..ae6c54cc 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -18,7 +18,7 @@ ProcessorType, ) from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy -from data_designer.engine.column_generators.generators.llm_generators import WithLLMGeneration +from data_designer.engine.column_generators.generators.generation_mixins import WithCompletionGeneration from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.dataset_builders.multi_column_configs import ( @@ -169,7 +169,7 @@ def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None: max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR - if isinstance(generator, WithLLMGeneration): + if isinstance(generator, WithCompletionGeneration): max_workers = generator.inference_parameters.max_parallel_requests self._fan_out_with_threads(generator, max_workers=max_workers) @@ -183,7 +183,7 @@ def _run_model_health_check_if_needed(self) -> bool: set(config.model_alias for config in self.llm_generated_column_configs) ) - def _fan_out_with_threads(self, generator: WithLLMGeneration, max_workers: int) -> None: + def _fan_out_with_threads(self, generator: WithCompletionGeneration, max_workers: int) -> None: if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL: raise DatasetGenerationError( f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} " diff --git a/tests/engine/column_generators/generators/test_llm_generators.py b/tests/engine/column_generators/generators/test_llm_completion_generators.py similarity index 97% rename from tests/engine/column_generators/generators/test_llm_generators.py rename to tests/engine/column_generators/generators/test_llm_completion_generators.py index acaa2c6f..ab398aed 100644 --- a/tests/engine/column_generators/generators/test_llm_generators.py +++ b/tests/engine/column_generators/generators/test_llm_completion_generators.py @@ -11,10 +11,12 @@ LLMStructuredColumnConfig, LLMTextColumnConfig, ) -from data_designer.engine.column_generators.generators.llm_generators import ( +from data_designer.engine.column_generators.generators.generation_mixins import ( DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS, DEFAULT_MAX_CONVERSATION_RESTARTS, REASONING_TRACE_COLUMN_POSTFIX, +) +from data_designer.engine.column_generators.generators.llm_completion_generators import ( LLMCodeCellGenerator, LLMJudgeCellGenerator, LLMStructuredCellGenerator, @@ -94,7 +96,7 @@ def test_generate_method(): assert call_args[1]["multi_modal_context"] is None -@patch("data_designer.engine.column_generators.generators.llm_generators.logger", autospec=True) +@patch("data_designer.engine.column_generators.generators.generation_mixins.logger", autospec=True) def test_log_pre_generation(mock_logger): generator, mock_resource_provider, _, mock_model_config, _, _, _ = _create_generator_with_mocks() mock_model_config.model_dump_json.return_value = '{"test": "config"}' diff --git a/tests/engine/column_generators/test_registry.py b/tests/engine/column_generators/test_registry.py index f70b0d90..57457b94 100644 --- a/tests/engine/column_generators/test_registry.py +++ b/tests/engine/column_generators/test_registry.py @@ -3,7 +3,7 @@ from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator -from data_designer.engine.column_generators.generators.llm_generators import ( +from data_designer.engine.column_generators.generators.llm_completion_generators import ( LLMCodeCellGenerator, LLMJudgeCellGenerator, LLMStructuredCellGenerator, From 052db7a41f142c8aa2b28f1701fb0fd3bfaa652f Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 25 Nov 2025 17:21:20 -0700 Subject: [PATCH 07/64] slight refactor --- .../column_generators/generators/base.py | 48 ++++++++ .../generators/generation_mixins.py | 109 ------------------ .../generators/llm_completion_generators.py | 63 +++++++++- .../dataset_builders/column_wise_builder.py | 2 +- .../test_llm_completion_generators.py | 6 +- 5 files changed, 111 insertions(+), 117 deletions(-) delete mode 100644 src/data_designer/engine/column_generators/generators/generation_mixins.py diff --git a/src/data_designer/engine/column_generators/generators/base.py b/src/data_designer/engine/column_generators/generators/base.py index f4ddb60c..8977a63b 100644 --- a/src/data_designer/engine/column_generators/generators/base.py +++ b/src/data_designer/engine/column_generators/generators/base.py @@ -2,12 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +import functools +import logging from typing import overload import pandas as pd +from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP +from data_designer.config.models import InferenceParameters, ModelConfig from data_designer.config.utils.type_helpers import StrEnum +from data_designer.engine.column_generators.utils.prompt_renderer import ( + RecordBasedPromptRenderer, +) from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT +from data_designer.engine.models.facade import ModelFacade + +logger = logging.getLogger(__name__) class GenerationStrategy(StrEnum): @@ -59,3 +69,41 @@ def can_generate_from_scratch(self) -> bool: @abstractmethod def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ... + + +class WithModelGeneration: + @functools.cached_property + def model(self) -> ModelFacade: + return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias) + + @functools.cached_property + def model_config(self) -> ModelConfig: + return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias) + + @functools.cached_property + def inference_parameters(self) -> InferenceParameters: + return self.model_config.inference_parameters + + @functools.cached_property + def prompt_renderer(self) -> RecordBasedPromptRenderer: + return RecordBasedPromptRenderer( + response_recipe=self.response_recipe, + error_message_context={ + "column_name": self.config.name, + "column_type": self.config.column_type, + "model_alias": self.config.model_alias, + }, + ) + + def log_pre_generation(self) -> None: + emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type] + logger.info(f"{emoji} Preparing {self.config.column_type} column generation") + logger.info(f" |-- column name: {self.config.name!r}") + logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}") + if self.model_config.provider is None: + logger.info(f" |-- default model provider: {self._get_provider_name()!r}") + + def _get_provider_name(self) -> str: + model_alias = self.model_config.alias + provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) + return provider.name diff --git a/src/data_designer/engine/column_generators/generators/generation_mixins.py b/src/data_designer/engine/column_generators/generators/generation_mixins.py deleted file mode 100644 index 4e29a37a..00000000 --- a/src/data_designer/engine/column_generators/generators/generation_mixins.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import functools -import logging - -from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP -from data_designer.config.models import InferenceParameters, ModelConfig -from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX -from data_designer.engine.column_generators.utils.prompt_renderer import ( - PromptType, - RecordBasedPromptRenderer, - create_response_recipe, -) -from data_designer.engine.models.facade import ModelFacade -from data_designer.engine.models.recipes.base import ResponseRecipe -from data_designer.engine.processing.utils import deserialize_json_values - -DEFAULT_MAX_CONVERSATION_RESTARTS = 5 -DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0 - - -logger = logging.getLogger(__name__) - - -class WithModelGeneration: - @functools.cached_property - def model(self) -> ModelFacade: - return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias) - - @functools.cached_property - def model_config(self) -> ModelConfig: - return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias) - - @functools.cached_property - def inference_parameters(self) -> InferenceParameters: - return self.model_config.inference_parameters - - @functools.cached_property - def prompt_renderer(self) -> RecordBasedPromptRenderer: - return RecordBasedPromptRenderer( - response_recipe=self.response_recipe, - error_message_context={ - "column_name": self.config.name, - "column_type": self.config.column_type, - "model_alias": self.config.model_alias, - }, - ) - - def log_pre_generation(self) -> None: - emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type] - logger.info(f"{emoji} Preparing {self.config.column_type} column generation") - logger.info(f" |-- column name: {self.config.name!r}") - logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}") - if self.model_config.provider is None: - logger.info(f" |-- default model provider: {self._get_provider_name()!r}") - - def _get_provider_name(self) -> str: - model_alias = self.model_config.alias - provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) - return provider.name - - -class WithCompletionGeneration(WithModelGeneration): - @functools.cached_property - def response_recipe(self) -> ResponseRecipe: - return create_response_recipe(self.config, self.model_config) - - @property - def max_conversation_correction_steps(self) -> int: - return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS - - @property - def max_conversation_restarts(self) -> int: - return DEFAULT_MAX_CONVERSATION_RESTARTS - - def generate(self, data: dict) -> dict: - deserialized_record = deserialize_json_values(data) - - multi_modal_context = None - if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0: - multi_modal_context = [ - context.get_context(deserialized_record) for context in self.config.multi_modal_context - ] - - response, reasoning_trace = self.model.generate( - prompt=self.prompt_renderer.render( - record=deserialized_record, - prompt_template=self.config.prompt, - prompt_type=PromptType.USER_PROMPT, - ), - system_prompt=self.prompt_renderer.render( - record=deserialized_record, - prompt_template=self.config.system_prompt, - prompt_type=PromptType.SYSTEM_PROMPT, - ), - parser=self.response_recipe.parse, - multi_modal_context=multi_modal_context, - max_correction_steps=self.max_conversation_correction_steps, - max_conversation_restarts=self.max_conversation_restarts, - purpose=f"running generation for column '{self.config.name}'", - ) - - data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response)) - - if reasoning_trace: - data[self.config.name + REASONING_TRACE_COLUMN_POSTFIX] = reasoning_trace - - return data diff --git a/src/data_designer/engine/column_generators/generators/llm_completion_generators.py b/src/data_designer/engine/column_generators/generators/llm_completion_generators.py index cc61c619..5665ba85 100644 --- a/src/data_designer/engine/column_generators/generators/llm_completion_generators.py +++ b/src/data_designer/engine/column_generators/generators/llm_completion_generators.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import functools import logging from data_designer.config.column_configs import ( @@ -9,20 +10,76 @@ LLMStructuredColumnConfig, LLMTextColumnConfig, ) +from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, GenerationStrategy, GeneratorMetadata, + WithModelGeneration, ) -from data_designer.engine.column_generators.generators.generation_mixins import ( - DEFAULT_MAX_CONVERSATION_RESTARTS, - WithCompletionGeneration, +from data_designer.engine.column_generators.utils.prompt_renderer import ( + PromptType, + create_response_recipe, ) +from data_designer.engine.models.recipes.base import ResponseRecipe +from data_designer.engine.processing.utils import deserialize_json_values from data_designer.engine.resources.resource_provider import ResourceType logger = logging.getLogger(__name__) +DEFAULT_MAX_CONVERSATION_RESTARTS = 5 +DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0 + + +class WithCompletionGeneration(WithModelGeneration): + @functools.cached_property + def response_recipe(self) -> ResponseRecipe: + return create_response_recipe(self.config, self.model_config) + + @property + def max_conversation_correction_steps(self) -> int: + return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS + + @property + def max_conversation_restarts(self) -> int: + return DEFAULT_MAX_CONVERSATION_RESTARTS + + def generate(self, data: dict) -> dict: + deserialized_record = deserialize_json_values(data) + + multi_modal_context = None + if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0: + multi_modal_context = [ + context.get_context(deserialized_record) for context in self.config.multi_modal_context + ] + + response, reasoning_trace = self.model.generate( + prompt=self.prompt_renderer.render( + record=deserialized_record, + prompt_template=self.config.prompt, + prompt_type=PromptType.USER_PROMPT, + ), + system_prompt=self.prompt_renderer.render( + record=deserialized_record, + prompt_template=self.config.system_prompt, + prompt_type=PromptType.SYSTEM_PROMPT, + ), + parser=self.response_recipe.parse, + multi_modal_context=multi_modal_context, + max_correction_steps=self.max_conversation_correction_steps, + max_conversation_restarts=self.max_conversation_restarts, + purpose=f"running generation for column '{self.config.name}'", + ) + + data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response)) + + if reasoning_trace: + data[self.config.name + REASONING_TRACE_COLUMN_POSTFIX] = reasoning_trace + + return data + + class LLMTextCellGenerator(WithCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index ae6c54cc..78a5e9fa 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -18,7 +18,7 @@ ProcessorType, ) from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy -from data_designer.engine.column_generators.generators.generation_mixins import WithCompletionGeneration +from data_designer.engine.column_generators.generators.llm_completion_generators import WithCompletionGeneration from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.dataset_builders.multi_column_configs import ( diff --git a/tests/engine/column_generators/generators/test_llm_completion_generators.py b/tests/engine/column_generators/generators/test_llm_completion_generators.py index ab398aed..3a411fc9 100644 --- a/tests/engine/column_generators/generators/test_llm_completion_generators.py +++ b/tests/engine/column_generators/generators/test_llm_completion_generators.py @@ -11,12 +11,10 @@ LLMStructuredColumnConfig, LLMTextColumnConfig, ) -from data_designer.engine.column_generators.generators.generation_mixins import ( +from data_designer.engine.column_generators.generators.llm_completion_generators import ( DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS, DEFAULT_MAX_CONVERSATION_RESTARTS, REASONING_TRACE_COLUMN_POSTFIX, -) -from data_designer.engine.column_generators.generators.llm_completion_generators import ( LLMCodeCellGenerator, LLMJudgeCellGenerator, LLMStructuredCellGenerator, @@ -96,7 +94,7 @@ def test_generate_method(): assert call_args[1]["multi_modal_context"] is None -@patch("data_designer.engine.column_generators.generators.generation_mixins.logger", autospec=True) +@patch("data_designer.engine.column_generators.generators.base.logger", autospec=True) def test_log_pre_generation(mock_logger): generator, mock_resource_provider, _, mock_model_config, _, _, _ = _create_generator_with_mocks() mock_model_config.model_dump_json.return_value = '{"test": "config"}' From 5504c8dd1b4745e27c6590e2424cc4cb26a7944d Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 25 Nov 2025 18:17:07 -0700 Subject: [PATCH 08/64] Added embedding generator --- src/data_designer/config/column_configs.py | 23 +++++++++++++ src/data_designer/config/column_types.py | 8 +++++ src/data_designer/config/models.py | 2 +- .../config/utils/visualization.py | 15 +++++++- .../generators/embedding_generators.py | 34 +++++++++++++++++++ .../engine/column_generators/registry.py | 3 ++ src/data_designer/engine/models/facade.py | 6 +++- src/data_designer/engine/models/registry.py | 29 ++++++++++------ src/data_designer/essentials/__init__.py | 4 +++ 9 files changed, 111 insertions(+), 13 deletions(-) create mode 100644 src/data_designer/engine/column_generators/generators/embedding_generators.py diff --git a/src/data_designer/config/column_configs.py b/src/data_designer/config/column_configs.py index d19b6a9e..c5468f19 100644 --- a/src/data_designer/config/column_configs.py +++ b/src/data_designer/config/column_configs.py @@ -377,3 +377,26 @@ class SeedDatasetColumnConfig(SingleColumnConfig): """ column_type: Literal["seed-dataset"] = "seed-dataset" + + +class EmbeddingColumnConfig(SingleColumnConfig): + """Configuration for embedding generation columns. + + Embedding columns generate embeddings for text input using a specified model. + + Attributes: + column_type: Discriminator field, always "embedding" for this configuration type. + target_column: The column to generate embeddings for. + model_alias: The model to use for embedding generation. + chunk_separator: Optional separator to split the text in the target column into chunks. For example, if chunk_separator + is '\n', the text will be split into chunks of text separated by newlines and embeddings generated for each chunk. + """ + + column_type: Literal["embedding"] = "embedding" + target_column: str + model_alias: str + chunk_separator: Optional[str] = None + + @property + def required_columns(self) -> list[str]: + return [self.target_column] diff --git a/src/data_designer/config/column_types.py b/src/data_designer/config/column_types.py index 50ba498d..aab55c4d 100644 --- a/src/data_designer/config/column_types.py +++ b/src/data_designer/config/column_types.py @@ -7,6 +7,7 @@ from ..plugin_manager import PluginManager from .column_configs import ( + EmbeddingColumnConfig, ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -31,6 +32,7 @@ SamplerColumnConfig, SeedDatasetColumnConfig, ValidationColumnConfig, + EmbeddingColumnConfig, ] ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT) @@ -50,6 +52,7 @@ DataDesignerColumnType.SEED_DATASET: "🌱", DataDesignerColumnType.SAMPLER: "🎲", DataDesignerColumnType.VALIDATION: "πŸ”", + DataDesignerColumnType.EMBEDDING: "🧬", } COLUMN_TYPE_EMOJI_MAP.update( {DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_column_generator_plugins()} @@ -66,6 +69,7 @@ def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumn DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_TEXT, DataDesignerColumnType.VALIDATION, + DataDesignerColumnType.EMBEDDING, } dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType)) return column_type in dag_column_types @@ -79,6 +83,7 @@ def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType] DataDesignerColumnType.LLM_CODE, DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.EMBEDDING, } llm_generated_column_types.update( plugin_manager.get_plugin_column_types( @@ -117,6 +122,8 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs)) if column_type == DataDesignerColumnType.SEED_DATASET: return SeedDatasetColumnConfig(name=name, **kwargs) + if column_type == DataDesignerColumnType.EMBEDDING: + return EmbeddingColumnConfig(name=name, **kwargs) if plugin := plugin_manager.get_column_generator_plugin_if_exists(column_type.value): return plugin.config_cls(name=name, **kwargs) raise InvalidColumnTypeError(f"πŸ›‘ {column_type} is not a valid column type.") # pragma: no cover @@ -131,6 +138,7 @@ def get_column_display_order() -> list[DataDesignerColumnType]: DataDesignerColumnType.LLM_CODE, DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.EMBEDDING, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, ] diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index 17698346..481633ac 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -207,7 +207,7 @@ def _is_value_in_range(self, value: float, min_value: float, max_value: float) - class GenerationType(str, Enum): CHAT_COMPLETION = "chat-completion" - TEXT_EMBEDDING = "text-embedding" + EMBEDDING = "embedding" IMAGE_GENERATION = "image-generation" diff --git a/src/data_designer/config/utils/visualization.py b/src/data_designer/config/utils/visualization.py index 26ab4ad3..0972daf7 100644 --- a/src/data_designer/config/utils/visualization.py +++ b/src/data_designer/config/utils/visualization.py @@ -8,7 +8,7 @@ from functools import cached_property import json import os -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np import pandas as pd @@ -171,6 +171,7 @@ def display_sample_record( + config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION) + config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT) + config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED) + + config_builder.get_columns_of_type(DataDesignerColumnType.EMBEDDING) ) if len(non_code_columns) > 0: table = Table(title="Generated Columns", **table_kws) @@ -178,6 +179,10 @@ def display_sample_record( table.add_column("Value") for col in non_code_columns: if not col.drop: + if col.column_type == DataDesignerColumnType.EMBEDDING: + record[col.name]["embeddings"] = [ + get_truncated_list_as_string(embd) for embd in record[col.name].get("embeddings") + ] table.add_row(col.name, convert_to_row_element(record[col.name])) render_list.append(pad_console_element(table)) @@ -237,6 +242,14 @@ def display_sample_record( console.print(Group(*render_list), markup=False) +def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> str: + if len(long_list) > max_items: + truncated_part = long_list[:max_items] + return f"[{', '.join(str(x) for x in truncated_part)} ...]" + else: + return str(long_list) + + def display_sampler_table( sampler_params: dict[SamplerType, ConfigBase], title: Optional[str] = None, diff --git a/src/data_designer/engine/column_generators/generators/embedding_generators.py b/src/data_designer/engine/column_generators/generators/embedding_generators.py new file mode 100644 index 00000000..ec827805 --- /dev/null +++ b/src/data_designer/engine/column_generators/generators/embedding_generators.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.config.column_configs import EmbeddingColumnConfig +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + GenerationStrategy, + GeneratorMetadata, + WithModelGeneration, +) +from data_designer.engine.processing.utils import deserialize_json_values + + +class EmbeddingCellGenerator(WithModelGeneration, ColumnGenerator[EmbeddingColumnConfig]): + @staticmethod + def metadata() -> GeneratorMetadata: + return GeneratorMetadata( + name="embedding_cell_generator", + description="Generate embeddings for a text column.", + generation_strategy=GenerationStrategy.CELL_BY_CELL, + required_resources=None, + ) + + def generate(self, data: dict) -> dict: + deserialized_record = deserialize_json_values(data) + input_text = deserialized_record[self.config.target_column] + input_chunks = input_text.split(self.config.chunk_separator) if self.config.chunk_separator else [input_text] + embeddings = self.model.generate_text_embeddings(input_texts=input_chunks) + data[self.config.name] = { + "embeddings": embeddings, + "num_embeddings": len(embeddings), + "dimension": len(embeddings[0]) if len(embeddings) > 0 else 0, + } + return data diff --git a/src/data_designer/engine/column_generators/registry.py b/src/data_designer/engine/column_generators/registry.py index 56a176ae..961eac1a 100644 --- a/src/data_designer/engine/column_generators/registry.py +++ b/src/data_designer/engine/column_generators/registry.py @@ -3,6 +3,7 @@ from data_designer.config.base import ConfigBase from data_designer.config.column_configs import ( + EmbeddingColumnConfig, ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -12,6 +13,7 @@ ) from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.base import ColumnGenerator +from data_designer.engine.column_generators.generators.embedding_generators import EmbeddingCellGenerator from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator from data_designer.engine.column_generators.generators.llm_completion_generators import ( LLMCodeCellGenerator, @@ -40,6 +42,7 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig) registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig) registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig) + registry.register(DataDesignerColumnType.EMBEDDING, EmbeddingCellGenerator, EmbeddingColumnConfig) registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig) registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig) registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig) diff --git a/src/data_designer/engine/models/facade.py b/src/data_designer/engine/models/facade.py index ea72d4c3..c205a4ca 100644 --- a/src/data_designer/engine/models/facade.py +++ b/src/data_designer/engine/models/facade.py @@ -11,7 +11,7 @@ from litellm.types.router import DeploymentTypedDict, LiteLLM_Params from litellm.types.utils import EmbeddingResponse, ModelResponse -from data_designer.config.models import ModelConfig, ModelProvider +from data_designer.config.models import GenerationType, ModelConfig, ModelProvider from data_designer.engine.model_provider import ModelProviderRegistry from data_designer.engine.models.errors import ( GenerationValidationFailureError, @@ -49,6 +49,10 @@ def model_name(self) -> str: def model_provider(self) -> ModelProvider: return self._model_provider_registry.get_provider(self._model_config.provider) + @property + def model_generation_type(self) -> GenerationType: + return self._model_config.generation_type + @property def model_provider_name(self) -> str: return self.model_provider.name diff --git a/src/data_designer/engine/models/registry.py b/src/data_designer/engine/models/registry.py index aafd8c80..4330ea18 100644 --- a/src/data_designer/engine/models/registry.py +++ b/src/data_designer/engine/models/registry.py @@ -5,7 +5,7 @@ import logging -from data_designer.config.models import ModelConfig +from data_designer.config.models import GenerationType, ModelConfig from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.litellm_overrides import apply_litellm_patches @@ -81,15 +81,24 @@ def run_health_check(self, model_aliases: set[str]) -> None: f" |-- πŸ‘€ Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..." ) try: - model.generate( - prompt="Hello!", - parser=lambda x: x, - system_prompt="You are a helpful assistant.", - max_correction_steps=0, - max_conversation_restarts=0, - skip_usage_tracking=True, - purpose="running health checks", - ) + if model.model_generation_type == GenerationType.EMBEDDING: + model.generate_text_embeddings( + input_texts=["Hello!"], + skip_usage_tracking=True, + purpose="running health checks", + ) + elif model.model_generation_type == GenerationType.CHAT_COMPLETION: + model.generate( + prompt="Hello!", + parser=lambda x: x, + system_prompt="You are a helpful assistant.", + max_correction_steps=0, + max_conversation_restarts=0, + skip_usage_tracking=True, + purpose="running health checks", + ) + else: + raise ValueError(f"Unsupported generation type: {model.model_generation_type}") logger.info(" |-- βœ… Passed!") except Exception as e: logger.error(" |-- ❌ Failed!") diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py index 8cd8eb92..ee43519c 100644 --- a/src/data_designer/essentials/__init__.py +++ b/src/data_designer/essentials/__init__.py @@ -6,6 +6,7 @@ from ..config.analysis.column_profilers import JudgeScoreProfilerConfig from ..config.column_configs import ( + EmbeddingColumnConfig, ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -22,6 +23,7 @@ from ..config.dataset_builders import BuildStage from ..config.datastore import DatastoreSettings from ..config.models import ( + GenerationType, ImageContext, ImageFormat, InferenceParameters, @@ -91,8 +93,10 @@ "DatastoreSettings", "DatetimeSamplerParams", "DropColumnsProcessorConfig", + "EmbeddingColumnConfig", "ExpressionColumnConfig", "GaussianSamplerParams", + "GenerationType", "IndexRange", "InfoType", "ImageContext", From 4b6f877875fa93f718d31211323f9a34207630b7 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 25 Nov 2025 18:20:57 -0700 Subject: [PATCH 09/64] chunk_separator -> chunk_pattern --- src/data_designer/config/column_configs.py | 7 ++++--- .../column_generators/generators/embedding_generators.py | 4 +++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/data_designer/config/column_configs.py b/src/data_designer/config/column_configs.py index c5468f19..339be35d 100644 --- a/src/data_designer/config/column_configs.py +++ b/src/data_designer/config/column_configs.py @@ -388,14 +388,15 @@ class EmbeddingColumnConfig(SingleColumnConfig): column_type: Discriminator field, always "embedding" for this configuration type. target_column: The column to generate embeddings for. model_alias: The model to use for embedding generation. - chunk_separator: Optional separator to split the text in the target column into chunks. For example, if chunk_separator - is '\n', the text will be split into chunks of text separated by newlines and embeddings generated for each chunk. + chunk_pattern: Optional regex pattern to split the text in the target column into chunks. For example, if chunk_pattern + is r'\n+', the text will be split into chunks using one or more newlines as separators and embeddings generated for each chunk. + If not provided, the entire text will be embedded as a single chunk. """ column_type: Literal["embedding"] = "embedding" target_column: str model_alias: str - chunk_separator: Optional[str] = None + chunk_pattern: Optional[str] = None @property def required_columns(self) -> list[str]: diff --git a/src/data_designer/engine/column_generators/generators/embedding_generators.py b/src/data_designer/engine/column_generators/generators/embedding_generators.py index ec827805..ac791d4f 100644 --- a/src/data_designer/engine/column_generators/generators/embedding_generators.py +++ b/src/data_designer/engine/column_generators/generators/embedding_generators.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import re + from data_designer.config.column_configs import EmbeddingColumnConfig from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, @@ -24,7 +26,7 @@ def metadata() -> GeneratorMetadata: def generate(self, data: dict) -> dict: deserialized_record = deserialize_json_values(data) input_text = deserialized_record[self.config.target_column] - input_chunks = input_text.split(self.config.chunk_separator) if self.config.chunk_separator else [input_text] + input_chunks = re.split(self.config.chunk_pattern, input_text) if self.config.chunk_pattern else [input_text] embeddings = self.model.generate_text_embeddings(input_texts=input_chunks) data[self.config.name] = { "embeddings": embeddings, From 04fc0f3645062f15b392b70cc64feea2e1d11cab Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 25 Nov 2025 18:22:49 -0700 Subject: [PATCH 10/64] update tests --- tests/config/test_columns.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/config/test_columns.py b/tests/config/test_columns.py index f0f5c51a..f7763b07 100644 --- a/tests/config/test_columns.py +++ b/tests/config/test_columns.py @@ -49,6 +49,7 @@ def test_data_designer_column_type_get_display_order(): DataDesignerColumnType.LLM_CODE, DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.EMBEDDING, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, ] @@ -59,6 +60,7 @@ def test_data_designer_column_type_is_llm_generated(): assert column_type_is_llm_generated(DataDesignerColumnType.LLM_CODE) assert column_type_is_llm_generated(DataDesignerColumnType.LLM_STRUCTURED) assert column_type_is_llm_generated(DataDesignerColumnType.LLM_JUDGE) + assert column_type_is_llm_generated(DataDesignerColumnType.EMBEDDING) assert not column_type_is_llm_generated(DataDesignerColumnType.SAMPLER) assert not column_type_is_llm_generated(DataDesignerColumnType.VALIDATION) assert not column_type_is_llm_generated(DataDesignerColumnType.EXPRESSION) @@ -72,6 +74,7 @@ def test_data_designer_column_type_is_in_dag(): assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_STRUCTURED) assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_TEXT) assert column_type_used_in_execution_dag(DataDesignerColumnType.VALIDATION) + assert column_type_used_in_execution_dag(DataDesignerColumnType.EMBEDDING) assert not column_type_used_in_execution_dag(DataDesignerColumnType.SAMPLER) assert not column_type_used_in_execution_dag(DataDesignerColumnType.SEED_DATASET) From 26d6da1917326fbb57a6e88cf3392145a4f69362 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 26 Nov 2025 09:44:05 -0700 Subject: [PATCH 11/64] rename for consistency --- .../generators/{embedding_generators.py => embedding.py} | 3 ++- .../{llm_completion_generators.py => llm_completion.py} | 0 src/data_designer/engine/column_generators/registry.py | 4 ++-- .../engine/dataset_builders/column_wise_builder.py | 2 +- .../generators/test_llm_completion_generators.py | 2 +- tests/engine/column_generators/test_registry.py | 2 +- 6 files changed, 7 insertions(+), 6 deletions(-) rename src/data_designer/engine/column_generators/generators/{embedding_generators.py => embedding.py} (91%) rename src/data_designer/engine/column_generators/generators/{llm_completion_generators.py => llm_completion.py} (100%) diff --git a/src/data_designer/engine/column_generators/generators/embedding_generators.py b/src/data_designer/engine/column_generators/generators/embedding.py similarity index 91% rename from src/data_designer/engine/column_generators/generators/embedding_generators.py rename to src/data_designer/engine/column_generators/generators/embedding.py index ac791d4f..d9981ccd 100644 --- a/src/data_designer/engine/column_generators/generators/embedding_generators.py +++ b/src/data_designer/engine/column_generators/generators/embedding.py @@ -11,6 +11,7 @@ WithModelGeneration, ) from data_designer.engine.processing.utils import deserialize_json_values +from data_designer.engine.resources.resource_provider import ResourceType class EmbeddingCellGenerator(WithModelGeneration, ColumnGenerator[EmbeddingColumnConfig]): @@ -20,7 +21,7 @@ def metadata() -> GeneratorMetadata: name="embedding_cell_generator", description="Generate embeddings for a text column.", generation_strategy=GenerationStrategy.CELL_BY_CELL, - required_resources=None, + required_resources=[ResourceType.MODEL_REGISTRY], ) def generate(self, data: dict) -> dict: diff --git a/src/data_designer/engine/column_generators/generators/llm_completion_generators.py b/src/data_designer/engine/column_generators/generators/llm_completion.py similarity index 100% rename from src/data_designer/engine/column_generators/generators/llm_completion_generators.py rename to src/data_designer/engine/column_generators/generators/llm_completion.py diff --git a/src/data_designer/engine/column_generators/registry.py b/src/data_designer/engine/column_generators/registry.py index 961eac1a..7171e561 100644 --- a/src/data_designer/engine/column_generators/registry.py +++ b/src/data_designer/engine/column_generators/registry.py @@ -13,9 +13,9 @@ ) from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.base import ColumnGenerator -from data_designer.engine.column_generators.generators.embedding_generators import EmbeddingCellGenerator +from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator -from data_designer.engine.column_generators.generators.llm_completion_generators import ( +from data_designer.engine.column_generators.generators.llm_completion import ( LLMCodeCellGenerator, LLMJudgeCellGenerator, LLMStructuredCellGenerator, diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index 78a5e9fa..ff9289ee 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -18,7 +18,7 @@ ProcessorType, ) from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy -from data_designer.engine.column_generators.generators.llm_completion_generators import WithCompletionGeneration +from data_designer.engine.column_generators.generators.llm_completion import WithCompletionGeneration from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.dataset_builders.multi_column_configs import ( diff --git a/tests/engine/column_generators/generators/test_llm_completion_generators.py b/tests/engine/column_generators/generators/test_llm_completion_generators.py index 3a411fc9..0b787b7e 100644 --- a/tests/engine/column_generators/generators/test_llm_completion_generators.py +++ b/tests/engine/column_generators/generators/test_llm_completion_generators.py @@ -11,7 +11,7 @@ LLMStructuredColumnConfig, LLMTextColumnConfig, ) -from data_designer.engine.column_generators.generators.llm_completion_generators import ( +from data_designer.engine.column_generators.generators.llm_completion import ( DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS, DEFAULT_MAX_CONVERSATION_RESTARTS, REASONING_TRACE_COLUMN_POSTFIX, diff --git a/tests/engine/column_generators/test_registry.py b/tests/engine/column_generators/test_registry.py index 57457b94..0d325937 100644 --- a/tests/engine/column_generators/test_registry.py +++ b/tests/engine/column_generators/test_registry.py @@ -3,7 +3,7 @@ from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator -from data_designer.engine.column_generators.generators.llm_completion_generators import ( +from data_designer.engine.column_generators.generators.llm_completion import ( LLMCodeCellGenerator, LLMJudgeCellGenerator, LLMStructuredCellGenerator, From 6facbd2c8a710052fc76c3c33c3c451dca04c697 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 26 Nov 2025 11:04:00 -0700 Subject: [PATCH 12/64] Restructure InferenceParameters -> CompletionInferenceParameters, BaseInferenceParameters, EmbeddingInferenceParameters --- .../config/default_model_settings.py | 18 +++-- src/data_designer/config/models.py | 79 ++++++++++++++++--- src/data_designer/essentials/__init__.py | 6 +- tests/cli/conftest.py | 10 +-- .../cli/controllers/test_model_controller.py | 4 +- tests/cli/services/test_model_service.py | 10 ++- tests/config/test_config_builder.py | 12 +-- tests/config/test_default_model_settings.py | 8 +- tests/config/test_models.py | 64 ++++++++------- tests/conftest.py | 4 +- tests/engine/models/conftest.py | 6 +- tests/engine/models/test_model_registry.py | 6 +- tests/essentials/test_init.py | 12 +++ 13 files changed, 162 insertions(+), 77 deletions(-) diff --git a/src/data_designer/config/default_model_settings.py b/src/data_designer/config/default_model_settings.py index 33d6dad4..cb565178 100644 --- a/src/data_designer/config/default_model_settings.py +++ b/src/data_designer/config/default_model_settings.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Any, Literal, Optional -from .models import InferenceParameters, ModelConfig, ModelProvider +from .models import CompletionInferenceParameters, ModelConfig, ModelProvider from .utils.constants import ( MANAGED_ASSETS_PATH, MODEL_CONFIGS_FILE_PATH, @@ -21,28 +21,30 @@ logger = logging.getLogger(__name__) -def get_default_text_alias_inference_parameters() -> InferenceParameters: - return InferenceParameters( +def get_default_text_alias_inference_parameters() -> CompletionInferenceParameters: + return CompletionInferenceParameters( temperature=0.85, top_p=0.95, ) -def get_default_reasoning_alias_inference_parameters() -> InferenceParameters: - return InferenceParameters( +def get_default_reasoning_alias_inference_parameters() -> CompletionInferenceParameters: + return CompletionInferenceParameters( temperature=0.35, top_p=0.95, ) -def get_default_vision_alias_inference_parameters() -> InferenceParameters: - return InferenceParameters( +def get_default_vision_alias_inference_parameters() -> CompletionInferenceParameters: + return CompletionInferenceParameters( temperature=0.85, top_p=0.95, ) -def get_default_inference_parameters(model_alias: Literal["text", "reasoning", "vision"]) -> InferenceParameters: +def get_default_inference_parameters( + model_alias: Literal["text", "reasoning", "vision"], +) -> CompletionInferenceParameters: if model_alias == "reasoning": return get_default_reasoning_alias_inference_parameters() elif model_alias == "vision": diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index 481633ac..1df7055e 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -5,7 +5,7 @@ from enum import Enum import logging from pathlib import Path -from typing import Any, Generic, List, Optional, TypeVar, Union +from typing import Any, Generic, List, Literal, Optional, TypeVar, Union import numpy as np from pydantic import BaseModel, Field, model_validator @@ -136,10 +136,7 @@ def sample(self) -> float: DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution] -class InferenceParameters(ConfigBase): - temperature: Optional[Union[float, DistributionT]] = None - top_p: Optional[Union[float, DistributionT]] = None - max_tokens: Optional[int] = Field(default=None, ge=1) +class BaseInferenceParameters(ConfigBase, ABC): max_parallel_requests: int = Field(default=4, ge=1) timeout: Optional[int] = Field(default=None, ge=1) extra_body: Optional[dict[str, Any]] = None @@ -147,6 +144,21 @@ class InferenceParameters(ConfigBase): @property def generate_kwargs(self) -> dict[str, Union[float, int]]: result = {} + if self.timeout is not None: + result["timeout"] = self.timeout + if self.extra_body is not None and self.extra_body != {}: + result["extra_body"] = self.extra_body + return result + + +class CompletionInferenceParameters(BaseInferenceParameters): + temperature: Optional[Union[float, DistributionT]] = None + top_p: Optional[Union[float, DistributionT]] = None + max_tokens: Optional[int] = Field(default=None, ge=1) + + @property + def generate_kwargs(self) -> dict[str, Union[float, int]]: + result = super().generate_kwargs if self.temperature is not None: result["temperature"] = ( self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature @@ -155,10 +167,6 @@ def generate_kwargs(self) -> dict[str, Union[float, int]]: result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p if self.max_tokens is not None: result["max_tokens"] = self.max_tokens - if self.timeout is not None: - result["timeout"] = self.timeout - if self.extra_body is not None and self.extra_body != {}: - result["extra_body"] = self.extra_body return result @model_validator(mode="after") @@ -205,6 +213,40 @@ def _is_value_in_range(self, value: float, min_value: float, max_value: float) - return min_value <= value <= max_value +# Maintain backwards compatibility with a deprecation warning +class InferenceParameters(CompletionInferenceParameters): + """ + Deprecated: Use CompletionInferenceParameters instead. + This alias will be removed in a future version. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + logger.warning( + "InferenceParameters is deprecated and will be removed in a future version. " + "Use CompletionInferenceParameters instead." + ) + super().__init__(*args, **kwargs) + + +class EmbeddingInferenceParameters(BaseInferenceParameters): + encoding_format: Optional[Literal["float", "base64"]] = "float" + dimensions: Optional[int] = None + + @property + def generate_kwargs(self) -> dict[str, Union[float, int]]: + result = super().generate_kwargs + if self.encoding_format is not None: + result["encoding_format"] = self.encoding_format + if self.dimensions is not None: + result["dimensions"] = self.dimensions + return result + + +InferenceParametersT: TypeAlias = Union[ + InferenceParameters, CompletionInferenceParameters, EmbeddingInferenceParameters +] + + class GenerationType(str, Enum): CHAT_COMPLETION = "chat-completion" EMBEDDING = "embedding" @@ -214,10 +256,25 @@ class GenerationType(str, Enum): class ModelConfig(ConfigBase): alias: str model: str - inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters) - generation_type: GenerationType = GenerationType.CHAT_COMPLETION + inference_parameters: InferenceParametersT = Field(default_factory=CompletionInferenceParameters) provider: Optional[str] = None + @model_validator(mode="after") + def _normalize_deprecated_inference_parameters(self) -> Self: + """Normalize deprecated InferenceParameters to CompletionInferenceParameters.""" + if isinstance(self.inference_parameters, InferenceParameters): + self.inference_parameters = CompletionInferenceParameters(**self.inference_parameters.model_dump()) + return self + + @property + def generation_type(self) -> GenerationType: + if isinstance(self.inference_parameters, CompletionInferenceParameters): + return GenerationType.CHAT_COMPLETION + elif isinstance(self.inference_parameters, EmbeddingInferenceParameters): + return GenerationType.EMBEDDING + else: + raise ValueError(f"Unsupported inference parameters type: {type(self.inference_parameters)}") + class ModelProvider(ConfigBase): name: str diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py index ee43519c..cd1dd6ba 100644 --- a/src/data_designer/essentials/__init__.py +++ b/src/data_designer/essentials/__init__.py @@ -23,6 +23,8 @@ from ..config.dataset_builders import BuildStage from ..config.datastore import DatastoreSettings from ..config.models import ( + CompletionInferenceParameters, + EmbeddingInferenceParameters, GenerationType, ImageContext, ImageFormat, @@ -80,20 +82,22 @@ "BernoulliMixtureSamplerParams", "BernoulliSamplerParams", "BinomialSamplerParams", + "BuildStage", "CategorySamplerParams", "CodeLang", "CodeValidatorParams", "ColumnInequalityConstraint", + "CompletionInferenceParameters", "configure_logging", "DataDesignerColumnType", "DataDesignerConfig", "DataDesignerConfigBuilder", - "BuildStage", "DatastoreSeedDatasetReference", "DatastoreSettings", "DatetimeSamplerParams", "DropColumnsProcessorConfig", "EmbeddingColumnConfig", + "EmbeddingInferenceParameters", "ExpressionColumnConfig", "GaussianSamplerParams", "GenerationType", diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 758e837e..66a06347 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -9,16 +9,16 @@ from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository from data_designer.cli.services.model_service import ModelService from data_designer.cli.services.provider_service import ProviderService -from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider +from data_designer.config.models import CompletionInferenceParameters, ModelConfig, ModelProvider @pytest.fixture -def stub_inference_parameters() -> InferenceParameters: - return InferenceParameters(temperature=0.7, top_p=0.9, max_tokens=2048, max_parallel_requests=4) +def stub_inference_parameters() -> CompletionInferenceParameters: + return CompletionInferenceParameters(temperature=0.7, top_p=0.9, max_tokens=2048, max_parallel_requests=4) @pytest.fixture -def stub_model_configs(stub_inference_parameters: InferenceParameters) -> list[ModelConfig]: +def stub_model_configs(stub_inference_parameters: CompletionInferenceParameters) -> list[ModelConfig]: return [ ModelConfig( alias="test-alias-1", @@ -41,7 +41,7 @@ def stub_new_model_config() -> ModelConfig: alias="test-alias-3", model="test-model-3", provider="test-provider-1", - inference_parameters=InferenceParameters( + inference_parameters=CompletionInferenceParameters( temperature=0.7, top_p=0.9, max_tokens=2048, diff --git a/tests/cli/controllers/test_model_controller.py b/tests/cli/controllers/test_model_controller.py index b630b04a..4f718ca4 100644 --- a/tests/cli/controllers/test_model_controller.py +++ b/tests/cli/controllers/test_model_controller.py @@ -9,7 +9,7 @@ from data_designer.cli.controllers.model_controller import ModelController from data_designer.cli.repositories.model_repository import ModelConfigRegistry from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import CompletionInferenceParameters, ModelConfig @pytest.fixture @@ -141,7 +141,7 @@ def test_run_updates_model( alias="test-alias-1-updated", model="test-model-1-updated", provider="test-provider-1", - inference_parameters=InferenceParameters(temperature=0.8, top_p=0.95, max_tokens=1024), + inference_parameters=CompletionInferenceParameters(temperature=0.8, top_p=0.95, max_tokens=1024), ) mock_builder = MagicMock() diff --git a/tests/cli/services/test_model_service.py b/tests/cli/services/test_model_service.py index 1d9bf5aa..4287eee8 100644 --- a/tests/cli/services/test_model_service.py +++ b/tests/cli/services/test_model_service.py @@ -7,7 +7,7 @@ from data_designer.cli.repositories.model_repository import ModelRepository from data_designer.cli.services.model_service import ModelService -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import CompletionInferenceParameters, ModelConfig def test_list_all(stub_model_service: ModelService, stub_model_configs: list[ModelConfig]): @@ -30,7 +30,9 @@ def test_add( assert stub_model_service.list_all() == stub_model_configs + [stub_new_model_config] -def test_add_duplicate_alias(stub_model_service: ModelService, stub_inference_parameters: InferenceParameters): +def test_add_duplicate_alias( + stub_model_service: ModelService, stub_inference_parameters: CompletionInferenceParameters +): """Test adding a model with an alias that already exists.""" duplicate_model = ModelConfig( alias="test-alias-1", @@ -61,7 +63,9 @@ def test_update_nonexistent_model(stub_model_service: ModelService, stub_new_mod stub_model_service.update("nonexistent", stub_new_model_config) -def test_update_to_existing_alias(stub_model_service: ModelService, stub_inference_parameters: InferenceParameters): +def test_update_to_existing_alias( + stub_model_service: ModelService, stub_inference_parameters: CompletionInferenceParameters +): """Test updating a model to an alias that already exists.""" updated_model = ModelConfig( alias="test-alias-2", # Already exists diff --git a/tests/config/test_config_builder.py b/tests/config/test_config_builder.py index aab8112a..57741e59 100644 --- a/tests/config/test_config_builder.py +++ b/tests/config/test_config_builder.py @@ -26,7 +26,7 @@ from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.datastore import DatastoreSettings from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import CompletionInferenceParameters, ModelConfig from data_designer.config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams from data_designer.config.seed import DatastoreSeedDatasetReference, SamplingStrategy @@ -670,7 +670,7 @@ def test_add_model_config(stub_empty_builder): new_model_config = ModelConfig( alias="new-model", model="openai/gpt-4", - inference_parameters=InferenceParameters( + inference_parameters=CompletionInferenceParameters( temperature=0.7, top_p=0.95, max_tokens=1024, @@ -691,7 +691,7 @@ def test_add_model_config(stub_empty_builder): alias="provider-model", model="anthropic/claude-3", provider="anthropic", - inference_parameters=InferenceParameters(temperature=0.8), + inference_parameters=CompletionInferenceParameters(temperature=0.8), ) stub_empty_builder.add_model_config(provider_model_config) @@ -717,7 +717,7 @@ def test_add_model_config_duplicate_alias(stub_empty_builder): duplicate_model_config = ModelConfig( alias="stub-model", model="different/model", - inference_parameters=InferenceParameters(temperature=0.5), + inference_parameters=CompletionInferenceParameters(temperature=0.5), ) with pytest.raises( @@ -733,12 +733,12 @@ def test_delete_model_config(stub_empty_builder): model_config_1 = ModelConfig( alias="model-to-delete", model="model/delete", - inference_parameters=InferenceParameters(temperature=0.5), + inference_parameters=CompletionInferenceParameters(temperature=0.5), ) model_config_2 = ModelConfig( alias="model-to-keep", model="model/keep", - inference_parameters=InferenceParameters(temperature=0.6), + inference_parameters=CompletionInferenceParameters(temperature=0.6), ) stub_empty_builder.add_model_config(model_config_1) stub_empty_builder.add_model_config(model_config_2) diff --git a/tests/config/test_default_model_settings.py b/tests/config/test_default_model_settings.py index 222bb410..8f389a69 100644 --- a/tests/config/test_default_model_settings.py +++ b/tests/config/test_default_model_settings.py @@ -18,20 +18,20 @@ get_default_providers, resolve_seed_default_model_settings, ) -from data_designer.config.models import InferenceParameters +from data_designer.config.models import CompletionInferenceParameters from data_designer.config.utils.visualization import get_nvidia_api_key, get_openai_api_key def test_get_default_inference_parameters(): - assert get_default_inference_parameters("text") == InferenceParameters( + assert get_default_inference_parameters("text") == CompletionInferenceParameters( temperature=0.85, top_p=0.95, ) - assert get_default_inference_parameters("reasoning") == InferenceParameters( + assert get_default_inference_parameters("reasoning") == CompletionInferenceParameters( temperature=0.35, top_p=0.95, ) - assert get_default_inference_parameters("vision") == InferenceParameters( + assert get_default_inference_parameters("vision") == CompletionInferenceParameters( temperature=0.85, top_p=0.95, ) diff --git a/tests/config/test_models.py b/tests/config/test_models.py index 6a3d7b25..f1f65401 100644 --- a/tests/config/test_models.py +++ b/tests/config/test_models.py @@ -11,9 +11,9 @@ from data_designer.config.errors import InvalidConfigError from data_designer.config.models import ( + CompletionInferenceParameters, ImageContext, ImageFormat, - InferenceParameters, ManualDistribution, ManualDistributionParams, ModalityDataType, @@ -46,13 +46,13 @@ def test_image_context_validate_image_format(): def test_inference_parameters_default_construction(): - empty_inference_parameters = InferenceParameters() + empty_inference_parameters = CompletionInferenceParameters() assert empty_inference_parameters.generate_kwargs == {} assert empty_inference_parameters.max_parallel_requests == 4 def test_inference_parameters_generate_kwargs(): - assert InferenceParameters( + assert CompletionInferenceParameters( temperature=0.95, top_p=0.95, max_tokens=100, @@ -67,9 +67,9 @@ def test_inference_parameters_generate_kwargs(): "extra_body": {"reasoning_effort": "high"}, } - assert InferenceParameters().generate_kwargs == {} + assert CompletionInferenceParameters().generate_kwargs == {} - inference_parameters_kwargs = InferenceParameters( + inference_parameters_kwargs = CompletionInferenceParameters( temperature=UniformDistribution(params=UniformDistributionParams(low=0.0, high=1.0)), top_p=ManualDistribution(params=ManualDistributionParams(values=[0.0, 1.0], weights=[0.5, 0.5])), ).generate_kwargs @@ -131,32 +131,38 @@ def test_inference_parameters_temperature_validation(): # All temp values provide in a manual destribution should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters( + CompletionInferenceParameters( temperature=ManualDistribution(params=ManualDistributionParams(values=[0.5, 2.5], weights=[0.5, 0.5])) ) # High and low values of uniform distribution should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.5))) + CompletionInferenceParameters( + temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.5)) + ) with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=2.0))) + CompletionInferenceParameters( + temperature=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=2.0)) + ) # Static values should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(temperature=3.0) + CompletionInferenceParameters(temperature=3.0) with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(temperature=-1.0) + CompletionInferenceParameters(temperature=-1.0) # Valid temperature values shouldn't raise validation errors try: - InferenceParameters(temperature=0.1) - InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.0))) - InferenceParameters( + CompletionInferenceParameters(temperature=0.1) + CompletionInferenceParameters( + temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.0)) + ) + CompletionInferenceParameters( temperature=ManualDistribution(params=ManualDistributionParams(values=[0.5, 2.0], weights=[0.5, 0.5])) ) except Exception: - pytest.fail("Unexpected exception raised during InferenceParameters temperature validation") + pytest.fail("Unexpected exception raised during CompletionInferenceParameters temperature validation") def test_generation_parameters_top_p_validation(): @@ -164,31 +170,31 @@ def test_generation_parameters_top_p_validation(): # All top_p values provide in a manual destribution should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters( + CompletionInferenceParameters( top_p=ManualDistribution(params=ManualDistributionParams(values=[0.5, 1.5], weights=[0.5, 0.5])) ) # High and low values of uniform distribution should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.5))) + CompletionInferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.5))) with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=1.0))) + CompletionInferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=1.0))) # Static values should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(top_p=1.5) + CompletionInferenceParameters(top_p=1.5) with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(top_p=-0.1) + CompletionInferenceParameters(top_p=-0.1) # Valid top_p values shouldn't raise validation errors try: - InferenceParameters(top_p=0.1) - InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.0))) - InferenceParameters( + CompletionInferenceParameters(top_p=0.1) + CompletionInferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.0))) + CompletionInferenceParameters( top_p=ManualDistribution(params=ManualDistributionParams(values=[0.5, 1.0], weights=[0.5, 0.5])) ) except Exception: - pytest.fail("Unexpected exception raised during InferenceParameters top_p validation") + pytest.fail("Unexpected exception raised during CompletionInferenceParameters top_p validation") def test_generation_parameters_max_tokens_validation(): @@ -196,15 +202,15 @@ def test_generation_parameters_max_tokens_validation(): ValidationError, match="Input should be greater than or equal to 1", ): - InferenceParameters(max_tokens=0) + CompletionInferenceParameters(max_tokens=0) # Valid max_tokens values shouldn't raise validation errors try: - InferenceParameters(max_tokens=128_000) - InferenceParameters(max_tokens=4096) - InferenceParameters(max_tokens=1) + CompletionInferenceParameters(max_tokens=128_000) + CompletionInferenceParameters(max_tokens=4096) + CompletionInferenceParameters(max_tokens=1) except Exception: - pytest.fail("Unexpected exception raised during InferenceParameters max_tokens validation") + pytest.fail("Unexpected exception raised during CompletionInferenceParameters max_tokens validation") def test_load_model_configs(): @@ -250,4 +256,4 @@ def test_load_model_configs(): def test_model_config_default_construction(): model_config = ModelConfig(alias="test", model="test") - assert model_config.inference_parameters == InferenceParameters() + assert model_config.inference_parameters == CompletionInferenceParameters() diff --git a/tests/conftest.py b/tests/conftest.py index 31dc0057..46b5d318 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.datastore import DatastoreSettings -from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider +from data_designer.config.models import CompletionInferenceParameters, ModelConfig, ModelProvider @pytest.fixture @@ -135,7 +135,7 @@ def stub_model_configs() -> list[ModelConfig]: ModelConfig( alias="stub-model", model="stub-model", - inference_parameters=InferenceParameters( + inference_parameters=CompletionInferenceParameters( temperature=0.9, top_p=0.9, max_tokens=2048, diff --git a/tests/engine/models/conftest.py b/tests/engine/models/conftest.py index 95e6941f..7edcd073 100644 --- a/tests/engine/models/conftest.py +++ b/tests/engine/models/conftest.py @@ -5,7 +5,7 @@ import pytest -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import CompletionInferenceParameters, ModelConfig from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry from data_designer.engine.models.registry import ModelRegistry, create_model_registry from data_designer.engine.secret_resolver import SecretsFileResolver @@ -38,7 +38,7 @@ def stub_model_configs() -> list[ModelConfig]: alias="stub-text", model="stub-model-text", provider="stub-model-provider", - inference_parameters=InferenceParameters( + inference_parameters=CompletionInferenceParameters( temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100 ), ), @@ -46,7 +46,7 @@ def stub_model_configs() -> list[ModelConfig]: alias="stub-reasoning", model="stub-model-reasoning", provider="stub-model-provider", - inference_parameters=InferenceParameters( + inference_parameters=CompletionInferenceParameters( temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100 ), ), diff --git a/tests/engine/models/test_model_registry.py b/tests/engine/models/test_model_registry.py index 571b9605..83e3b650 100644 --- a/tests/engine/models/test_model_registry.py +++ b/tests/engine/models/test_model_registry.py @@ -6,7 +6,7 @@ from litellm import AuthenticationError import pytest -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import CompletionInferenceParameters, ModelConfig from data_designer.engine.models.errors import ModelAuthenticationError from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.registry import ModelRegistry, create_model_registry @@ -24,7 +24,7 @@ def stub_new_model_config(): alias="stub-vision", model="stub-model-vision", provider="stub-model-provider", - inference_parameters=InferenceParameters( + inference_parameters=CompletionInferenceParameters( temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100 ), ) @@ -36,7 +36,7 @@ def stub_no_usage_config(): alias="no-usage", model="no-usage-model", provider="stub-model-provider", - inference_parameters=InferenceParameters(), + inference_parameters=CompletionInferenceParameters(), ) diff --git a/tests/essentials/test_init.py b/tests/essentials/test_init.py index 89f8388a..d810bba3 100644 --- a/tests/essentials/test_init.py +++ b/tests/essentials/test_init.py @@ -17,14 +17,17 @@ CodeLang, CodeValidatorParams, ColumnInequalityConstraint, + CompletionInferenceParameters, DataDesignerColumnType, DataDesignerConfig, DataDesignerConfigBuilder, DatastoreSeedDatasetReference, DatastoreSettings, DatetimeSamplerParams, + EmbeddingInferenceParameters, ExpressionColumnConfig, GaussianSamplerParams, + GenerationType, ImageContext, ImageFormat, InferenceParameters, @@ -109,6 +112,9 @@ def test_model_config_imports(): assert ImageContext is not None assert ImageFormat is not None assert InferenceParameters is not None + assert CompletionInferenceParameters is not None + assert EmbeddingInferenceParameters is not None + assert GenerationType is not None assert ManualDistribution is not None assert ManualDistributionParams is not None assert Modality is not None @@ -232,6 +238,7 @@ def test_all_contains_column_configs(): assert "Score" in __all__ assert "SeedDatasetColumnConfig" in __all__ assert "ValidationColumnConfig" in __all__ + assert "EmbeddingColumnConfig" in __all__ def test_all_contains_sampler_params(): @@ -250,6 +257,8 @@ def test_all_contains_sampler_params(): assert "TimeDeltaSamplerParams" in __all__ assert "UniformSamplerParams" in __all__ assert "UUIDSamplerParams" in __all__ + assert "PersonFromFakerSamplerParams" in __all__ + assert "ProcessorType" in __all__ def test_all_contains_constraints(): @@ -263,6 +272,9 @@ def test_all_contains_model_configs(): assert "ImageContext" in __all__ assert "ImageFormat" in __all__ assert "InferenceParameters" in __all__ + assert "CompletionInferenceParameters" in __all__ + assert "EmbeddingInferenceParameters" in __all__ + assert "GenerationType" in __all__ assert "ManualDistribution" in __all__ assert "ManualDistributionParams" in __all__ assert "Modality" in __all__ From 2c1b2676fe0234016a7e13fe57171da2295eaf7c Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 26 Nov 2025 13:04:18 -0700 Subject: [PATCH 13/64] Remove purpose from consolidated kwargs --- src/data_designer/config/models.py | 2 +- .../engine/column_generators/generators/embedding.py | 1 + src/data_designer/engine/models/facade.py | 2 ++ tests/engine/models/test_facade.py | 8 ++++---- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index 1df7055e..7b129556 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -229,7 +229,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: class EmbeddingInferenceParameters(BaseInferenceParameters): - encoding_format: Optional[Literal["float", "base64"]] = "float" + encoding_format: Optional[Literal["float", "base64"]] = None dimensions: Optional[int] = None @property diff --git a/src/data_designer/engine/column_generators/generators/embedding.py b/src/data_designer/engine/column_generators/generators/embedding.py index d9981ccd..48fc309f 100644 --- a/src/data_designer/engine/column_generators/generators/embedding.py +++ b/src/data_designer/engine/column_generators/generators/embedding.py @@ -28,6 +28,7 @@ def generate(self, data: dict) -> dict: deserialized_record = deserialize_json_values(data) input_text = deserialized_record[self.config.target_column] input_chunks = re.split(self.config.chunk_pattern, input_text) if self.config.chunk_pattern else [input_text] + input_chunks = [chunk.strip() for chunk in input_chunks if chunk.strip()] embeddings = self.model.generate_text_embeddings(input_texts=input_chunks) data[self.config.name] = { "embeddings": embeddings, diff --git a/src/data_designer/engine/models/facade.py b/src/data_designer/engine/models/facade.py index c205a4ca..6b98c0a7 100644 --- a/src/data_designer/engine/models/facade.py +++ b/src/data_designer/engine/models/facade.py @@ -91,6 +91,8 @@ def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = self._track_usage(response) def consolidate_kwargs(self, **kwargs) -> dict[str, Any]: + # Remove purpose from kwargs to avoid passing it to the model + kwargs.pop("purpose", None) kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs} if self.model_provider.extra_body: kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} diff --git a/tests/engine/models/test_facade.py b/tests/engine/models/test_facade.py index afe27730..8765d0ab 100644 --- a/tests/engine/models/test_facade.py +++ b/tests/engine/models/test_facade.py @@ -116,17 +116,17 @@ def test_usage_stats_property(stub_model_facade): def test_consolidate_kwargs(stub_model_configs, stub_model_facade): - # Model config generate kwargs are used as base - result = stub_model_facade.consolidate_kwargs() + # Model config generate kwargs are used as base, and purpose is removed + result = stub_model_facade.consolidate_kwargs(purpose="test") assert result == stub_model_configs[0].inference_parameters.generate_kwargs # kwargs overrides model config generate kwargs - result = stub_model_facade.consolidate_kwargs(temperature=0.01) + result = stub_model_facade.consolidate_kwargs(temperature=0.01, purpose="test") assert result == {**stub_model_configs[0].inference_parameters.generate_kwargs, "temperature": 0.01} # Provider extra_body overrides all other kwargs stub_model_facade.model_provider.extra_body = {"foo_provider": "bar_provider"} - result = stub_model_facade.consolidate_kwargs(extra_body={"foo": "bar"}) + result = stub_model_facade.consolidate_kwargs(extra_body={"foo": "bar"}, purpose="test") assert result == { **stub_model_configs[0].inference_parameters.generate_kwargs, "extra_body": {"foo_provider": "bar_provider", "foo": "bar"}, From 4b1492baf805adc0719d73857b1f19a219f49375 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 2 Dec 2025 11:38:33 -0700 Subject: [PATCH 14/64] WithModelConfiguration.inference_parameters should should be typed with BaseInferenceParameters --- src/data_designer/engine/column_generators/generators/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data_designer/engine/column_generators/generators/base.py b/src/data_designer/engine/column_generators/generators/base.py index 8977a63b..580c96a6 100644 --- a/src/data_designer/engine/column_generators/generators/base.py +++ b/src/data_designer/engine/column_generators/generators/base.py @@ -9,7 +9,7 @@ import pandas as pd from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import BaseInferenceParameters, ModelConfig from data_designer.config.utils.type_helpers import StrEnum from data_designer.engine.column_generators.utils.prompt_renderer import ( RecordBasedPromptRenderer, @@ -81,7 +81,7 @@ def model_config(self) -> ModelConfig: return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias) @functools.cached_property - def inference_parameters(self) -> InferenceParameters: + def inference_parameters(self) -> BaseInferenceParameters: return self.model_config.inference_parameters @functools.cached_property From c445caf53f213a54b80b3df71a0c00334ccf519b Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 2 Dec 2025 14:37:07 -0700 Subject: [PATCH 15/64] Type as WithModelGeneration --- .../engine/dataset_builders/column_wise_builder.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index ff9289ee..2e30407c 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -17,8 +17,11 @@ ProcessorConfig, ProcessorType, ) -from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy -from data_designer.engine.column_generators.generators.llm_completion import WithCompletionGeneration +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + GenerationStrategy, + WithModelGeneration, +) from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.dataset_builders.multi_column_configs import ( @@ -169,7 +172,7 @@ def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None: max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR - if isinstance(generator, WithCompletionGeneration): + if isinstance(generator, WithModelGeneration): max_workers = generator.inference_parameters.max_parallel_requests self._fan_out_with_threads(generator, max_workers=max_workers) @@ -183,7 +186,7 @@ def _run_model_health_check_if_needed(self) -> bool: set(config.model_alias for config in self.llm_generated_column_configs) ) - def _fan_out_with_threads(self, generator: WithCompletionGeneration, max_workers: int) -> None: + def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int) -> None: if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL: raise DatasetGenerationError( f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} " From 4b8aa2bf9258c1a1fc3be10ff1d817ae797ed2d7 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 2 Dec 2025 16:06:52 -0700 Subject: [PATCH 16/64] Add image generation modality --- src/data_designer/config/column_configs.py | 40 +++++++++++++++ src/data_designer/config/column_types.py | 8 +++ src/data_designer/config/models.py | 20 +++++++- .../column_generators/generators/base.py | 14 ------ .../column_generators/generators/image.py | 49 +++++++++++++++++++ .../generators/llm_completion.py | 12 +++++ .../engine/column_generators/registry.py | 4 +- src/data_designer/engine/models/facade.py | 36 +++++++++++++- src/data_designer/engine/models/registry.py | 6 +++ src/data_designer/essentials/__init__.py | 4 ++ 10 files changed, 176 insertions(+), 17 deletions(-) create mode 100644 src/data_designer/engine/column_generators/generators/image.py diff --git a/src/data_designer/config/column_configs.py b/src/data_designer/config/column_configs.py index 339be35d..eb93f9f0 100644 --- a/src/data_designer/config/column_configs.py +++ b/src/data_designer/config/column_configs.py @@ -401,3 +401,43 @@ class EmbeddingColumnConfig(SingleColumnConfig): @property def required_columns(self) -> list[str]: return [self.target_column] + + +class ImageGenerationColumnConfig(SingleColumnConfig): + """Configuration for image generation columns. + + Image columns generate images using a specified model. + + Attributes: + column_type: Discriminator field, always "image-generation" for this configuration type. + prompt: Prompt template for image generation. Supports Jinja2 templating to + reference other columns (e.g., "Generate an image of a {{ character_name }}"). + Must be a valid Jinja2 template. + model_alias: The model to use for image generation. + """ + + column_type: Literal["image-generation"] = "image-generation" + prompt: str + model_alias: str + + @property + def required_columns(self) -> list[str]: + """Get columns referenced in the prompt template. + + Returns: + List of unique column names referenced in Jinja2 templates. + """ + return list(get_prompt_template_keywords(self.prompt)) + + @model_validator(mode="after") + def assert_prompt_valid_jinja(self) -> Self: + """Validate that prompt is a valid Jinja2 template. + + Returns: + The validated instance. + + Raises: + InvalidConfigError: If prompt contains invalid Jinja2 syntax. + """ + assert_valid_jinja2_template(self.prompt) + return self diff --git a/src/data_designer/config/column_types.py b/src/data_designer/config/column_types.py index aab55c4d..efdeb094 100644 --- a/src/data_designer/config/column_types.py +++ b/src/data_designer/config/column_types.py @@ -9,6 +9,7 @@ from .column_configs import ( EmbeddingColumnConfig, ExpressionColumnConfig, + ImageGenerationColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -33,6 +34,7 @@ SeedDatasetColumnConfig, ValidationColumnConfig, EmbeddingColumnConfig, + ImageGenerationColumnConfig, ] ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT) @@ -53,6 +55,7 @@ DataDesignerColumnType.SAMPLER: "🎲", DataDesignerColumnType.VALIDATION: "πŸ”", DataDesignerColumnType.EMBEDDING: "🧬", + DataDesignerColumnType.IMAGE_GENERATION: "πŸ–ΌοΈ", } COLUMN_TYPE_EMOJI_MAP.update( {DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_column_generator_plugins()} @@ -70,6 +73,7 @@ def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumn DataDesignerColumnType.LLM_TEXT, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE_GENERATION, } dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType)) return column_type in dag_column_types @@ -84,6 +88,7 @@ def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType] DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE_GENERATION, } llm_generated_column_types.update( plugin_manager.get_plugin_column_types( @@ -124,6 +129,8 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType return SeedDatasetColumnConfig(name=name, **kwargs) if column_type == DataDesignerColumnType.EMBEDDING: return EmbeddingColumnConfig(name=name, **kwargs) + if column_type == DataDesignerColumnType.IMAGE_GENERATION: + return ImageGenerationColumnConfig(name=name, **kwargs) if plugin := plugin_manager.get_column_generator_plugin_if_exists(column_type.value): return plugin.config_cls(name=name, **kwargs) raise InvalidColumnTypeError(f"πŸ›‘ {column_type} is not a valid column type.") # pragma: no cover @@ -139,6 +146,7 @@ def get_column_display_order() -> list[DataDesignerColumnType]: DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE_GENERATION, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, ] diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index 7b129556..6e535038 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -242,8 +242,24 @@ def generate_kwargs(self) -> dict[str, Union[float, int]]: return result +class ImageGenerationInferenceParameters(BaseInferenceParameters): + quality: str + size: str + output_format: Optional[ModalityDataType] = ModalityDataType.BASE64 + + @property + def generate_kwargs(self) -> dict[str, Union[float, int]]: + result = super().generate_kwargs + result["size"] = self.size + result["quality"] = self.quality + result["response_format"] = ( + self.output_format.value if self.output_format == ModalityDataType.URL else "b64_json" + ) + return result + + InferenceParametersT: TypeAlias = Union[ - InferenceParameters, CompletionInferenceParameters, EmbeddingInferenceParameters + InferenceParameters, CompletionInferenceParameters, EmbeddingInferenceParameters, ImageGenerationInferenceParameters ] @@ -272,6 +288,8 @@ def generation_type(self) -> GenerationType: return GenerationType.CHAT_COMPLETION elif isinstance(self.inference_parameters, EmbeddingInferenceParameters): return GenerationType.EMBEDDING + elif isinstance(self.inference_parameters, ImageGenerationInferenceParameters): + return GenerationType.IMAGE_GENERATION else: raise ValueError(f"Unsupported inference parameters type: {type(self.inference_parameters)}") diff --git a/src/data_designer/engine/column_generators/generators/base.py b/src/data_designer/engine/column_generators/generators/base.py index 580c96a6..a98038b3 100644 --- a/src/data_designer/engine/column_generators/generators/base.py +++ b/src/data_designer/engine/column_generators/generators/base.py @@ -11,9 +11,6 @@ from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP from data_designer.config.models import BaseInferenceParameters, ModelConfig from data_designer.config.utils.type_helpers import StrEnum -from data_designer.engine.column_generators.utils.prompt_renderer import ( - RecordBasedPromptRenderer, -) from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT from data_designer.engine.models.facade import ModelFacade @@ -84,17 +81,6 @@ def model_config(self) -> ModelConfig: def inference_parameters(self) -> BaseInferenceParameters: return self.model_config.inference_parameters - @functools.cached_property - def prompt_renderer(self) -> RecordBasedPromptRenderer: - return RecordBasedPromptRenderer( - response_recipe=self.response_recipe, - error_message_context={ - "column_name": self.config.name, - "column_type": self.config.column_type, - "model_alias": self.config.model_alias, - }, - ) - def log_pre_generation(self) -> None: emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type] logger.info(f"{emoji} Preparing {self.config.column_type} column generation") diff --git a/src/data_designer/engine/column_generators/generators/image.py b/src/data_designer/engine/column_generators/generators/image.py new file mode 100644 index 00000000..f7cfba89 --- /dev/null +++ b/src/data_designer/engine/column_generators/generators/image.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from litellm.types.utils import ImageResponse + +from data_designer.config.column_configs import ImageGenerationColumnConfig +from data_designer.config.models import ModalityDataType +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + GenerationStrategy, + GeneratorMetadata, + WithModelGeneration, +) +from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering +from data_designer.engine.processing.utils import deserialize_json_values +from data_designer.engine.resources.resource_provider import ResourceType + + +class ImageCellGenerator( + WithModelGeneration, WithJinja2UserTemplateRendering, ColumnGenerator[ImageGenerationColumnConfig] +): + @staticmethod + def metadata() -> GeneratorMetadata: + return GeneratorMetadata( + name="image_cell_generator", + description="Generate images using a specified model.", + generation_strategy=GenerationStrategy.CELL_BY_CELL, + required_resources=[ResourceType.MODEL_REGISTRY], + ) + + def generate(self, data: dict) -> dict: + deserialized_record = deserialize_json_values(data) + missing_columns = list(set(self.config.required_columns) - set(data.keys())) + if len(missing_columns) > 0: + error_msg = ( + f"There was an error preparing the Jinja2 expression template. " + f"The following columns {missing_columns} are missing!" + ) + raise ValueError(error_msg) + + self.prepare_jinja2_template_renderer(self.config.prompt, list(deserialized_record.keys())) + prompt = self.render_template(deserialized_record) + image_response: ImageResponse = self.model.generate_image(prompt=prompt) + if self.model_config.inference_parameters.output_format == ModalityDataType.URL: + data[self.config.name] = image_response.data[0].url + else: + data[self.config.name] = image_response.data[0].b64_json + return data diff --git a/src/data_designer/engine/column_generators/generators/llm_completion.py b/src/data_designer/engine/column_generators/generators/llm_completion.py index 5665ba85..8fae174b 100644 --- a/src/data_designer/engine/column_generators/generators/llm_completion.py +++ b/src/data_designer/engine/column_generators/generators/llm_completion.py @@ -19,6 +19,7 @@ ) from data_designer.engine.column_generators.utils.prompt_renderer import ( PromptType, + RecordBasedPromptRenderer, create_response_recipe, ) from data_designer.engine.models.recipes.base import ResponseRecipe @@ -45,6 +46,17 @@ def max_conversation_correction_steps(self) -> int: def max_conversation_restarts(self) -> int: return DEFAULT_MAX_CONVERSATION_RESTARTS + @functools.cached_property + def prompt_renderer(self) -> RecordBasedPromptRenderer: + return RecordBasedPromptRenderer( + response_recipe=self.response_recipe, + error_message_context={ + "column_name": self.config.name, + "column_type": self.config.column_type, + "model_alias": self.config.model_alias, + }, + ) + def generate(self, data: dict) -> dict: deserialized_record = deserialize_json_values(data) diff --git a/src/data_designer/engine/column_generators/registry.py b/src/data_designer/engine/column_generators/registry.py index 7171e561..3d000729 100644 --- a/src/data_designer/engine/column_generators/registry.py +++ b/src/data_designer/engine/column_generators/registry.py @@ -5,6 +5,7 @@ from data_designer.config.column_configs import ( EmbeddingColumnConfig, ExpressionColumnConfig, + ImageGenerationColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -15,6 +16,7 @@ from data_designer.engine.column_generators.generators.base import ColumnGenerator from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator +from data_designer.engine.column_generators.generators.image import ImageCellGenerator from data_designer.engine.column_generators.generators.llm_completion import ( LLMCodeCellGenerator, LLMJudgeCellGenerator, @@ -47,7 +49,7 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig) registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig) registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig) - + registry.register(DataDesignerColumnType.IMAGE_GENERATION, ImageCellGenerator, ImageGenerationColumnConfig) if with_plugins: for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR): registry.register( diff --git a/src/data_designer/engine/models/facade.py b/src/data_designer/engine/models/facade.py index 6b98c0a7..33c79797 100644 --- a/src/data_designer/engine/models/facade.py +++ b/src/data_designer/engine/models/facade.py @@ -9,7 +9,7 @@ from typing import Any from litellm.types.router import DeploymentTypedDict, LiteLLM_Params -from litellm.types.utils import EmbeddingResponse, ModelResponse +from litellm.types.utils import EmbeddingResponse, ImageResponse, ImageUsage, ModelResponse from data_designer.config.models import GenerationType, ModelConfig, ModelProvider from data_designer.engine.model_provider import ModelProviderRegistry @@ -131,6 +131,27 @@ def generate_text_embeddings( if not skip_usage_tracking and response is not None: self._track_usage_from_embedding(response) + @catch_llm_exceptions + def generate_image(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> ImageResponse: + logger.debug( + f"Generating image with model {self.model_name!r}...", + extra={"model": self.model_name, "prompt": prompt}, + ) + kwargs = self.consolidate_kwargs(**kwargs) + response = None + try: + response = self._router.image_generation(prompt=prompt, model=self.model_name, **kwargs) + logger.debug( + f"Received image from model {self.model_name!r}", + extra={"model": self.model_name, "response": response}, + ) + return response + except Exception as e: + raise e + finally: + if not skip_usage_tracking and response is not None: + self._track_usage_from_image(response) + @catch_llm_exceptions def generate( self, @@ -280,3 +301,16 @@ def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> Non ), request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) + + def _track_usage_from_image(self, response: ImageResponse | None) -> None: + if response is None: + self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) + return + if response.usage is not None and isinstance(response.usage, ImageUsage): + self._usage_stats.extend( + token_usage=TokenUsageStats( + prompt_tokens=response.usage.input_tokens, + completion_tokens=response.usage.output_tokens, + ), + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + ) diff --git a/src/data_designer/engine/models/registry.py b/src/data_designer/engine/models/registry.py index 4330ea18..91025684 100644 --- a/src/data_designer/engine/models/registry.py +++ b/src/data_designer/engine/models/registry.py @@ -97,6 +97,12 @@ def run_health_check(self, model_aliases: set[str]) -> None: skip_usage_tracking=True, purpose="running health checks", ) + elif model.model_generation_type == GenerationType.IMAGE_GENERATION: + model.generate_image( + prompt="Generate a simple pixel", + skip_usage_tracking=True, + purpose="running health checks", + ) else: raise ValueError(f"Unsupported generation type: {model.model_generation_type}") logger.info(" |-- βœ… Passed!") diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py index cd1dd6ba..e8c6091c 100644 --- a/src/data_designer/essentials/__init__.py +++ b/src/data_designer/essentials/__init__.py @@ -8,6 +8,7 @@ from ..config.column_configs import ( EmbeddingColumnConfig, ExpressionColumnConfig, + ImageGenerationColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -28,6 +29,7 @@ GenerationType, ImageContext, ImageFormat, + ImageGenerationInferenceParameters, InferenceParameters, ManualDistribution, ManualDistributionParams, @@ -105,6 +107,8 @@ "InfoType", "ImageContext", "ImageFormat", + "ImageGenerationColumnConfig", + "ImageGenerationInferenceParameters", "InferenceParameters", "JudgeScoreProfilerConfig", "LLMCodeColumnConfig", From 2c5933f789b8e1dc47d40be56f3ff76741850d10 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 2 Dec 2025 17:49:32 -0700 Subject: [PATCH 17/64] update return type for generate_kwargs --- src/data_designer/config/models.py | 10 ++++------ tests/config/test_columns.py | 3 +++ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index 6e535038..9d0ee6a6 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -142,7 +142,7 @@ class BaseInferenceParameters(ConfigBase, ABC): extra_body: Optional[dict[str, Any]] = None @property - def generate_kwargs(self) -> dict[str, Union[float, int]]: + def generate_kwargs(self) -> dict[str, Any]: result = {} if self.timeout is not None: result["timeout"] = self.timeout @@ -157,7 +157,7 @@ class CompletionInferenceParameters(BaseInferenceParameters): max_tokens: Optional[int] = Field(default=None, ge=1) @property - def generate_kwargs(self) -> dict[str, Union[float, int]]: + def generate_kwargs(self) -> dict[str, Any]: result = super().generate_kwargs if self.temperature is not None: result["temperature"] = ( @@ -248,13 +248,11 @@ class ImageGenerationInferenceParameters(BaseInferenceParameters): output_format: Optional[ModalityDataType] = ModalityDataType.BASE64 @property - def generate_kwargs(self) -> dict[str, Union[float, int]]: + def generate_kwargs(self) -> dict[str, Any]: result = super().generate_kwargs result["size"] = self.size result["quality"] = self.quality - result["response_format"] = ( - self.output_format.value if self.output_format == ModalityDataType.URL else "b64_json" - ) + result["response_format"] = "b64_json" if self.output_format == ModalityDataType.BASE64 else self.output_format return result diff --git a/tests/config/test_columns.py b/tests/config/test_columns.py index f7763b07..2e74695f 100644 --- a/tests/config/test_columns.py +++ b/tests/config/test_columns.py @@ -50,6 +50,7 @@ def test_data_designer_column_type_get_display_order(): DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE_GENERATION, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, ] @@ -61,6 +62,7 @@ def test_data_designer_column_type_is_llm_generated(): assert column_type_is_llm_generated(DataDesignerColumnType.LLM_STRUCTURED) assert column_type_is_llm_generated(DataDesignerColumnType.LLM_JUDGE) assert column_type_is_llm_generated(DataDesignerColumnType.EMBEDDING) + assert column_type_is_llm_generated(DataDesignerColumnType.IMAGE_GENERATION) assert not column_type_is_llm_generated(DataDesignerColumnType.SAMPLER) assert not column_type_is_llm_generated(DataDesignerColumnType.VALIDATION) assert not column_type_is_llm_generated(DataDesignerColumnType.EXPRESSION) @@ -75,6 +77,7 @@ def test_data_designer_column_type_is_in_dag(): assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_TEXT) assert column_type_used_in_execution_dag(DataDesignerColumnType.VALIDATION) assert column_type_used_in_execution_dag(DataDesignerColumnType.EMBEDDING) + assert column_type_used_in_execution_dag(DataDesignerColumnType.IMAGE_GENERATION) assert not column_type_used_in_execution_dag(DataDesignerColumnType.SAMPLER) assert not column_type_used_in_execution_dag(DataDesignerColumnType.SEED_DATASET) From c6c29d4fdca3a292d06abdbfaee11c2f66269cfb Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 3 Dec 2025 10:25:17 -0700 Subject: [PATCH 18/64] make generation_type a field of ModelConfig as opposed to a prop resolved based on the type of InferenceParameters --- src/data_designer/config/models.py | 25 +++++++++------- tests/config/test_models.py | 47 +++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 11 deletions(-) diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index 9d0ee6a6..b10deca0 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -271,6 +271,7 @@ class ModelConfig(ConfigBase): alias: str model: str inference_parameters: InferenceParametersT = Field(default_factory=CompletionInferenceParameters) + generation_type: GenerationType = Field(default=GenerationType.CHAT_COMPLETION) provider: Optional[str] = None @model_validator(mode="after") @@ -280,16 +281,20 @@ def _normalize_deprecated_inference_parameters(self) -> Self: self.inference_parameters = CompletionInferenceParameters(**self.inference_parameters.model_dump()) return self - @property - def generation_type(self) -> GenerationType: - if isinstance(self.inference_parameters, CompletionInferenceParameters): - return GenerationType.CHAT_COMPLETION - elif isinstance(self.inference_parameters, EmbeddingInferenceParameters): - return GenerationType.EMBEDDING - elif isinstance(self.inference_parameters, ImageGenerationInferenceParameters): - return GenerationType.IMAGE_GENERATION - else: - raise ValueError(f"Unsupported inference parameters type: {type(self.inference_parameters)}") + @model_validator(mode="after") + def _validate_generation_type(self) -> Self: + generation_type_instance_map = { + GenerationType.CHAT_COMPLETION: CompletionInferenceParameters, + GenerationType.EMBEDDING: EmbeddingInferenceParameters, + GenerationType.IMAGE_GENERATION: ImageGenerationInferenceParameters, + } + if self.generation_type not in generation_type_instance_map: + raise ValueError(f"Invalid generation type: {self.generation_type}") + if not isinstance(self.inference_parameters, generation_type_instance_map[self.generation_type]): + raise ValueError( + f"Inference parameters must be an instance of {generation_type_instance_map[self.generation_type].__name__!r} when generation_type is {self.generation_type!r}" + ) + return self class ModelProvider(ConfigBase): diff --git a/tests/config/test_models.py b/tests/config/test_models.py index f1f65401..40f6afe9 100644 --- a/tests/config/test_models.py +++ b/tests/config/test_models.py @@ -12,8 +12,11 @@ from data_designer.config.errors import InvalidConfigError from data_designer.config.models import ( CompletionInferenceParameters, + EmbeddingInferenceParameters, + GenerationType, ImageContext, ImageFormat, + ImageGenerationInferenceParameters, ManualDistribution, ManualDistributionParams, ModalityDataType, @@ -254,6 +257,48 @@ def test_load_model_configs(): load_model_configs(tmp_file.name) -def test_model_config_default_construction(): +def test_model_config_construction(): + # test default construction model_config = ModelConfig(alias="test", model="test") assert model_config.inference_parameters == CompletionInferenceParameters() + assert model_config.generation_type == GenerationType.CHAT_COMPLETION + + # test construction with completion inference parameters + completion_params = CompletionInferenceParameters(temperature=0.5, top_p=0.5, max_tokens=100) + model_config = ModelConfig(alias="test", model="test", inference_parameters=completion_params) + assert model_config.inference_parameters == completion_params + assert model_config.generation_type == GenerationType.CHAT_COMPLETION + + # test construction with embedding inference parameters + embedding_params = EmbeddingInferenceParameters(dimensions=100) + model_config = ModelConfig( + alias="test", model="test", generation_type=GenerationType.EMBEDDING, inference_parameters=embedding_params + ) + assert model_config.inference_parameters == embedding_params + assert model_config.generation_type == GenerationType.EMBEDDING + + # test construction with image generation inference parameters + image_generation_params = ImageGenerationInferenceParameters(size="1024x1024", quality="standard") + model_config = ModelConfig( + alias="test", + model="test", + generation_type=GenerationType.IMAGE_GENERATION, + inference_parameters=image_generation_params, + ) + assert model_config.inference_parameters == image_generation_params + assert model_config.generation_type == GenerationType.IMAGE_GENERATION + + +def test_model_config_invalid_generation_type(): + with pytest.raises(ValidationError, match="Input should be"): + ModelConfig(alias="test", model="test", generation_type="invalid_generation_type") + with pytest.raises( + ValidationError, + match="Inference parameters must be an instance of 'EmbeddingInferenceParameters' when generation_type is 'embedding'", + ): + ModelConfig( + alias="test", + model="test", + generation_type=GenerationType.EMBEDDING, + inference_parameters=CompletionInferenceParameters(), + ) From 06a724b4090df4f150d348dd7c5e9b67b562daa4 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 3 Dec 2025 11:07:57 -0700 Subject: [PATCH 19/64] remove regex based chunking from embedding generator --- src/data_designer/config/column_configs.py | 7 +--- src/data_designer/config/models.py | 2 +- .../column_generators/generators/embedding.py | 10 ++--- src/data_designer/engine/processing/utils.py | 38 +++++++++++++++++++ tests/engine/processing/test_utils.py | 17 +++++++++ 5 files changed, 62 insertions(+), 12 deletions(-) diff --git a/src/data_designer/config/column_configs.py b/src/data_designer/config/column_configs.py index eb93f9f0..a3bef936 100644 --- a/src/data_designer/config/column_configs.py +++ b/src/data_designer/config/column_configs.py @@ -386,17 +386,14 @@ class EmbeddingColumnConfig(SingleColumnConfig): Attributes: column_type: Discriminator field, always "embedding" for this configuration type. - target_column: The column to generate embeddings for. + target_column: The column to generate embeddings for. The column could be a single text string or a list of text strings in stringified JSON format. + If it is a list of text strings in stringified JSON format, the embeddings will be generated for each text string. model_alias: The model to use for embedding generation. - chunk_pattern: Optional regex pattern to split the text in the target column into chunks. For example, if chunk_pattern - is r'\n+', the text will be split into chunks using one or more newlines as separators and embeddings generated for each chunk. - If not provided, the entire text will be embedded as a single chunk. """ column_type: Literal["embedding"] = "embedding" target_column: str model_alias: str - chunk_pattern: Optional[str] = None @property def required_columns(self) -> list[str]: diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index b10deca0..4b3ae12c 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -271,7 +271,7 @@ class ModelConfig(ConfigBase): alias: str model: str inference_parameters: InferenceParametersT = Field(default_factory=CompletionInferenceParameters) - generation_type: GenerationType = Field(default=GenerationType.CHAT_COMPLETION) + generation_type: Optional[GenerationType] = Field(default=GenerationType.CHAT_COMPLETION) provider: Optional[str] = None @model_validator(mode="after") diff --git a/src/data_designer/engine/column_generators/generators/embedding.py b/src/data_designer/engine/column_generators/generators/embedding.py index 48fc309f..ed738e8f 100644 --- a/src/data_designer/engine/column_generators/generators/embedding.py +++ b/src/data_designer/engine/column_generators/generators/embedding.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import re from data_designer.config.column_configs import EmbeddingColumnConfig from data_designer.engine.column_generators.generators.base import ( @@ -10,7 +9,7 @@ GeneratorMetadata, WithModelGeneration, ) -from data_designer.engine.processing.utils import deserialize_json_values +from data_designer.engine.processing.utils import deserialize_json_values, parse_list_string from data_designer.engine.resources.resource_provider import ResourceType @@ -26,10 +25,9 @@ def metadata() -> GeneratorMetadata: def generate(self, data: dict) -> dict: deserialized_record = deserialize_json_values(data) - input_text = deserialized_record[self.config.target_column] - input_chunks = re.split(self.config.chunk_pattern, input_text) if self.config.chunk_pattern else [input_text] - input_chunks = [chunk.strip() for chunk in input_chunks if chunk.strip()] - embeddings = self.model.generate_text_embeddings(input_texts=input_chunks) + input_texts = parse_list_string(deserialized_record[self.config.target_column]) + embeddings = self.model.generate_text_embeddings(input_texts=input_texts) + data[self.config.name] = { "embeddings": embeddings, "num_embeddings": len(embeddings), diff --git a/src/data_designer/engine/processing/utils.py b/src/data_designer/engine/processing/utils.py index 3579b3bd..5d42c40e 100644 --- a/src/data_designer/engine/processing/utils.py +++ b/src/data_designer/engine/processing/utils.py @@ -1,8 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import ast import json import logging +import re from typing import Any, TypeVar, Union, overload import pandas as pd @@ -100,6 +102,42 @@ def deserialize_json_values(data): return data +def parse_list_string(text: str) -> list[str]: + """Parse a list from a string, handling JSON arrays, Python lists, and trailing commas.""" + text = text.strip() + + # Try JSON first + try: + list_obj = json.loads(text) + if isinstance(list_obj, list): + return _clean_whitespace(list_obj) + except json.JSONDecodeError: + pass + + # Remove trailing commas before closing brackets (common in JSON-like strings) + text_cleaned = re.sub(r",\s*]", "]", text) + text_cleaned = re.sub(r",\s*}", "}", text_cleaned) + + # Try JSON again with cleaned text + try: + return _clean_whitespace(json.loads(text_cleaned)) + except json.JSONDecodeError: + pass + + # Try Python literal eval (handles single quotes) + try: + return _clean_whitespace(ast.literal_eval(text_cleaned)) + except (ValueError, SyntaxError): + pass + + # If all else fails, return the original text + return [text.strip()] + + +def _clean_whitespace(texts: list[str]) -> list[str]: + return [text.strip() for text in texts] + + def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None: joined_columns = set() for df in datasets: diff --git a/tests/engine/processing/test_utils.py b/tests/engine/processing/test_utils.py index a41e0ec2..dec0fe6a 100644 --- a/tests/engine/processing/test_utils.py +++ b/tests/engine/processing/test_utils.py @@ -9,6 +9,7 @@ from data_designer.engine.processing.utils import ( concat_datasets, deserialize_json_values, + parse_list_string, ) @@ -116,3 +117,19 @@ def test_concat_datasets_logging(mock_logger, stub_sample_dataframes): def test_deserialize_json_values_scenarios(test_case, input_data, expected_result): result = deserialize_json_values(input_data) assert result == expected_result + + +@pytest.mark.parametrize( + "input_string,expected_result", + [ + ('["a", "b", "c"]', ["a", "b", "c"]), # valid stringified json array + ('[" a ", " b", "c "]', ["a", "b", "c"]), # valid stringified json array with whitespace + ('["a", "b", "c",]', ["a", "b", "c"]), # valid stringified json array with trailing comma + ("['a', 'b', 'c']", ["a", "b", "c"]), # valid python-style list with single quotes + ("['a', 'b', 'c', ]", ["a", "b", "c"]), # valid python-style list with trailing comma + ("simple string ", ["simple string"]), # simple string with whitespace + ], +) +def test_parse_list_string_scenarios(input_string, expected_result): + result = parse_list_string(input_string) + assert result == expected_result From f291033e6e1e0debdf31f10f732931c724370afe Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 4 Feb 2026 09:59:52 -0700 Subject: [PATCH 20/64] save progress --- .../src/data_designer/config/__init__.py | 6 + .../data_designer/config/column_configs.py | 11 +- .../src/data_designer/config/models.py | 65 ++- .../config/utils/visualization.py | 128 ++++++ .../column_generators/generators/image.py | 20 +- .../src/data_designer/engine/models/facade.py | 116 ++++- .../integrations/huggingface/client.py | 419 ++++++++++++++++++ pyproject.toml | 3 + uv.lock | 412 ++++++++++++++++- 9 files changed, 1148 insertions(+), 32 deletions(-) create mode 100644 packages/data-designer/src/data_designer/integrations/huggingface/client.py diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py index 0ebf06be..46122609 100644 --- a/packages/data-designer-config/src/data_designer/config/__init__.py +++ b/packages/data-designer-config/src/data_designer/config/__init__.py @@ -15,6 +15,7 @@ from data_designer.config.column_configs import ( # noqa: F401 EmbeddingColumnConfig, ExpressionColumnConfig, + ImageGenerationColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -34,7 +35,9 @@ ToolConfig, ) from data_designer.config.models import ( # noqa: F401 + ChatCompletionImageInferenceParams, ChatCompletionInferenceParams, + DiffusionImageInferenceParams, EmbeddingInferenceParams, GenerationType, ImageContext, @@ -117,6 +120,7 @@ # column_configs "EmbeddingColumnConfig": (_MOD_COLUMN_CONFIGS, "EmbeddingColumnConfig"), "ExpressionColumnConfig": (_MOD_COLUMN_CONFIGS, "ExpressionColumnConfig"), + "ImageGenerationColumnConfig": (_MOD_COLUMN_CONFIGS, "ImageGenerationColumnConfig"), "LLMCodeColumnConfig": (_MOD_COLUMN_CONFIGS, "LLMCodeColumnConfig"), "LLMJudgeColumnConfig": (_MOD_COLUMN_CONFIGS, "LLMJudgeColumnConfig"), "LLMStructuredColumnConfig": (_MOD_COLUMN_CONFIGS, "LLMStructuredColumnConfig"), @@ -138,7 +142,9 @@ "MCPProvider": (_MOD_MCP, "MCPProvider"), "ToolConfig": (_MOD_MCP, "ToolConfig"), # models + "ChatCompletionImageInferenceParams": (_MOD_MODELS, "ChatCompletionImageInferenceParams"), "ChatCompletionInferenceParams": (_MOD_MODELS, "ChatCompletionInferenceParams"), + "DiffusionImageInferenceParams": (_MOD_MODELS, "DiffusionImageInferenceParams"), "EmbeddingInferenceParams": (_MOD_MODELS, "EmbeddingInferenceParams"), "GenerationType": (_MOD_MODELS, "GenerationType"), "ImageContext": (_MOD_MODELS, "ImageContext"), diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index ee5efa80..9e1f5737 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -480,7 +480,14 @@ def side_effect_columns(self) -> list[str]: class ImageGenerationColumnConfig(SingleColumnConfig): """Configuration for image generation columns. - Image columns generate images using a specified model. + Image columns generate images using either autoregressive or diffusion models. + The API used is automatically determined by the model's inference parameters: + + - **Autoregressive models** (ChatCompletionImageInferenceParams): + GPT-5, gpt-image-*, Gemini image generation models via chat completions API + + - **Diffusion models** (DiffusionImageInferenceParams): + DALL-E, Imagen, Stable Diffusion via image_generation API Attributes: column_type: Discriminator field, always "image-generation" for this configuration type. @@ -505,7 +512,7 @@ def required_columns(self) -> list[str]: Returns: List of unique column names referenced in Jinja2 templates. """ - return list(extract_keywords_from_jinja2_template(self.expr)) + return list(extract_keywords_from_jinja2_template(self.prompt)) @model_validator(mode="after") def assert_prompt_valid_jinja(self) -> Self: diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index 5e9b3518..203ddbdb 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -242,7 +242,8 @@ def sample(self) -> float: class GenerationType(str, Enum): CHAT_COMPLETION = "chat-completion" EMBEDDING = "embedding" - IMAGE_GENERATION = "image-generation" + CHAT_COMPLETION_IMAGE = "chat-completion-image" + DIFFUSION_IMAGE = "diffusion-image" class BaseInferenceParams(ConfigBase, ABC): @@ -415,23 +416,64 @@ def generate_kwargs(self) -> dict[str, float | int]: return result -class ImageGenerationInferenceParams(BaseInferenceParams): - generation_type: Literal[GenerationType.IMAGE_GENERATION] = GenerationType.IMAGE_GENERATION +class ChatCompletionImageInferenceParams(BaseInferenceParams): + """Configuration for image generation using autoregressive models via chat completions API. + + Uses the standard chat completions API for autoregressive multimodal models + that can generate images (GPT-5, gpt-image-*, Gemini image generation, etc.). + + Attributes: + generation_type: Type of generation, always "chat-completion-image" for this class. + quality: Optional quality setting for image generation (e.g., "standard", "hd"). + size: Optional size specification for generated images (e.g., "1024x1024", "1792x1024"). + """ + + generation_type: Literal[GenerationType.CHAT_COMPLETION_IMAGE] = GenerationType.CHAT_COMPLETION_IMAGE + quality: str | None = None + size: str | None = None + + @property + def generate_kwargs(self) -> dict[str, Any]: + result = super().generate_kwargs + if self.quality is not None: + result["quality"] = self.quality + if self.size is not None: + result["size"] = self.size + return result + + +class DiffusionImageInferenceParams(BaseInferenceParams): + """Configuration for image generation using diffusion models via image_generation API. + + Uses the legacy image_generation API for diffusion models like DALL-E, Imagen, + and Stable Diffusion. + + Attributes: + generation_type: Type of generation, always "diffusion-image" for this class. + quality: Quality setting for image generation (e.g., "standard", "hd"). + size: Size specification for generated images (e.g., "1024x1024", "1792x1024"). + output_format: Format of the output ("url" or "base64"). Default: "base64". + """ + + generation_type: Literal[GenerationType.DIFFUSION_IMAGE] = GenerationType.DIFFUSION_IMAGE quality: str size: str - output_format: ModalityDataType | None = ModalityDataType.BASE64 + output_format: ModalityDataType = ModalityDataType.BASE64 @property def generate_kwargs(self) -> dict[str, Any]: result = super().generate_kwargs result["size"] = self.size result["quality"] = self.quality - result["response_format"] = "b64_json" if self.output_format == ModalityDataType.BASE64 else self.output_format + result["response_format"] = "b64_json" if self.output_format == ModalityDataType.BASE64 else "url" return result InferenceParamsT: TypeAlias = Annotated[ - ChatCompletionInferenceParams | EmbeddingInferenceParams | ImageGenerationInferenceParams, + ChatCompletionInferenceParams + | EmbeddingInferenceParams + | ChatCompletionImageInferenceParams + | DiffusionImageInferenceParams, Field(discriminator="generation_type"), ] @@ -464,8 +506,15 @@ def generation_type(self) -> GenerationType: def _convert_inference_parameters(cls, value: Any) -> Any: """Convert raw dict to appropriate inference parameters type based on field presence.""" if isinstance(value, dict): - # Infer type from presence of embedding-specific fields - if "encoding_format" in value or "dimensions" in value: + # Check for explicit generation_type first + gen_type = value.get("generation_type") + + # Infer type from generation_type or field presence + if gen_type == "chat-completion-image": + return ChatCompletionImageInferenceParams(**value) + elif gen_type == "diffusion-image": + return DiffusionImageInferenceParams(**value) + elif gen_type == "embedding" or "encoding_format" in value or "dimensions" in value: return EmbeddingInferenceParams(**value) else: return ChatCompletionInferenceParams(**value) diff --git a/packages/data-designer-config/src/data_designer/config/utils/visualization.py b/packages/data-designer-config/src/data_designer/config/utils/visualization.py index 7e5c79a9..38189068 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/visualization.py +++ b/packages/data-designer-config/src/data_designer/config/utils/visualization.py @@ -3,6 +3,8 @@ from __future__ import annotations +import base64 +import io import json import os from collections import OrderedDict @@ -39,6 +41,93 @@ console = Console() +def _is_base64_image(value: str) -> bool: + """Check if a string is base64-encoded image data.""" + if not isinstance(value, str): + return False + # Check if it starts with data URI scheme + if value.startswith("...") or plain base64 + + Returns: + Base64 string without data URI prefix + + Raises: + ModelAPIError: If data URI format is invalid + """ + if data.startswith("data:image/"): + # Extract base64 portion after comma + if "," in data: + return data.split(",", 1)[1] + else: + raise ModelAPIError("Invalid data URI format: missing comma separator") + + # Already plain base64 + return data + + def _download_url_to_base64(self, url: str) -> str: + """Download image from URL and convert to base64. + + Args: + url: Image URL + + Returns: + Base64-encoded image string + + Raises: + ModelAPIError: If download fails + """ + import base64 + + from data_designer.lazy_heavy_imports import httpx + + try: + with httpx.Client(timeout=30.0) as client: + response = client.get(url) + response.raise_for_status() + image_bytes = response.content + return base64.b64encode(image_bytes).decode("utf-8") + except Exception as e: + raise ModelAPIError(f"Failed to download image from URL {url}: {e}") from e diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py b/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py new file mode 100644 index 00000000..ad7ef0d5 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.engine.storage.image_storage import ImageFormat, ImageStorageManager + +__all__ = ["ImageFormat", "ImageStorageManager"] diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/image_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/image_storage.py new file mode 100644 index 00000000..d632bbc1 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/storage/image_storage.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import base64 +import uuid +from enum import Enum +from pathlib import Path + + +class ImageFormat(str, Enum): + """Supported image formats.""" + + PNG = "png" + JPEG = "jpeg" + JPG = "jpg" + WEBP = "webp" + + +class ImageStorageManager: + """Manages disk storage of generated images. + + Handles: + - Creating images directory + - Decoding base64 to bytes + - Detecting image format + - Saving with UUID filenames + - Returning relative paths + """ + + def __init__(self, base_path: Path, images_subdir: str = "images", validate_images: bool = True) -> None: + """Initialize image storage manager. + + Args: + base_path: Base directory for dataset + images_subdir: Subdirectory name for images (default: "images") + validate_images: Whether to validate images after saving (default: True) + """ + self.base_path = Path(base_path) + self.images_dir = self.base_path / images_subdir + self.images_subdir = images_subdir + self.validate_images = validate_images + self._ensure_images_directory() + + def _ensure_images_directory(self) -> None: + """Create images directory if it doesn't exist.""" + self.images_dir.mkdir(parents=True, exist_ok=True) + + def save_base64_image(self, base64_data: str) -> str: + """Save base64 image to disk and return relative path. + + Args: + base64_data: Base64 encoded image string (with or without data URI prefix) + + Returns: + Relative path to saved image (e.g., "images/f47ac10b-58cc.png") + + Raises: + ValueError: If base64 data is invalid + OSError: If disk write fails + """ + # Decode base64 to bytes + image_bytes = self._decode_base64(base64_data) + + # Detect format + image_format = self._detect_format(image_bytes) + + # Generate unique filename + image_id = uuid.uuid4() + filename = f"{image_id}.{image_format.value}" + full_path = self.images_dir / filename + relative_path = f"{self.images_subdir}/{filename}" + + # Write to disk + with open(full_path, "wb") as f: + f.write(image_bytes) + + # Optional validation + if self.validate_images: + self._validate_image(full_path) + + return relative_path + + def _decode_base64(self, base64_data: str) -> bytes: + """Decode base64 string to bytes. + + Args: + base64_data: Base64 string (with or without data URI prefix) + + Returns: + Decoded bytes + + Raises: + ValueError: If base64 data is invalid + """ + # Remove data URI prefix if present (e.g., "data:image/png;base64,") + if base64_data.startswith("data:"): + if "," in base64_data: + base64_data = base64_data.split(",", 1)[1] + else: + raise ValueError("Invalid data URI format: missing comma separator") + + try: + return base64.b64decode(base64_data, validate=True) + except Exception as e: + raise ValueError(f"Invalid base64 data: {e}") from e + + def _detect_format(self, image_bytes: bytes) -> ImageFormat: + """Detect image format from bytes. + + Args: + image_bytes: Image data as bytes + + Returns: + Detected format (defaults to PNG if unknown) + """ + # Check magic bytes first (fast) + if image_bytes.startswith(b"\x89PNG\r\n\x1a\n"): + return ImageFormat.PNG + elif image_bytes.startswith(b"\xff\xd8\xff"): + return ImageFormat.JPG + elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]: + return ImageFormat.WEBP + + # Fallback to PIL for robust detection + try: + import io + + from PIL import Image + + img = Image.open(io.BytesIO(image_bytes)) + format_str = img.format.lower() if img.format else None + if format_str in ["png", "jpeg", "jpg", "webp"]: + return ImageFormat(format_str if format_str != "jpeg" else "jpg") + except Exception: + pass + + # Default to PNG + return ImageFormat.PNG + + def _validate_image(self, image_path: Path) -> None: + """Validate that saved image is readable. + + Args: + image_path: Path to image file + + Raises: + ValueError: If image is corrupted or unreadable + """ + try: + from PIL import Image + + with Image.open(image_path) as img: + img.verify() + except Exception as e: + # Clean up invalid file + image_path.unlink(missing_ok=True) + raise ValueError(f"Saved image is invalid or corrupted: {e}") from e + + def cleanup(self) -> None: + """Clean up image directory (for preview mode).""" + import shutil + + if self.images_dir.exists(): + shutil.rmtree(self.images_dir) diff --git a/packages/data-designer/src/data_designer/integrations/huggingface/client.py b/packages/data-designer/src/data_designer/integrations/huggingface/client.py index c047d73b..2e84ee3c 100644 --- a/packages/data-designer/src/data_designer/integrations/huggingface/client.py +++ b/packages/data-designer/src/data_designer/integrations/huggingface/client.py @@ -66,6 +66,7 @@ def upload_dataset( Uploads the complete dataset including: - Main parquet batch files from parquet-files/ β†’ data/ + - Images from images/ β†’ images/ (if present) - Processor output batch files from processors-files/{name}/ β†’ {name}/ - Existing builder_config.json and metadata.json files - Auto-generated README.md (dataset card) @@ -102,6 +103,7 @@ def upload_dataset( raise HuggingFaceHubClientUploadError(f"Failed to upload dataset card: {e}") from e self._upload_main_dataset_files(repo_id=repo_id, parquet_folder=base_dataset_path / FINAL_DATASET_FOLDER_NAME) + self._upload_images_folder(repo_id=repo_id, images_folder=base_dataset_path / "images") self._upload_processor_files( repo_id=repo_id, processors_folder=base_dataset_path / PROCESSORS_OUTPUTS_FOLDER_NAME ) @@ -178,6 +180,36 @@ def _upload_main_dataset_files(self, repo_id: str, parquet_folder: Path) -> None except Exception as e: raise HuggingFaceHubClientUploadError(f"Failed to upload parquet files: {e}") from e + def _upload_images_folder(self, repo_id: str, images_folder: Path) -> None: + """Upload images folder to Hugging Face Hub. + + Args: + repo_id: Hugging Face dataset repo ID + images_folder: Path to images folder + + Raises: + HuggingFaceUploadError: If upload fails + """ + if not images_folder.exists(): + return + + image_files = list(images_folder.glob("*")) + if not image_files: + return + + logger.info(f" |-- {RandomEmoji.loading()} Uploading {len(image_files)} images...") + + try: + self._api.upload_folder( + repo_id=repo_id, + folder_path=str(images_folder), + path_in_repo="images", + repo_type="dataset", + commit_message="Upload images", + ) + except Exception as e: + raise HuggingFaceHubClientUploadError(f"Failed to upload images: {e}") from e + def _upload_processor_files(self, repo_id: str, processors_folder: Path) -> None: """Upload processor output files. From ed9787bf297a5a57c90f5b58ebd049a2fbe07cae Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 6 Feb 2026 10:17:21 -0700 Subject: [PATCH 22/64] support generation of multiple images --- .../src/data_designer/config/models.py | 14 +++- .../config/utils/visualization.py | 24 ++++++- .../column_generators/generators/image.py | 20 +++--- .../src/data_designer/engine/models/facade.py | 72 ++++++++++--------- 4 files changed, 84 insertions(+), 46 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index dc3533e3..3dab2d8d 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -435,8 +435,11 @@ class ImageInferenceParams(BaseInferenceParams): - Preview mode: Images stored as base64 directly in dataframe Common parameters like quality and size are provided as optional fields. - For model-specific parameters, use the `extra_body` field inherited from - BaseInferenceParams. + For model-specific parameters (including n for number of images), use the `extra_body` + field inherited from BaseInferenceParams. + + If the API returns multiple images (either from prompt or API parameters), all images + will be stored as a list in the dataframe. Attributes: generation_type: Type of generation, always "image" for this class. @@ -451,6 +454,13 @@ class ImageInferenceParams(BaseInferenceParams): size="1024x1024" ) + # Generate multiple images using extra_body + dd.ImageInferenceParams( + quality="hd", + size="1024x1024", + extra_body={"n": 3} # Request 3 images from API + ) + # With model-specific params via extra_body dd.ImageInferenceParams( quality="hd", diff --git a/packages/data-designer-config/src/data_designer/config/utils/visualization.py b/packages/data-designer-config/src/data_designer/config/utils/visualization.py index 56d28fd3..62b57f5e 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/visualization.py +++ b/packages/data-designer-config/src/data_designer/config/utils/visualization.py @@ -394,7 +394,28 @@ def display_sample_record( if col.drop: continue image_data = record[col.name] - if _is_base64_image(image_data): + + # Handle list of images + if isinstance(image_data, list): + previews = [] + for idx, img in enumerate(image_data): + if _is_base64_image(img): + previews.append(f"[{idx}] ") + if in_notebook: + images_to_display_later.append((f"{col.name}[{idx}]", img)) + elif _is_image_url(img): + previews.append(f"[{idx}] ") + if in_notebook: + images_to_display_later.append((f"{col.name}[{idx}]", img)) + elif _is_image_path(img): + previews.append(f"[{idx}] ") + if in_notebook: + images_to_display_later.append((f"{col.name}[{idx}]", img)) + else: + previews.append(f"[{idx}] {str(img)[:30]}") + preview = "\n".join(previews) if previews else "" + # Handle single image (backwards compatibility) + elif _is_base64_image(image_data): preview = f"" if in_notebook: images_to_display_later.append((col.name, image_data)) @@ -408,6 +429,7 @@ def display_sample_record( images_to_display_later.append((col.name, image_data)) else: preview = str(image_data)[:100] + "..." if len(str(image_data)) > 100 else str(image_data) + table.add_row(col.name, preview) render_list.append(pad_console_element(table)) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py index 2d24fc2d..db3c9c9e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py @@ -36,13 +36,13 @@ def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.CELL_BY_CELL def generate(self, data: dict) -> dict: - """Generate image and optionally save to disk. + """Generate image(s) and optionally save to disk. Args: data: Record data Returns: - Record with image path (create mode) or base64 data (preview mode) added + Record with image path(s) (create mode) or base64 data (preview mode) added """ deserialized_record = deserialize_json_values(data) @@ -63,16 +63,18 @@ def generate(self, data: dict) -> dict: if not prompt or not prompt.strip(): raise ValueError(f"Rendered prompt for column {self.config.name!r} is empty") - # Generate image (returns base64 string) - base64_image = self.model.generate_image(prompt=prompt) + # Generate images (returns list of base64 strings) + base64_images = self.model.generate_image(prompt=prompt) # Store in dataframe based on mode if self.image_storage_manager: - # Create mode: save to disk and store relative path - relative_path = self.image_storage_manager.save_base64_image(base64_image) - data[self.config.name] = relative_path + # Create mode: save each image to disk and store list of relative paths + relative_paths = [ + self.image_storage_manager.save_base64_image(base64_image) for base64_image in base64_images + ] + data[self.config.name] = relative_paths else: - # Preview mode: store base64 directly - data[self.config.name] = base64_image + # Preview mode: store list of base64 strings directly + data[self.config.name] = base64_images return data diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 1abd235b..b78d2e1e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -163,22 +163,23 @@ def generate_text_embeddings( self._track_usage_from_embedding(response) @catch_llm_exceptions - def generate_image(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> str: - """Generate image and return base64-encoded data. + def generate_image(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]: + """Generate image(s) and return base64-encoded data. Automatically detects the appropriate API based on model name: - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) β†’ image_generation API - All other models β†’ chat/completions API (default) - Both paths return base64-encoded image data. + Both paths return base64-encoded image data. If the API returns multiple images, + all are returned in the list. Args: prompt: The prompt for image generation skip_usage_tracking: Whether to skip usage tracking - **kwargs: Additional arguments to pass to the model + **kwargs: Additional arguments to pass to the model (including n=number of images) Returns: - Base64-encoded image string (without data URI prefix) + List of base64-encoded image strings (without data URI prefix) Raises: ModelAPIError: If image generation fails or returns invalid data @@ -214,11 +215,11 @@ def _is_diffusion_model(self) -> bool: ] return any(pattern in model_lower for pattern in diffusion_patterns) - def _generate_image_chat_completion(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> str: - """Generate image using autoregressive model via chat completions API. + def _generate_image_chat_completion(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]: + """Generate image(s) using autoregressive model via chat completions API. Returns: - Base64-encoded image string + List of base64-encoded image strings """ kwargs = self.consolidate_kwargs(**kwargs) messages = [ChatMessage.as_user(content=prompt)] @@ -232,7 +233,7 @@ def _generate_image_chat_completion(self, prompt: str, skip_usage_tracking: bool ) logger.debug( - f"Received image from autoregressive model {self.model_name!r}", + f"Received image(s) from autoregressive model {self.model_name!r}", extra={"model": self.model_name, "response": response}, ) @@ -241,42 +242,45 @@ def _generate_image_chat_completion(self, prompt: str, skip_usage_tracking: bool raise ModelAPIError("Response missing choices") message = response.choices[0].message + images = [] # Extract base64 from images attribute (primary path) if hasattr(message, "images") and message.images: - first_image = message.images[0] - - # Handle different response formats - if isinstance(first_image, dict) and "image_url" in first_image: - image_url = first_image["image_url"] - - if isinstance(image_url, dict) and "url" in image_url: - url = image_url["url"] - return self._extract_base64_from_data_uri(url) - elif isinstance(image_url, str): - return self._extract_base64_from_data_uri(image_url) - - # Fallback: treat as base64 string - if isinstance(first_image, str): - return self._extract_base64_from_data_uri(first_image) + for image in message.images: + # Handle different response formats + if isinstance(image, dict) and "image_url" in image: + image_url = image["image_url"] + + if isinstance(image_url, dict) and "url" in image_url: + url = image_url["url"] + images.append(self._extract_base64_from_data_uri(url)) + elif isinstance(image_url, str): + images.append(self._extract_base64_from_data_uri(image_url)) + # Fallback: treat as base64 string + elif isinstance(image, str): + images.append(self._extract_base64_from_data_uri(image)) # Fallback: check content field - content = message.content or "" - if content: - return self._extract_base64_from_data_uri(content) + if not images: + content = message.content or "" + if content: + images.append(self._extract_base64_from_data_uri(content)) + + if not images: + raise ModelAPIError("No image data found in response") - raise ModelAPIError("No image data found in response") + return images except Exception: raise - def _generate_image_diffusion(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> str: - """Generate image using diffusion model via image_generation API. + def _generate_image_diffusion(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]: + """Generate image(s) using diffusion model via image_generation API. Always returns base64. The API is configured to return base64 format. Returns: - Base64-encoded image string + List of base64-encoded image strings """ kwargs = self.consolidate_kwargs(**kwargs) @@ -289,7 +293,7 @@ def _generate_image_diffusion(self, prompt: str, skip_usage_tracking: bool = Fal response = self._router.image_generation(prompt=prompt, model=self.model_name, **kwargs) logger.debug( - f"Received image from diffusion model {self.model_name!r}", + f"Received {len(response.data)} image(s) from diffusion model {self.model_name!r}", extra={"model": self.model_name, "response": response}, ) @@ -297,8 +301,8 @@ def _generate_image_diffusion(self, prompt: str, skip_usage_tracking: bool = Fal if not response.data or len(response.data) == 0: raise ModelAPIError("Image generation returned no data") - # Return base64 data - return response.data[0].b64_json + # Return all images as list + return [img.b64_json for img in response.data] except Exception: raise From 7dea87a0e75e242951751cf3af14e94aac77eb46 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 6 Feb 2026 10:28:40 -0700 Subject: [PATCH 23/64] clean up visualization --- .../config/utils/image_helpers.py | 110 ++++++++++++++++++ .../config/utils/visualization.py | 94 +++------------ 2 files changed, 124 insertions(+), 80 deletions(-) create mode 100644 packages/data-designer-config/src/data_designer/config/utils/image_helpers.py diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py new file mode 100644 index 00000000..a32714d3 --- /dev/null +++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Helper utilities for working with images.""" + +from __future__ import annotations + +import base64 +from pathlib import Path + +from data_designer.config.models import ImageFormat + + +def is_image_path(value: str) -> bool: + """Check if a string is an image file path. + + Args: + value: String to check + + Returns: + True if the string looks like an image file path, False otherwise + """ + if not isinstance(value, str): + return False + return any(value.lower().endswith(ext) for ext in get_supported_image_extensions()) + + +def is_base64_image(value: str) -> bool: + """Check if a string is base64-encoded image data. + + Args: + value: String to check + + Returns: + True if the string looks like base64-encoded image data, False otherwise + """ + if not isinstance(value, str): + return False + # Check if it starts with data URI scheme + if value.startswith("..." and returns + just the base64 portion. + + Args: + data: Data URI (e.g., "") or plain base64 + + Returns: + Base64 string without data URI prefix + + Raises: + ValueError: If data URI format is invalid + """ + if data.startswith("data:"): + if "," in data: + return data.split(",", 1)[1] + raise ValueError("Invalid data URI format: missing comma separator") + return data + + +def decode_base64_image(base64_data: str) -> bytes: + """Decode base64 string to image bytes. + + Automatically handles data URIs by extracting the base64 portion first. + + Args: + base64_data: Base64 string (with or without data URI prefix) + + Returns: + Decoded image bytes + + Raises: + ValueError: If base64 data is invalid + """ + # Remove data URI prefix if present + base64_data = extract_base64_from_data_uri(base64_data) + + try: + return base64.b64decode(base64_data, validate=True) + except Exception as e: + raise ValueError(f"Invalid base64 data: {e}") from e + + +def detect_image_format(image_bytes: bytes) -> ImageFormat: + """Detect image format from bytes. + + Uses magic bytes for fast detection, falls back to PIL for robust detection. + + Args: + image_bytes: Image data as bytes + + Returns: + Detected format (defaults to PNG if unknown) + """ + # Check magic bytes first (fast) + if image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.PNG]): + return ImageFormat.PNG + elif image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.JPG]): + return ImageFormat.JPG + elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]: + return ImageFormat.WEBP + + # Fallback to PIL for robust detection + try: + img = PIL.Image.open(io.BytesIO(image_bytes)) + format_str = img.format.lower() if img.format else None + if format_str in ["png", "jpeg", "jpg", "webp"]: + return ImageFormat(format_str if format_str != "jpeg" else "jpg") + except Exception: + pass + + # Default to PNG + return ImageFormat.PNG def is_image_path(value: str) -> bool: diff --git a/packages/data-designer-config/src/data_designer/config/utils/visualization.py b/packages/data-designer-config/src/data_designer/config/utils/visualization.py index dd819aa3..c349ec86 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/visualization.py +++ b/packages/data-designer-config/src/data_designer/config/utils/visualization.py @@ -29,16 +29,18 @@ from data_designer.config.utils.constants import NVIDIA_API_KEY_ENV_VAR_NAME, OPENAI_API_KEY_ENV_VAR_NAME from data_designer.config.utils.errors import DatasetSampleDisplayError from data_designer.config.utils.image_helpers import ( + extract_base64_from_data_uri, is_base64_image, is_image_path, is_image_url, load_image_path_to_base64, ) -from data_designer.lazy_heavy_imports import np, pd +from data_designer.lazy_heavy_imports import PIL, np, pd if TYPE_CHECKING: import numpy as np import pandas as pd + import PIL from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.dataset_metadata import DatasetMetadata @@ -64,7 +66,6 @@ def _display_image_if_in_notebook( try: # Check if we're in a Jupyter environment from IPython.display import HTML, display - from PIL import Image as PILImage get_ipython() # This will raise NameError if not in IPython/Jupyter @@ -77,23 +78,21 @@ def _display_image_if_in_notebook( ) return False base64_data = loaded_base64 - # Decode the image - elif image_data.startswith("" + result = extract_base64_from_data_uri(data_uri) + assert result == "iVBORw0KGgoAAAANS" + + +def test_extract_base64_plain_base64_without_prefix(): + plain_base64 = "iVBORw0KGgoAAAANS" + result = extract_base64_from_data_uri(plain_base64) + assert result == plain_base64 + + +def test_extract_base64_invalid_data_uri_raises_error(): + with pytest.raises(ValueError, match="Invalid data URI format: missing comma separator"): + extract_base64_from_data_uri("data:image/png;base64") + + +# Tests for decode_base64_image + + +def test_decode_base64_image_valid(): + png_bytes = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" + base64_data = base64.b64encode(png_bytes).decode() + result = decode_base64_image(base64_data) + assert result == png_bytes + + +def test_decode_base64_image_with_data_uri(): + png_bytes = b"\x89PNG\r\n\x1a\n" + base64_data = base64.b64encode(png_bytes).decode() + data_uri = f"data:image/png;base64,{base64_data}" + result = decode_base64_image(data_uri) + assert result == png_bytes + + +def test_decode_base64_image_invalid_raises_error(): + with pytest.raises(ValueError, match="Invalid base64 data"): + decode_base64_image("not-valid-base64!!!") + + +# Tests for detect_image_format + + +def test_detect_image_format_png(): + png_magic = b"\x89PNG\r\n\x1a\n" + b"\x00" * 10 + assert detect_image_format(png_magic) == ImageFormat.PNG + + +def test_detect_image_format_jpg(): + jpg_magic = b"\xff\xd8\xff" + b"\x00" * 10 + assert detect_image_format(jpg_magic) == ImageFormat.JPG + + +def test_detect_image_format_webp(): + webp_magic = b"RIFF" + b"\x00" * 4 + b"WEBP" + assert detect_image_format(webp_magic) == ImageFormat.WEBP + + +def test_detect_image_format_unknown_defaults_to_png(): + unknown_bytes = b"\x00\x00\x00\x00" + b"\x00" * 10 + assert detect_image_format(unknown_bytes) == ImageFormat.PNG + + +# Tests for is_image_path + + +def test_is_image_path_various_extensions(): + assert is_image_path("/path/to/image.png") is True + assert is_image_path("image.PNG") is True + assert is_image_path("image.jpg") is True + assert is_image_path("image.jpeg") is True + + +def test_is_image_path_non_image(): + assert is_image_path("/path/to/file.txt") is False + assert is_image_path("document.pdf") is False + + +def test_is_image_path_extension_in_directory(): + assert is_image_path("/some.png/file.txt") is False + + +# Tests for is_base64_image + + +def test_is_base64_image_data_uri(): + assert is_base64_image("") is True + + +def test_is_base64_image_long_valid_base64(): + long_base64 = base64.b64encode(b"x" * 100).decode() + assert is_base64_image(long_base64) is True + + +def test_is_base64_image_short_string(): + assert is_base64_image("short") is False + + +# Tests for is_image_url + + +def test_is_image_url_http_and_https(): + assert is_image_url("http://example.com/image.png") is True + assert is_image_url("https://example.com/photo.jpg") is True + + +def test_is_image_url_with_query_params(): + assert is_image_url("https://example.com/image.png?size=large") is True + + +def test_is_image_url_without_image_extension(): + assert is_image_url("https://example.com/page.html") is False + + +def test_is_image_url_non_http(): + assert is_image_url("ftp://example.com/image.png") is False + + +# Tests for get_supported_image_extensions + + +def test_get_supported_image_extensions_matches_enum(): + result = get_supported_image_extensions() + enum_values = [f".{fmt.value}" for fmt in ImageFormat] + assert set(result) == set(enum_values) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index b78d2e1e..d13273f4 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any from data_designer.config.models import GenerationType, ModelConfig, ModelProvider +from data_designer.config.utils.image_helpers import extract_base64_from_data_uri from data_designer.engine.mcp.errors import MCPConfigurationError from data_designer.engine.model_provider import ModelProviderRegistry from data_designer.engine.models.errors import ( @@ -38,6 +39,16 @@ def _identity(x: Any) -> Any: logger = logging.getLogger(__name__) +# Patterns for detecting diffusion-based image generation models +DIFFUSION_MODEL_PATTERNS = [ + "dall-e", + "dalle", + "stable-diffusion", + "sd-", + "sd_", + "imagen", +] + class ModelFacade: def __init__( @@ -205,15 +216,7 @@ def _is_diffusion_model(self) -> bool: True if model is detected as diffusion-based, False otherwise """ model_lower = self.model_name.lower() - diffusion_patterns = [ - "dall-e", - "dalle", - "stable-diffusion", - "sd-", - "sd_", - "imagen", - ] - return any(pattern in model_lower for pattern in diffusion_patterns) + return any(pattern in model_lower for pattern in DIFFUSION_MODEL_PATTERNS) def _generate_image_chat_completion(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]: """Generate image(s) using autoregressive model via chat completions API. @@ -253,18 +256,18 @@ def _generate_image_chat_completion(self, prompt: str, skip_usage_tracking: bool if isinstance(image_url, dict) and "url" in image_url: url = image_url["url"] - images.append(self._extract_base64_from_data_uri(url)) + images.append(extract_base64_from_data_uri(url)) elif isinstance(image_url, str): - images.append(self._extract_base64_from_data_uri(image_url)) + images.append(extract_base64_from_data_uri(image_url)) # Fallback: treat as base64 string elif isinstance(image, str): - images.append(self._extract_base64_from_data_uri(image)) + images.append(extract_base64_from_data_uri(image)) # Fallback: check content field if not images: content = message.content or "" if content: - images.append(self._extract_base64_from_data_uri(content)) + images.append(extract_base64_from_data_uri(content)) if not images: raise ModelAPIError("No image data found in response") @@ -535,28 +538,6 @@ def _track_usage_from_image_diffusion(self, response: litellm.types.utils.ImageR request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) - def _extract_base64_from_data_uri(self, data: str) -> str: - """Extract base64 data from data URI or return as-is. - - Args: - data: Data URI (e.g., "...") or plain base64 - - Returns: - Base64 string without data URI prefix - - Raises: - ModelAPIError: If data URI format is invalid - """ - if data.startswith("data:image/"): - # Extract base64 portion after comma - if "," in data: - return data.split(",", 1)[1] - else: - raise ModelAPIError("Invalid data URI format: missing comma separator") - - # Already plain base64 - return data - def _download_url_to_base64(self, url: str) -> str: """Download image from URL and convert to base64. diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/image_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/image_storage.py index d632bbc1..22d4bf84 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/image_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/image_storage.py @@ -3,19 +3,15 @@ from __future__ import annotations -import base64 import uuid -from enum import Enum from pathlib import Path +from typing import TYPE_CHECKING +from data_designer.config.utils.image_helpers import decode_base64_image, detect_image_format +from data_designer.lazy_heavy_imports import PIL -class ImageFormat(str, Enum): - """Supported image formats.""" - - PNG = "png" - JPEG = "jpeg" - JPG = "jpg" - WEBP = "webp" +if TYPE_CHECKING: + import PIL class ImageStorageManager: @@ -61,10 +57,10 @@ def save_base64_image(self, base64_data: str) -> str: OSError: If disk write fails """ # Decode base64 to bytes - image_bytes = self._decode_base64(base64_data) + image_bytes = decode_base64_image(base64_data) # Detect format - image_format = self._detect_format(image_bytes) + image_format = detect_image_format(image_bytes) # Generate unique filename image_id = uuid.uuid4() @@ -82,63 +78,6 @@ def save_base64_image(self, base64_data: str) -> str: return relative_path - def _decode_base64(self, base64_data: str) -> bytes: - """Decode base64 string to bytes. - - Args: - base64_data: Base64 string (with or without data URI prefix) - - Returns: - Decoded bytes - - Raises: - ValueError: If base64 data is invalid - """ - # Remove data URI prefix if present (e.g., "data:image/png;base64,") - if base64_data.startswith("data:"): - if "," in base64_data: - base64_data = base64_data.split(",", 1)[1] - else: - raise ValueError("Invalid data URI format: missing comma separator") - - try: - return base64.b64decode(base64_data, validate=True) - except Exception as e: - raise ValueError(f"Invalid base64 data: {e}") from e - - def _detect_format(self, image_bytes: bytes) -> ImageFormat: - """Detect image format from bytes. - - Args: - image_bytes: Image data as bytes - - Returns: - Detected format (defaults to PNG if unknown) - """ - # Check magic bytes first (fast) - if image_bytes.startswith(b"\x89PNG\r\n\x1a\n"): - return ImageFormat.PNG - elif image_bytes.startswith(b"\xff\xd8\xff"): - return ImageFormat.JPG - elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]: - return ImageFormat.WEBP - - # Fallback to PIL for robust detection - try: - import io - - from PIL import Image - - img = Image.open(io.BytesIO(image_bytes)) - format_str = img.format.lower() if img.format else None - if format_str in ["png", "jpeg", "jpg", "webp"]: - return ImageFormat(format_str if format_str != "jpeg" else "jpg") - except Exception: - pass - - # Default to PNG - return ImageFormat.PNG - def _validate_image(self, image_path: Path) -> None: """Validate that saved image is readable. @@ -149,9 +88,7 @@ def _validate_image(self, image_path: Path) -> None: ValueError: If image is corrupted or unreadable """ try: - from PIL import Image - - with Image.open(image_path) as img: + with PIL.Image.open(image_path) as img: img.verify() except Exception as e: # Clean up invalid file From 0f07f7b9501aaee80e68b994b301ccd464391f05 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 6 Feb 2026 13:11:54 -0700 Subject: [PATCH 25/64] Streamline integration for image generation --- .../config/utils/image_helpers.py | 20 +- .../tests/config/utils/test_image_helpers.py | 184 +++++++++++++++++- .../column_generators/generators/image.py | 20 +- .../dataset_builders/artifact_storage.py | 22 ++- .../dataset_builders/column_wise_builder.py | 45 +++-- .../data_designer/engine/storage/__init__.py | 4 +- ...image_storage.py => multimedia_storage.py} | 33 ++-- .../generators/test_image.py | 121 ++++++++++++ .../tests/engine/storage/__init__.py | 2 + .../engine/storage/test_multimedia_storage.py | 182 +++++++++++++++++ 10 files changed, 583 insertions(+), 50 deletions(-) rename packages/data-designer-engine/src/data_designer/engine/storage/{image_storage.py => multimedia_storage.py} (80%) create mode 100644 packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py create mode 100644 packages/data-designer-engine/tests/engine/storage/__init__.py create mode 100644 packages/data-designer-engine/tests/engine/storage/test_multimedia_storage.py diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py index 48dacbae..1f5ec332 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py @@ -92,8 +92,8 @@ def detect_image_format(image_bytes: bytes) -> ImageFormat: try: img = PIL.Image.open(io.BytesIO(image_bytes)) format_str = img.format.lower() if img.format else None - if format_str in ["png", "jpeg", "jpg", "webp"]: - return ImageFormat(format_str if format_str != "jpeg" else "jpg") + if format_str in [ImageFormat.PNG, ImageFormat.JPG, ImageFormat.JPEG, ImageFormat.WEBP]: + return ImageFormat(format_str if format_str != ImageFormat.JPEG else ImageFormat.JPG) except Exception: pass @@ -191,6 +191,22 @@ def load_image_path_to_base64(image_path: str, base_path: str | None = None) -> return None +def validate_image(image_path: Path) -> None: + """Validate that an image file is readable and not corrupted. + + Args: + image_path: Path to image file + + Raises: + ValueError: If image is corrupted or unreadable + """ + try: + with PIL.Image.open(image_path) as img: + img.verify() + except Exception as e: + raise ValueError(f"Image validation failed: {e}") from e + + def get_supported_image_extensions() -> list[str]: """Get list of supported image extensions from ImageFormat enum. diff --git a/packages/data-designer-config/tests/config/utils/test_image_helpers.py b/packages/data-designer-config/tests/config/utils/test_image_helpers.py index 3d6683e4..9c7ccd7f 100644 --- a/packages/data-designer-config/tests/config/utils/test_image_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_image_helpers.py @@ -4,7 +4,14 @@ from __future__ import annotations import base64 - +import io +from typing import TYPE_CHECKING +from unittest.mock import Mock, patch + +# Explicitly import PIL.Image submodule to make it accessible as PIL.Image +# Python doesn't automatically import submodules when you import a package, +# so `import PIL` alone doesn't give you access to PIL.Image +import PIL.Image # noqa: E402 import pytest from data_designer.config.models import ImageFormat @@ -16,7 +23,13 @@ is_base64_image, is_image_path, is_image_url, + load_image_path_to_base64, + validate_image, ) +from data_designer.lazy_heavy_imports import PIL + +if TYPE_CHECKING: + import PIL # Tests for extract_base64_from_data_uri @@ -139,6 +152,39 @@ def test_is_image_url_non_http(): assert is_image_url("ftp://example.com/image.png") is False +# Tests for validate_image + + +def test_validate_image_valid_png(tmp_path): + # Create a valid 1x1 PNG using PIL + img = PIL.Image.new("RGB", (1, 1), color="red") + buf = io.BytesIO() + img.save(buf, format="PNG") + png_bytes = buf.getvalue() + + image_path = tmp_path / "test.png" + image_path.write_bytes(png_bytes) + + # Should not raise + validate_image(image_path) + + +def test_validate_image_corrupted_raises_error(tmp_path): + # Create an invalid image file + image_path = tmp_path / "corrupted.png" + image_path.write_bytes(b"not a valid image") + + with pytest.raises(ValueError, match="Image validation failed"): + validate_image(image_path) + + +def test_validate_image_nonexistent_raises_error(tmp_path): + image_path = tmp_path / "nonexistent.png" + + with pytest.raises(ValueError, match="Image validation failed"): + validate_image(image_path) + + # Tests for get_supported_image_extensions @@ -146,3 +192,139 @@ def test_get_supported_image_extensions_matches_enum(): result = get_supported_image_extensions() enum_values = [f".{fmt.value}" for fmt in ImageFormat] assert set(result) == set(enum_values) + + +# Additional tests for uncovered lines + + +def test_detect_image_format_with_pil_fallback_unsupported_format(tmp_path): + # Create a real GIF image that will trigger PIL fallback + # (GIF has different magic bytes not in our fast-path detection) + img = PIL.Image.new("RGB", (1, 1), color="red") + gif_path = tmp_path / "test.gif" + img.save(gif_path, format="GIF") + + gif_bytes = gif_path.read_bytes() + # Should use PIL fallback and default to PNG (GIF not in ImageFormat enum) + result = detect_image_format(gif_bytes) + assert result == ImageFormat.PNG + + +def test_detect_image_format_with_pil_fallback_jpeg(): + # Test PIL fallback path that converts "jpeg" format string to JPG enum + # Use mock since we can't easily create valid JPEG bytes without magic bytes + mock_img = Mock() + mock_img.format = "JPEG" + + # Use bytes that don't match our magic bytes to trigger PIL fallback + test_bytes = b"\x00\x00\x00\x00" + + with patch.object(PIL.Image, "open", return_value=mock_img): + result = detect_image_format(test_bytes) + # Should convert JPEG -> JPG via line 96 + assert result == ImageFormat.JPG + + +def test_is_image_path_non_string_input(): + assert is_image_path(123) is False + assert is_image_path(None) is False + assert is_image_path([]) is False + + +def test_is_base64_image_non_string_input(): + assert is_base64_image(123) is False + assert is_base64_image(None) is False + assert is_base64_image([]) is False + + +def test_is_base64_image_invalid_base64_decode(): + # String with valid base64 characters but incorrect padding that causes decode to fail + # Single '=' in middle of string is invalid base64 (padding only allowed at end) + invalid_base64 = "A" * 50 + "=" + "A" * 49 + "more text" + assert is_base64_image(invalid_base64) is False + + +def test_is_image_url_non_string_input(): + assert is_image_url(123) is False + assert is_image_url(None) is False + assert is_image_url([]) is False + + +# Tests for load_image_path_to_base64 + + +def test_load_image_path_to_base64_absolute_path(tmp_path): + # Create a test image file + img = PIL.Image.new("RGB", (1, 1), color="blue") + image_path = tmp_path / "test.png" + img.save(image_path) + + # Load with absolute path + result = load_image_path_to_base64(str(image_path)) + assert result is not None + assert len(result) > 0 + # Verify it's valid base64 + decoded = base64.b64decode(result) + assert len(decoded) > 0 + + +def test_load_image_path_to_base64_relative_with_base_path(tmp_path): + # Create a test image file + img = PIL.Image.new("RGB", (1, 1), color="green") + image_path = tmp_path / "subdir" / "test.png" + image_path.parent.mkdir(exist_ok=True) + img.save(image_path) + + # Load with relative path and base_path + result = load_image_path_to_base64("subdir/test.png", base_path=str(tmp_path)) + assert result is not None + assert len(result) > 0 + + +def test_load_image_path_to_base64_nonexistent_file(): + result = load_image_path_to_base64("/nonexistent/path/to/image.png") + assert result is None + + +def test_load_image_path_to_base64_relative_with_cwd_fallback(tmp_path, monkeypatch): + # Create test image in current working directory + + # Change to tmp_path as cwd + monkeypatch.chdir(tmp_path) + + img = PIL.Image.new("RGB", (1, 1), color="yellow") + image_path = tmp_path / "test_cwd.png" + img.save(image_path) + + # Use relative path without base_path - should fall back to cwd + result = load_image_path_to_base64("test_cwd.png") + assert result is not None + assert len(result) > 0 + + +def test_load_image_path_to_base64_base_path_fallback_to_cwd(tmp_path, monkeypatch): + # Test the case where base_path is provided but file isn't there, falls back to cwd + monkeypatch.chdir(tmp_path) + + # Create image in cwd + img = PIL.Image.new("RGB", (1, 1), color="red") + image_path = tmp_path / "test.png" + img.save(image_path) + + # Create a different base_path that doesn't have the image + wrong_base = tmp_path / "wrong" + wrong_base.mkdir() + + # Use relative path with wrong base_path - should fall back to cwd + result = load_image_path_to_base64("test.png", base_path=str(wrong_base)) + assert result is not None + assert len(result) > 0 + + +def test_load_image_path_to_base64_exception_handling(tmp_path): + # Create a directory (not a file) to trigger exception + dir_path = tmp_path / "directory" + dir_path.mkdir() + + result = load_image_path_to_base64(str(dir_path)) + assert result is None diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py index db3c9c9e..7ad7a18c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py @@ -11,25 +11,27 @@ from data_designer.engine.processing.utils import deserialize_json_values if TYPE_CHECKING: - from data_designer.engine.storage.image_storage import ImageStorageManager + from data_designer.engine.storage.multimedia_storage import MultimediaStorage class ImageCellGenerator(WithJinja2UserTemplateRendering, ColumnGeneratorWithModel[ImageGenerationColumnConfig]): """Generator for image columns with optional disk persistence. - Behavior depends on whether image_storage_manager is set: - - If set (create mode): Saves images to disk and stores relative paths in dataframe + Behavior depends on whether multimedia storage is available via ResourceProvider: + - If available (create mode): Saves images to disk and stores relative paths in dataframe - If None (preview mode): Stores base64 directly in dataframe API is automatically detected based on the model name: - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) β†’ image_generation API - All other models β†’ chat/completions API (default) - Attributes: - image_storage_manager: Optional image storage manager instance (set by dataset builder) + Storage is accessed via ResourceProvider.artifact_storage.multimedia_storage """ - image_storage_manager: ImageStorageManager | None = None + @property + def multimedia_storage(self) -> MultimediaStorage | None: + """Get multimedia storage from resource provider if available.""" + return self._resource_provider.artifact_storage.multimedia_storage @staticmethod def get_generation_strategy() -> GenerationStrategy: @@ -67,11 +69,9 @@ def generate(self, data: dict) -> dict: base64_images = self.model.generate_image(prompt=prompt) # Store in dataframe based on mode - if self.image_storage_manager: + if self.multimedia_storage: # Create mode: save each image to disk and store list of relative paths - relative_paths = [ - self.image_storage_manager.save_base64_image(base64_image) for base64_image in base64_images - ] + relative_paths = [self.multimedia_storage.save_base64_image(base64_image) for base64_image in base64_images] data[self.config.name] = relative_paths else: # Preview mode: store list of base64 strings directly diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py index 35e7d4f8..b5ffaae7 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py @@ -11,11 +11,12 @@ from pathlib import Path from typing import TYPE_CHECKING -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from data_designer.config.utils.io_helpers import read_parquet_dataset from data_designer.config.utils.type_helpers import StrEnum, resolve_string_enum from data_designer.engine.dataset_builders.errors import ArtifactStorageError +from data_designer.engine.storage.multimedia_storage import MultimediaStorage from data_designer.lazy_heavy_imports import pd if TYPE_CHECKING: @@ -38,12 +39,15 @@ class BatchStage(StrEnum): class ArtifactStorage(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + artifact_path: Path | str dataset_name: str = "dataset" final_dataset_folder_name: str = FINAL_DATASET_FOLDER_NAME partial_results_folder_name: str = "tmp-partial-parquet-files" dropped_columns_folder_name: str = "dropped-columns-parquet-files" processors_outputs_folder_name: str = PROCESSORS_OUTPUTS_FOLDER_NAME + multimedia_storage: MultimediaStorage | None = Field(default=None, exclude=True) @property def artifact_path_exists(self) -> bool: @@ -116,6 +120,22 @@ def validate_folder_names(self): return self + def ensure_multimedia_storage(self) -> MultimediaStorage: + """Lazily create multimedia storage if not already present. + + Returns: + MultimediaStorage instance + + Note: + Creates storage with default settings (images_subdir="images", validate_images=True) + """ + if self.multimedia_storage is None: + self.multimedia_storage = MultimediaStorage( + base_path=self.base_dataset_path, + validate_images=True, + ) + return self.multimedia_storage + @staticmethod def mkdir_if_needed(path: Path | str) -> Path: """Create the directory if it does not exist.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 7a2962eb..ac4469eb 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -26,7 +26,6 @@ ColumnGeneratorWithModel, GenerationStrategy, ) -from data_designer.engine.column_generators.generators.image import ImageCellGenerator from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated from data_designer.engine.compiler import compile_data_designer_config from data_designer.engine.dataset_builders.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage @@ -41,7 +40,6 @@ from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry from data_designer.engine.resources.resource_provider import ResourceProvider -from data_designer.engine.storage.image_storage import ImageStorageManager from data_designer.lazy_heavy_imports import pd if TYPE_CHECKING: @@ -66,7 +64,6 @@ def __init__( self._resource_provider = resource_provider self._records_to_drop: set[int] = set() self._registry = registry or DataDesignerRegistry() - self._image_storage_manager: ImageStorageManager | None = None self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider) self._column_configs = compile_dataset_builder_column_configs(self._data_designer_config) @@ -98,11 +95,31 @@ def build( *, num_records: int, on_batch_complete: Callable[[Path], None] | None = None, + save_multimedia_to_disk: bool = True, ) -> Path: + """Build the dataset. + + Args: + num_records: Number of records to generate. + on_batch_complete: Optional callback function called when each batch completes. + save_multimedia_to_disk: Whether to save generated multimedia (images, audio, video) to disk. + If False, multimedia is stored directly in the DataFrame (e.g., images as base64). + Default is True. + + Returns: + Path to the generated dataset directory. + """ self._run_model_health_check_if_needed() self._run_mcp_tool_check_if_needed() self._write_builder_config() - self._initialize_image_storage_if_needed() + + # Ensure multimedia storage exists if needed + if save_multimedia_to_disk and self._has_image_columns(): + self.artifact_storage.ensure_multimedia_storage() + else: + # Disable storage for preview or when explicitly disabled + self.artifact_storage.multimedia_storage = None + generators = self._initialize_generators() start_time = time.perf_counter() group_id = uuid.uuid4().hex @@ -128,7 +145,7 @@ def build( def build_preview(self, *, num_records: int) -> pd.DataFrame: self._run_model_health_check_if_needed() self._run_mcp_tool_check_if_needed() - # Skip image storage initialization for preview - base64 will be stored directly in DataFrame + # Skip multimedia storage initialization for preview - base64 will be stored directly in DataFrame generators = self._initialize_generators() group_id = uuid.uuid4().hex @@ -155,26 +172,16 @@ def _has_image_columns(self) -> bool: return any(col.column_type == DataDesignerColumnType.IMAGE_GENERATION for col in self.single_column_configs) - def _initialize_image_storage_if_needed(self) -> None: - """Initialize image storage manager if dataset has image columns.""" - if self._has_image_columns(): - self._image_storage_manager = ImageStorageManager( - base_path=self.artifact_storage.base_dataset_path, images_subdir="images", validate_images=True - ) - def _initialize_generators(self) -> list[ColumnGenerator]: + """Initialize column generators. + + Generators access multimedia storage via ResourceProvider.artifact_storage.multimedia_storage + """ generators = [] for config in self._column_configs: generator_cls = self._registry.column_generators.get_for_config_type(type(config)) generator = generator_cls(config=config, resource_provider=self._resource_provider) - - # Inject image storage manager for image generators (if available) - # For preview mode, storage manager is None and base64 is stored directly - if isinstance(generator, ImageCellGenerator): - generator.image_storage_manager = self._image_storage_manager - generators.append(generator) - return generators def _write_builder_config(self) -> None: diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py b/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py index ad7ef0d5..820d512a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from data_designer.engine.storage.image_storage import ImageFormat, ImageStorageManager +from data_designer.engine.storage.multimedia_storage import MultimediaStorage -__all__ = ["ImageFormat", "ImageStorageManager"] +__all__ = ["MultimediaStorage"] diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/image_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/multimedia_storage.py similarity index 80% rename from packages/data-designer-engine/src/data_designer/engine/storage/image_storage.py rename to packages/data-designer-engine/src/data_designer/engine/storage/multimedia_storage.py index 22d4bf84..e40c0032 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/image_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/multimedia_storage.py @@ -5,28 +5,32 @@ import uuid from pathlib import Path -from typing import TYPE_CHECKING -from data_designer.config.utils.image_helpers import decode_base64_image, detect_image_format -from data_designer.lazy_heavy_imports import PIL +from data_designer.config.utils.image_helpers import decode_base64_image, detect_image_format, validate_image -if TYPE_CHECKING: - import PIL +IMAGES_SUBDIR = "images" -class ImageStorageManager: - """Manages disk storage of generated images. +class MultimediaStorage: + """Manages disk storage of generated multimedia content. + + Currently supports: + - Images (PNG, JPG, WEBP) + + Future support planned for: + - Audio + - Video Handles: - - Creating images directory + - Creating storage directories - Decoding base64 to bytes - - Detecting image format + - Detecting media format - Saving with UUID filenames - Returning relative paths """ - def __init__(self, base_path: Path, images_subdir: str = "images", validate_images: bool = True) -> None: - """Initialize image storage manager. + def __init__(self, base_path: Path, images_subdir: str = IMAGES_SUBDIR, validate_images: bool = True) -> None: + """Initialize multimedia storage manager. Args: base_path: Base directory for dataset @@ -88,12 +92,11 @@ def _validate_image(self, image_path: Path) -> None: ValueError: If image is corrupted or unreadable """ try: - with PIL.Image.open(image_path) as img: - img.verify() - except Exception as e: + validate_image(image_path) + except ValueError: # Clean up invalid file image_path.unlink(missing_ok=True) - raise ValueError(f"Saved image is invalid or corrupted: {e}") from e + raise def cleanup(self) -> None: """Clean up image directory (for preview mode).""" diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py new file mode 100644 index 00000000..7173ed2d --- /dev/null +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import Mock, patch + +import pytest + +from data_designer.config.column_configs import ImageGenerationColumnConfig +from data_designer.engine.column_generators.generators.base import GenerationStrategy +from data_designer.engine.column_generators.generators.image import ImageCellGenerator +from data_designer.engine.processing.ginja.exceptions import UserTemplateError + + +@pytest.fixture +def stub_image_column_config(): + return ImageGenerationColumnConfig( + name="test_image", prompt="A {{ style }} image of {{ subject }}", model_alias="test_model" + ) + + +@pytest.fixture +def stub_base64_images() -> list[str]: + return ["base64_image_1", "base64_image_2"] + + +def test_image_cell_generator_generation_strategy( + stub_image_column_config: ImageGenerationColumnConfig, stub_resource_provider: None +) -> None: + generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) + assert generator.get_generation_strategy() == GenerationStrategy.CELL_BY_CELL + + +def test_image_cell_generator_multimedia_storage_property( + stub_image_column_config: ImageGenerationColumnConfig, stub_resource_provider: None +) -> None: + generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) + # Should return multimedia_storage from artifact_storage (None by default in stub) + assert generator.multimedia_storage is None + + +def test_image_cell_generator_generate_with_storage( + stub_image_column_config, stub_resource_provider, stub_base64_images +): + """Test generate with multimedia storage (create mode) - saves to disk.""" + # Setup mock multimedia storage + mock_storage = Mock() + mock_storage.save_base64_image.side_effect = ["images/uuid1.png", "images/uuid2.png"] + stub_resource_provider.artifact_storage.multimedia_storage = mock_storage + + with patch.object( + stub_resource_provider.model_registry.get_model.return_value, + "generate_image", + return_value=stub_base64_images, + ) as mock_generate: + generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) + data = generator.generate(data={"style": "photorealistic", "subject": "cat"}) + + # Check that column was added with relative paths + assert stub_image_column_config.name in data + assert data[stub_image_column_config.name] == ["images/uuid1.png", "images/uuid2.png"] + + # Verify model was called with rendered prompt + mock_generate.assert_called_once_with(prompt="A photorealistic image of cat") + + # Verify storage was called for each image + assert mock_storage.save_base64_image.call_count == 2 + mock_storage.save_base64_image.assert_any_call("base64_image_1") + mock_storage.save_base64_image.assert_any_call("base64_image_2") + + +def test_image_cell_generator_generate_without_storage( + stub_image_column_config, stub_resource_provider, stub_base64_images +): + """Test generate without multimedia storage (preview mode) - stores base64 directly.""" + # Ensure multimedia_storage is None (preview mode) + stub_resource_provider.artifact_storage.multimedia_storage = None + + with patch.object( + stub_resource_provider.model_registry.get_model.return_value, + "generate_image", + return_value=stub_base64_images, + ) as mock_generate: + generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) + data = generator.generate(data={"style": "watercolor", "subject": "dog"}) + + # Check that column was added with base64 data + assert stub_image_column_config.name in data + assert data[stub_image_column_config.name] == stub_base64_images + + # Verify model was called with rendered prompt + mock_generate.assert_called_once_with(prompt="A watercolor image of dog") + + +def test_image_cell_generator_missing_columns_error(stub_image_column_config, stub_resource_provider): + """Test that missing required columns raises ValueError.""" + generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) + + with pytest.raises(ValueError, match="columns.*missing"): + # Missing 'subject' column + generator.generate(data={"style": "photorealistic"}) + + +def test_image_cell_generator_empty_prompt_error(stub_resource_provider): + """Test that empty rendered prompt raises UserTemplateError.""" + # Create config with template that renders to empty string + config = ImageGenerationColumnConfig(name="test_image", prompt="{{ empty }}", model_alias="test_model") + + generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) + + with pytest.raises(UserTemplateError): + generator.generate(data={"empty": ""}) + + +def test_image_cell_generator_whitespace_only_prompt_error(stub_resource_provider): + """Test that whitespace-only rendered prompt raises ValueError.""" + config = ImageGenerationColumnConfig(name="test_image", prompt="{{ spaces }}", model_alias="test_model") + + generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) + + with pytest.raises(ValueError, match="empty"): + generator.generate(data={"spaces": " "}) diff --git a/packages/data-designer-engine/tests/engine/storage/__init__.py b/packages/data-designer-engine/tests/engine/storage/__init__.py new file mode 100644 index 00000000..e5725ea5 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/storage/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/data-designer-engine/tests/engine/storage/test_multimedia_storage.py b/packages/data-designer-engine/tests/engine/storage/test_multimedia_storage.py new file mode 100644 index 00000000..ade76b5a --- /dev/null +++ b/packages/data-designer-engine/tests/engine/storage/test_multimedia_storage.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import base64 +import io + +# Explicitly import PIL.Image submodule to make it accessible as PIL.Image +# Python doesn't automatically import submodules when you import a package, +# so `import PIL` alone doesn't give you access to PIL.Image +import PIL.Image # noqa: E402 +import pytest + +from data_designer.engine.storage.multimedia_storage import IMAGES_SUBDIR, MultimediaStorage +from data_designer.lazy_heavy_imports import PIL + + +@pytest.fixture +def multimedia_storage(tmp_path): + """Create a MultimediaStorage instance with a temporary directory.""" + return MultimediaStorage(base_path=tmp_path) + + +@pytest.fixture +def sample_base64_png() -> str: + """Create a valid 1x1 PNG as base64.""" + img = PIL.Image.new("RGB", (1, 1), color="red") + buf = io.BytesIO() + img.save(buf, format="PNG") + png_bytes = buf.getvalue() + return base64.b64encode(png_bytes).decode() + + +@pytest.fixture +def sample_base64_jpg() -> str: + """Create a valid 1x1 JPEG as base64.""" + img = PIL.Image.new("RGB", (1, 1), color="blue") + buf = io.BytesIO() + img.save(buf, format="JPEG") + jpg_bytes = buf.getvalue() + return base64.b64encode(jpg_bytes).decode() + + +def test_multimedia_storage_init(tmp_path): + """Test MultimediaStorage initialization.""" + storage = MultimediaStorage(base_path=tmp_path) + assert storage.base_path == tmp_path + assert storage.images_dir == tmp_path / IMAGES_SUBDIR + assert storage.images_subdir == IMAGES_SUBDIR + assert storage.validate_images is True + # Should create images directory on init + assert storage.images_dir.exists() + + +def test_multimedia_storage_init_custom_subdir(tmp_path): + """Test MultimediaStorage initialization with custom subdirectory.""" + custom_subdir = "custom_images" + storage = MultimediaStorage(base_path=tmp_path, images_subdir=custom_subdir, validate_images=False) + assert storage.images_subdir == custom_subdir + assert storage.images_dir == tmp_path / custom_subdir + assert storage.validate_images is False + assert storage.images_dir.exists() + + +def test_save_base64_image_png(multimedia_storage, sample_base64_png): + """Test saving a PNG image from base64.""" + relative_path = multimedia_storage.save_base64_image(sample_base64_png) + + # Check return value format + assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + assert relative_path.endswith(".png") + + # Check file exists on disk + full_path = multimedia_storage.base_path / relative_path + assert full_path.exists() + + # Verify file content + saved_bytes = full_path.read_bytes() + expected_bytes = base64.b64decode(sample_base64_png) + assert saved_bytes == expected_bytes + + +def test_save_base64_image_jpg(multimedia_storage, sample_base64_jpg): + """Test saving a JPEG image from base64.""" + relative_path = multimedia_storage.save_base64_image(sample_base64_jpg) + + # Check return value format + assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + assert relative_path.endswith(".jpg") + + # Check file exists on disk + full_path = multimedia_storage.base_path / relative_path + assert full_path.exists() + + +def test_save_base64_image_with_data_uri(multimedia_storage, sample_base64_png): + """Test saving image from data URI format.""" + data_uri = f"data:image/png;base64,{sample_base64_png}" + relative_path = multimedia_storage.save_base64_image(data_uri) + + # Should successfully extract base64 and save + assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + assert relative_path.endswith(".png") + + # Verify file exists and content is correct + full_path = multimedia_storage.base_path / relative_path + assert full_path.exists() + saved_bytes = full_path.read_bytes() + expected_bytes = base64.b64decode(sample_base64_png) + assert saved_bytes == expected_bytes + + +def test_save_base64_image_invalid_base64_raises_error(multimedia_storage): + """Test that invalid base64 data raises ValueError.""" + with pytest.raises(ValueError, match="Invalid base64"): + multimedia_storage.save_base64_image("not-valid-base64!!!") + + +def test_save_base64_image_multiple_images_unique_filenames(multimedia_storage, sample_base64_png): + """Test that multiple images get unique filenames.""" + path1 = multimedia_storage.save_base64_image(sample_base64_png) + path2 = multimedia_storage.save_base64_image(sample_base64_png) + + # Paths should be different (different UUIDs) + assert path1 != path2 + + # Both files should exist + assert (multimedia_storage.base_path / path1).exists() + assert (multimedia_storage.base_path / path2).exists() + + +def test_save_base64_image_validation_enabled(tmp_path, sample_base64_png): + """Test that validation is performed when enabled.""" + storage = MultimediaStorage(base_path=tmp_path, validate_images=True) + # Should succeed with valid image + relative_path = storage.save_base64_image(sample_base64_png) + assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + + +def test_save_base64_image_validation_corrupted_image_raises_error(tmp_path): + """Test that corrupted image fails validation and is cleaned up.""" + storage = MultimediaStorage(base_path=tmp_path, validate_images=True) + + # Create base64 of invalid image data + corrupted_bytes = b"not a valid image" + corrupted_base64 = base64.b64encode(corrupted_bytes).decode() + + with pytest.raises(ValueError, match="Image validation failed"): + storage.save_base64_image(corrupted_base64) + + # Check that no files were left behind + assert len(list(storage.images_dir.iterdir())) == 0 + + +def test_save_base64_image_validation_disabled(tmp_path): + """Test that validation can be disabled.""" + storage = MultimediaStorage(base_path=tmp_path, validate_images=False) + + # Create base64 of invalid image data + corrupted_bytes = b"not a valid image" + corrupted_base64 = base64.b64encode(corrupted_bytes).decode() + + # Should succeed without validation + relative_path = storage.save_base64_image(corrupted_base64) + assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + + # File should exist even though it's invalid + full_path = storage.base_path / relative_path + assert full_path.exists() + + +def test_cleanup(multimedia_storage, sample_base64_png): + """Test cleanup removes images directory.""" + # Save an image first + multimedia_storage.save_base64_image(sample_base64_png) + assert multimedia_storage.images_dir.exists() + assert len(list(multimedia_storage.images_dir.iterdir())) > 0 + + # Cleanup should remove directory + multimedia_storage.cleanup() + assert not multimedia_storage.images_dir.exists() From 2aae6ccd6f09064feddc6b14d1faa15dd5c5e417 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 6 Feb 2026 17:30:10 -0700 Subject: [PATCH 26/64] streamline generation --- .../config/utils/image_helpers.py | 8 +- .../config/utils/visualization.py | 28 +-- .../src/data_designer/lazy_heavy_imports.py | 1 + .../tests/config/utils/test_image_helpers.py | 24 +-- .../column_generators/generators/image.py | 31 ++- .../dataset_builders/artifact_storage.py | 28 ++- .../dataset_builders/column_wise_builder.py | 18 +- .../data_designer/engine/storage/__init__.py | 4 +- ...multimedia_storage.py => media_storage.py} | 63 ++++-- .../generators/test_image.py | 24 +-- .../dataset_builders/test_artifact_storage.py | 7 +- .../engine/storage/test_media_storage.py | 174 +++++++++++++++++ .../engine/storage/test_multimedia_storage.py | 182 ------------------ .../tests/engine/test_configurable_task.py | 33 +--- 14 files changed, 301 insertions(+), 324 deletions(-) rename packages/data-designer-engine/src/data_designer/engine/storage/{multimedia_storage.py => media_storage.py} (56%) create mode 100644 packages/data-designer-engine/tests/engine/storage/test_media_storage.py delete mode 100644 packages/data-designer-engine/tests/engine/storage/test_multimedia_storage.py diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py index 1f5ec332..67803aff 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py @@ -11,10 +11,10 @@ from typing import TYPE_CHECKING from data_designer.config.models import ImageFormat -from data_designer.lazy_heavy_imports import PIL +from data_designer.lazy_heavy_imports import Image if TYPE_CHECKING: - import PIL + from PIL import Image # Magic bytes for image format detection IMAGE_FORMAT_MAGIC_BYTES = { @@ -90,7 +90,7 @@ def detect_image_format(image_bytes: bytes) -> ImageFormat: # Fallback to PIL for robust detection try: - img = PIL.Image.open(io.BytesIO(image_bytes)) + img = Image.open(io.BytesIO(image_bytes)) format_str = img.format.lower() if img.format else None if format_str in [ImageFormat.PNG, ImageFormat.JPG, ImageFormat.JPEG, ImageFormat.WEBP]: return ImageFormat(format_str if format_str != ImageFormat.JPEG else ImageFormat.JPG) @@ -201,7 +201,7 @@ def validate_image(image_path: Path) -> None: ValueError: If image is corrupted or unreadable """ try: - with PIL.Image.open(image_path) as img: + with Image.open(image_path) as img: img.verify() except Exception as e: raise ValueError(f"Image validation failed: {e}") from e diff --git a/packages/data-designer-config/src/data_designer/config/utils/visualization.py b/packages/data-designer-config/src/data_designer/config/utils/visualization.py index c349ec86..6a9e8ee5 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/visualization.py +++ b/packages/data-designer-config/src/data_designer/config/utils/visualization.py @@ -3,8 +3,6 @@ from __future__ import annotations -import base64 -import io import json import os from collections import OrderedDict @@ -35,12 +33,11 @@ is_image_url, load_image_path_to_base64, ) -from data_designer.lazy_heavy_imports import PIL, np, pd +from data_designer.lazy_heavy_imports import np, pd if TYPE_CHECKING: import numpy as np import pandas as pd - import PIL from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.dataset_metadata import DatasetMetadata @@ -49,15 +46,12 @@ console = Console() -def _display_image_if_in_notebook( - image_data: str, col_name: str, max_width: int = 512, base_path: str | None = None -) -> bool: +def _display_image_if_in_notebook(image_data: str, col_name: str, base_path: str | None = None) -> bool: """Display image with caption in Jupyter notebook if available. Args: image_data: Base64-encoded image data, data URI, or file path. col_name: Name of the column (used for caption). - max_width: Maximum width for the displayed image in pixels. base_path: Optional base path to resolve relative image paths. Returns: @@ -83,27 +77,15 @@ def _display_image_if_in_notebook( # Extract base64 from data URI if present base64_data = extract_base64_from_data_uri(base64_data) - image_bytes = base64.b64decode(base64_data) - # Open image with PIL and resize if needed - img = PIL.Image.open(io.BytesIO(image_bytes)) - - # Resize if image is too large - if img.width > max_width: - ratio = max_width / img.width - new_height = int(img.height * ratio) - img = img.resize((max_width, new_height), PIL.Image.Resampling.LANCZOS) - - # Convert back to base64 for HTML display - buffered = io.BytesIO() - img.save(buffered, format=img.format or "PNG") - img_base64 = base64.b64encode(buffered.getvalue()).decode() + # Use the base64 data directly without resizing + img_base64 = base64_data # Create HTML with caption and image in left-aligned container html = f"""
πŸ–ΌοΈ {col_name}
- +
""" display(HTML(html)) diff --git a/packages/data-designer-config/src/data_designer/lazy_heavy_imports.py b/packages/data-designer-config/src/data_designer/lazy_heavy_imports.py index f7901a7c..0e95f248 100644 --- a/packages/data-designer-config/src/data_designer/lazy_heavy_imports.py +++ b/packages/data-designer-config/src/data_designer/lazy_heavy_imports.py @@ -36,6 +36,7 @@ "scipy": "scipy", "jsonschema": "jsonschema", "PIL": "PIL", + "Image": "PIL.Image", } diff --git a/packages/data-designer-config/tests/config/utils/test_image_helpers.py b/packages/data-designer-config/tests/config/utils/test_image_helpers.py index 9c7ccd7f..e0eb0370 100644 --- a/packages/data-designer-config/tests/config/utils/test_image_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_image_helpers.py @@ -5,13 +5,8 @@ import base64 import io -from typing import TYPE_CHECKING from unittest.mock import Mock, patch -# Explicitly import PIL.Image submodule to make it accessible as PIL.Image -# Python doesn't automatically import submodules when you import a package, -# so `import PIL` alone doesn't give you access to PIL.Image -import PIL.Image # noqa: E402 import pytest from data_designer.config.models import ImageFormat @@ -26,10 +21,7 @@ load_image_path_to_base64, validate_image, ) -from data_designer.lazy_heavy_imports import PIL - -if TYPE_CHECKING: - import PIL +from data_designer.lazy_heavy_imports import Image # Tests for extract_base64_from_data_uri @@ -157,7 +149,7 @@ def test_is_image_url_non_http(): def test_validate_image_valid_png(tmp_path): # Create a valid 1x1 PNG using PIL - img = PIL.Image.new("RGB", (1, 1), color="red") + img = Image.new("RGB", (1, 1), color="red") buf = io.BytesIO() img.save(buf, format="PNG") png_bytes = buf.getvalue() @@ -200,7 +192,7 @@ def test_get_supported_image_extensions_matches_enum(): def test_detect_image_format_with_pil_fallback_unsupported_format(tmp_path): # Create a real GIF image that will trigger PIL fallback # (GIF has different magic bytes not in our fast-path detection) - img = PIL.Image.new("RGB", (1, 1), color="red") + img = Image.new("RGB", (1, 1), color="red") gif_path = tmp_path / "test.gif" img.save(gif_path, format="GIF") @@ -219,7 +211,7 @@ def test_detect_image_format_with_pil_fallback_jpeg(): # Use bytes that don't match our magic bytes to trigger PIL fallback test_bytes = b"\x00\x00\x00\x00" - with patch.object(PIL.Image, "open", return_value=mock_img): + with patch.object(Image, "open", return_value=mock_img): result = detect_image_format(test_bytes) # Should convert JPEG -> JPG via line 96 assert result == ImageFormat.JPG @@ -255,7 +247,7 @@ def test_is_image_url_non_string_input(): def test_load_image_path_to_base64_absolute_path(tmp_path): # Create a test image file - img = PIL.Image.new("RGB", (1, 1), color="blue") + img = Image.new("RGB", (1, 1), color="blue") image_path = tmp_path / "test.png" img.save(image_path) @@ -270,7 +262,7 @@ def test_load_image_path_to_base64_absolute_path(tmp_path): def test_load_image_path_to_base64_relative_with_base_path(tmp_path): # Create a test image file - img = PIL.Image.new("RGB", (1, 1), color="green") + img = Image.new("RGB", (1, 1), color="green") image_path = tmp_path / "subdir" / "test.png" image_path.parent.mkdir(exist_ok=True) img.save(image_path) @@ -292,7 +284,7 @@ def test_load_image_path_to_base64_relative_with_cwd_fallback(tmp_path, monkeypa # Change to tmp_path as cwd monkeypatch.chdir(tmp_path) - img = PIL.Image.new("RGB", (1, 1), color="yellow") + img = Image.new("RGB", (1, 1), color="yellow") image_path = tmp_path / "test_cwd.png" img.save(image_path) @@ -307,7 +299,7 @@ def test_load_image_path_to_base64_base_path_fallback_to_cwd(tmp_path, monkeypat monkeypatch.chdir(tmp_path) # Create image in cwd - img = PIL.Image.new("RGB", (1, 1), color="red") + img = Image.new("RGB", (1, 1), color="red") image_path = tmp_path / "test.png" img.save(image_path) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py index 7ad7a18c..41586e4b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py @@ -11,27 +11,27 @@ from data_designer.engine.processing.utils import deserialize_json_values if TYPE_CHECKING: - from data_designer.engine.storage.multimedia_storage import MultimediaStorage + from data_designer.engine.storage.media_storage import MediaStorage class ImageCellGenerator(WithJinja2UserTemplateRendering, ColumnGeneratorWithModel[ImageGenerationColumnConfig]): - """Generator for image columns with optional disk persistence. + """Generator for image columns with disk or dataframe persistence. - Behavior depends on whether multimedia storage is available via ResourceProvider: - - If available (create mode): Saves images to disk and stores relative paths in dataframe - - If None (preview mode): Stores base64 directly in dataframe + Media storage always exists and determines behavior via its mode: + - DISK mode (create): Saves images to disk and stores relative paths in dataframe + - DATAFRAME mode (preview): Stores base64 directly in dataframe API is automatically detected based on the model name: - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) β†’ image_generation API - All other models β†’ chat/completions API (default) - Storage is accessed via ResourceProvider.artifact_storage.multimedia_storage + Storage is accessed via ResourceProvider.artifact_storage.media_storage """ @property - def multimedia_storage(self) -> MultimediaStorage | None: - """Get multimedia storage from resource provider if available.""" - return self._resource_provider.artifact_storage.multimedia_storage + def media_storage(self) -> MediaStorage: + """Get media storage from resource provider.""" + return self._resource_provider.artifact_storage.media_storage @staticmethod def get_generation_strategy() -> GenerationStrategy: @@ -68,13 +68,10 @@ def generate(self, data: dict) -> dict: # Generate images (returns list of base64 strings) base64_images = self.model.generate_image(prompt=prompt) - # Store in dataframe based on mode - if self.multimedia_storage: - # Create mode: save each image to disk and store list of relative paths - relative_paths = [self.multimedia_storage.save_base64_image(base64_image) for base64_image in base64_images] - data[self.config.name] = relative_paths - else: - # Preview mode: store list of base64 strings directly - data[self.config.name] = base64_images + # Store via media storage (mode determines disk vs dataframe storage) + # TODO: MediaStorage will check its mode (DISK/DATAFRAME) and act accordingly + # For now, always saves to disk - need to implement mode system + results = [self.media_storage.save_base64_image(base64_image) for base64_image in base64_images] + data[self.config.name] = results return data diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py index b5ffaae7..a7316be3 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py @@ -16,7 +16,7 @@ from data_designer.config.utils.io_helpers import read_parquet_dataset from data_designer.config.utils.type_helpers import StrEnum, resolve_string_enum from data_designer.engine.dataset_builders.errors import ArtifactStorageError -from data_designer.engine.storage.multimedia_storage import MultimediaStorage +from data_designer.engine.storage.media_storage import MediaStorage, StorageMode from data_designer.lazy_heavy_imports import pd if TYPE_CHECKING: @@ -47,7 +47,7 @@ class ArtifactStorage(BaseModel): partial_results_folder_name: str = "tmp-partial-parquet-files" dropped_columns_folder_name: str = "dropped-columns-parquet-files" processors_outputs_folder_name: str = PROCESSORS_OUTPUTS_FOLDER_NAME - multimedia_storage: MultimediaStorage | None = Field(default=None, exclude=True) + media_storage: MediaStorage = Field(default=None, exclude=True) @property def artifact_path_exists(self) -> bool: @@ -118,23 +118,21 @@ def validate_folder_names(self): if any(char in invalid_chars for char in name): raise ArtifactStorageError(f"πŸ›‘ Directory name '{name}' contains invalid characters.") - return self + # Initialize media storage with DISK mode by default + self.media_storage = MediaStorage( + base_path=self.base_dataset_path, + mode=StorageMode.DISK, + ) - def ensure_multimedia_storage(self) -> MultimediaStorage: - """Lazily create multimedia storage if not already present. + return self - Returns: - MultimediaStorage instance + def set_media_storage_mode(self, mode: StorageMode) -> None: + """Set media storage mode. - Note: - Creates storage with default settings (images_subdir="images", validate_images=True) + Args: + mode: StorageMode.DISK (save to disk) or StorageMode.DATAFRAME (store in memory) """ - if self.multimedia_storage is None: - self.multimedia_storage = MultimediaStorage( - base_path=self.base_dataset_path, - validate_images=True, - ) - return self.multimedia_storage + self.media_storage.mode = mode @staticmethod def mkdir_if_needed(path: Path | str) -> Path: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index ac4469eb..6802f805 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -40,6 +40,7 @@ from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry from data_designer.engine.resources.resource_provider import ResourceProvider +from data_designer.engine.storage.media_storage import StorageMode from data_designer.lazy_heavy_imports import pd if TYPE_CHECKING: @@ -113,12 +114,10 @@ def build( self._run_mcp_tool_check_if_needed() self._write_builder_config() - # Ensure multimedia storage exists if needed - if save_multimedia_to_disk and self._has_image_columns(): - self.artifact_storage.ensure_multimedia_storage() - else: - # Disable storage for preview or when explicitly disabled - self.artifact_storage.multimedia_storage = None + # Set media storage mode based on parameters + if self._has_image_columns(): + mode = StorageMode.DISK if save_multimedia_to_disk else StorageMode.DATAFRAME + self.artifact_storage.set_media_storage_mode(mode) generators = self._initialize_generators() start_time = time.perf_counter() @@ -145,7 +144,10 @@ def build( def build_preview(self, *, num_records: int) -> pd.DataFrame: self._run_model_health_check_if_needed() self._run_mcp_tool_check_if_needed() - # Skip multimedia storage initialization for preview - base64 will be stored directly in DataFrame + + # Set media storage to DATAFRAME mode for preview - base64 stored directly in DataFrame + if self._has_image_columns(): + self.artifact_storage.set_media_storage_mode(StorageMode.DATAFRAME) generators = self._initialize_generators() group_id = uuid.uuid4().hex @@ -175,7 +177,7 @@ def _has_image_columns(self) -> bool: def _initialize_generators(self) -> list[ColumnGenerator]: """Initialize column generators. - Generators access multimedia storage via ResourceProvider.artifact_storage.multimedia_storage + Generators access media storage via ResourceProvider.artifact_storage.media_storage """ generators = [] for config in self._column_configs: diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py b/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py index 820d512a..34c776d5 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from data_designer.engine.storage.multimedia_storage import MultimediaStorage +from data_designer.engine.storage.media_storage import MediaStorage, StorageMode -__all__ = ["MultimediaStorage"] +__all__ = ["MediaStorage", "StorageMode"] diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/multimedia_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py similarity index 56% rename from packages/data-designer-engine/src/data_designer/engine/storage/multimedia_storage.py rename to packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py index e40c0032..ddac3459 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/multimedia_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py @@ -3,16 +3,29 @@ from __future__ import annotations +import shutil import uuid from pathlib import Path from data_designer.config.utils.image_helpers import decode_base64_image, detect_image_format, validate_image +from data_designer.config.utils.type_helpers import StrEnum IMAGES_SUBDIR = "images" -class MultimediaStorage: - """Manages disk storage of generated multimedia content. +class StorageMode(StrEnum): + """Storage mode for generated media content. + + - DISK: Save media to disk and store relative paths in dataframe (for dataset creation) + - DATAFRAME: Store base64 data directly in dataframe (for preview mode) + """ + + DISK = "disk" + DATAFRAME = "dataframe" + + +class MediaStorage: + """Manages storage of generated media content. Currently supports: - Images (PNG, JPG, WEBP) @@ -21,45 +34,60 @@ class MultimediaStorage: - Audio - Video + Storage modes: + - DISK: Save media to disk and return relative paths (for dataset creation) + - DATAFRAME: Return base64 data directly (for preview mode) + Handles: - Creating storage directories - Decoding base64 to bytes - Detecting media format - - Saving with UUID filenames - - Returning relative paths + - Saving with UUID filenames (DISK mode) + - Returning relative paths or base64 data based on mode + - Always validates images to ensure data quality """ - def __init__(self, base_path: Path, images_subdir: str = IMAGES_SUBDIR, validate_images: bool = True) -> None: - """Initialize multimedia storage manager. + def __init__( + self, base_path: Path, images_subdir: str = IMAGES_SUBDIR, mode: StorageMode = StorageMode.DISK + ) -> None: + """Initialize media storage manager. Args: base_path: Base directory for dataset images_subdir: Subdirectory name for images (default: "images") - validate_images: Whether to validate images after saving (default: True) + mode: Storage mode - DISK (save to disk) or DATAFRAME (return base64) """ self.base_path = Path(base_path) self.images_dir = self.base_path / images_subdir self.images_subdir = images_subdir - self.validate_images = validate_images - self._ensure_images_directory() + self.mode = mode def _ensure_images_directory(self) -> None: - """Create images directory if it doesn't exist.""" + """Create images directory if it doesn't exist (lazy initialization).""" self.images_dir.mkdir(parents=True, exist_ok=True) def save_base64_image(self, base64_data: str) -> str: - """Save base64 image to disk and return relative path. + """Save or return base64 image based on storage mode. Args: base64_data: Base64 encoded image string (with or without data URI prefix) Returns: - Relative path to saved image (e.g., "images/f47ac10b-58cc.png") + DISK mode: Relative path to saved image (e.g., "images/f47ac10b-58cc.png") + DATAFRAME mode: Original base64 data string Raises: - ValueError: If base64 data is invalid - OSError: If disk write fails + ValueError: If base64 data is invalid (DISK mode only) + OSError: If disk write fails (DISK mode only) """ + # DATAFRAME mode: return base64 directly without disk operations + if self.mode == StorageMode.DATAFRAME: + return base64_data + + # DISK mode: save to disk, validate, and return relative path + # Ensure images directory exists (lazy initialization) + self._ensure_images_directory() + # Decode base64 to bytes image_bytes = decode_base64_image(base64_data) @@ -76,9 +104,8 @@ def save_base64_image(self, base64_data: str) -> str: with open(full_path, "wb") as f: f.write(image_bytes) - # Optional validation - if self.validate_images: - self._validate_image(full_path) + # Always validate in DISK mode to ensure data quality + self._validate_image(full_path) return relative_path @@ -100,7 +127,5 @@ def _validate_image(self, image_path: Path) -> None: def cleanup(self) -> None: """Clean up image directory (for preview mode).""" - import shutil - if self.images_dir.exists(): shutil.rmtree(self.images_dir) diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py index 7173ed2d..e7055d67 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py @@ -30,22 +30,22 @@ def test_image_cell_generator_generation_strategy( assert generator.get_generation_strategy() == GenerationStrategy.CELL_BY_CELL -def test_image_cell_generator_multimedia_storage_property( +def test_image_cell_generator_media_storage_property( stub_image_column_config: ImageGenerationColumnConfig, stub_resource_provider: None ) -> None: generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) - # Should return multimedia_storage from artifact_storage (None by default in stub) - assert generator.multimedia_storage is None + # Should return media_storage from artifact_storage (always exists) + assert generator.media_storage is not None def test_image_cell_generator_generate_with_storage( stub_image_column_config, stub_resource_provider, stub_base64_images ): - """Test generate with multimedia storage (create mode) - saves to disk.""" - # Setup mock multimedia storage + """Test generate with media storage (create mode) - saves to disk.""" + # Setup mock media storage mock_storage = Mock() mock_storage.save_base64_image.side_effect = ["images/uuid1.png", "images/uuid2.png"] - stub_resource_provider.artifact_storage.multimedia_storage = mock_storage + stub_resource_provider.artifact_storage.media_storage = mock_storage with patch.object( stub_resource_provider.model_registry.get_model.return_value, @@ -68,12 +68,14 @@ def test_image_cell_generator_generate_with_storage( mock_storage.save_base64_image.assert_any_call("base64_image_2") -def test_image_cell_generator_generate_without_storage( +def test_image_cell_generator_generate_in_dataframe_mode( stub_image_column_config, stub_resource_provider, stub_base64_images ): - """Test generate without multimedia storage (preview mode) - stores base64 directly.""" - # Ensure multimedia_storage is None (preview mode) - stub_resource_provider.artifact_storage.multimedia_storage = None + """Test generate with media storage in DATAFRAME mode - stores base64 directly.""" + # Mock save_base64_image to return base64 directly (simulating DATAFRAME mode) + mock_storage = Mock() + mock_storage.save_base64_image.side_effect = stub_base64_images + stub_resource_provider.artifact_storage.media_storage = mock_storage with patch.object( stub_resource_provider.model_registry.get_model.return_value, @@ -83,7 +85,7 @@ def test_image_cell_generator_generate_without_storage( generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) data = generator.generate(data={"style": "watercolor", "subject": "dog"}) - # Check that column was added with base64 data + # Check that column was added with base64 data (simulating DATAFRAME mode) assert stub_image_column_config.name in data assert data[stub_image_column_config.name] == stub_base64_images diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_artifact_storage.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_artifact_storage.py index df15b4f7..35edf892 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_artifact_storage.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_artifact_storage.py @@ -213,10 +213,11 @@ def test_artifact_storage_resolved_dataset_name(mock_datetime, tmp_path): (af_storage.artifact_path / af_storage.dataset_name).mkdir() assert af_storage.resolved_dataset_name == "dataset" - # dataset path exists and is not empty + # dataset path exists and is not empty (create file BEFORE constructing ArtifactStorage) + dataset_dir = tmp_path / "dataset" + dataset_dir.mkdir(exist_ok=True) + (dataset_dir / "stub_file.txt").touch() af_storage = ArtifactStorage(artifact_path=tmp_path) - (af_storage.artifact_path / af_storage.dataset_name / "stub_file.txt").touch() - print(af_storage.resolved_dataset_name) assert af_storage.resolved_dataset_name == "dataset_01-01-2025_120304" diff --git a/packages/data-designer-engine/tests/engine/storage/test_media_storage.py b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py new file mode 100644 index 00000000..abd17afe --- /dev/null +++ b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import base64 +import io + +import pytest + +from data_designer.engine.storage.media_storage import IMAGES_SUBDIR, MediaStorage, StorageMode +from data_designer.lazy_heavy_imports import Image + + +@pytest.fixture +def media_storage(tmp_path): + """Create a MediaStorage instance with a temporary directory.""" + return MediaStorage(base_path=tmp_path) + + +@pytest.fixture +def sample_base64_png() -> str: + """Create a valid 1x1 PNG as base64.""" + img = Image.new("RGB", (1, 1), color="red") + buf = io.BytesIO() + img.save(buf, format="PNG") + png_bytes = buf.getvalue() + return base64.b64encode(png_bytes).decode() + + +@pytest.fixture +def sample_base64_jpg() -> str: + """Create a valid 1x1 JPEG as base64.""" + img = Image.new("RGB", (1, 1), color="blue") + buf = io.BytesIO() + img.save(buf, format="JPEG") + jpg_bytes = buf.getvalue() + return base64.b64encode(jpg_bytes).decode() + + +def test_media_storage_init(tmp_path): + """Test MediaStorage initialization.""" + storage = MediaStorage(base_path=tmp_path) + assert storage.base_path == tmp_path + assert storage.images_dir == tmp_path / IMAGES_SUBDIR + assert storage.images_subdir == IMAGES_SUBDIR + assert storage.mode == StorageMode.DISK + # Directory should NOT exist until first save (lazy initialization) + assert not storage.images_dir.exists() + + +def test_media_storage_init_custom_subdir(tmp_path): + """Test MediaStorage initialization with custom subdirectory and mode.""" + custom_subdir = "custom_images" + storage = MediaStorage(base_path=tmp_path, images_subdir=custom_subdir, mode=StorageMode.DATAFRAME) + assert storage.images_subdir == custom_subdir + assert storage.images_dir == tmp_path / custom_subdir + assert storage.mode == StorageMode.DATAFRAME + # Directory should NOT exist until first save (lazy initialization) + assert not storage.images_dir.exists() + + +def test_save_base64_image_png(media_storage, sample_base64_png): + """Test saving a PNG image from base64.""" + relative_path = media_storage.save_base64_image(sample_base64_png) + + # Check return value format + assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + assert relative_path.endswith(".png") + + # Check file exists on disk + full_path = media_storage.base_path / relative_path + assert full_path.exists() + + # Verify file content + saved_bytes = full_path.read_bytes() + expected_bytes = base64.b64decode(sample_base64_png) + assert saved_bytes == expected_bytes + + +def test_save_base64_image_jpg(media_storage, sample_base64_jpg): + """Test saving a JPEG image from base64.""" + relative_path = media_storage.save_base64_image(sample_base64_jpg) + + # Check return value format + assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + assert relative_path.endswith(".jpg") + + # Check file exists on disk + full_path = media_storage.base_path / relative_path + assert full_path.exists() + + +def test_save_base64_image_with_data_uri(media_storage, sample_base64_png): + """Test saving image from data URI format.""" + data_uri = f"data:image/png;base64,{sample_base64_png}" + relative_path = media_storage.save_base64_image(data_uri) + + # Should successfully extract base64 and save + assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + assert relative_path.endswith(".png") + + # Verify file exists and content is correct + full_path = media_storage.base_path / relative_path + assert full_path.exists() + saved_bytes = full_path.read_bytes() + expected_bytes = base64.b64decode(sample_base64_png) + assert saved_bytes == expected_bytes + + +def test_save_base64_image_invalid_base64_raises_error(media_storage): + """Test that invalid base64 data raises ValueError.""" + with pytest.raises(ValueError, match="Invalid base64"): + media_storage.save_base64_image("not-valid-base64!!!") + + +def test_save_base64_image_multiple_images_unique_filenames(media_storage, sample_base64_png): + """Test that multiple images get unique filenames.""" + path1 = media_storage.save_base64_image(sample_base64_png) + path2 = media_storage.save_base64_image(sample_base64_png) + + # Paths should be different (different UUIDs) + assert path1 != path2 + + # Both files should exist + assert (media_storage.base_path / path1).exists() + assert (media_storage.base_path / path2).exists() + + +def test_save_base64_image_disk_mode_validates(tmp_path, sample_base64_png): + """Test that DISK mode validates images.""" + storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DISK) + # Should succeed with valid image + relative_path = storage.save_base64_image(sample_base64_png) + assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + + +def test_save_base64_image_disk_mode_corrupted_image_raises_error(tmp_path): + """Test that DISK mode validates and rejects corrupted images.""" + storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DISK) + + # Create base64 of invalid image data + corrupted_bytes = b"not a valid image" + corrupted_base64 = base64.b64encode(corrupted_bytes).decode() + + with pytest.raises(ValueError, match="Image validation failed"): + storage.save_base64_image(corrupted_base64) + + # Check that no files were left behind (cleanup on validation failure) + assert len(list(storage.images_dir.iterdir())) == 0 + + +def test_save_base64_image_dataframe_mode_returns_base64(tmp_path, sample_base64_png): + """Test that DATAFRAME mode returns base64 directly without disk operations.""" + storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DATAFRAME) + + # Should return the same base64 data + result = storage.save_base64_image(sample_base64_png) + assert result == sample_base64_png + + # Directory should not be created in DATAFRAME mode (lazy initialization) + assert not storage.images_dir.exists() + + +def test_cleanup(media_storage, sample_base64_png): + """Test cleanup removes images directory.""" + # Save an image first + media_storage.save_base64_image(sample_base64_png) + assert media_storage.images_dir.exists() + assert len(list(media_storage.images_dir.iterdir())) > 0 + + # Cleanup should remove directory + media_storage.cleanup() + assert not media_storage.images_dir.exists() diff --git a/packages/data-designer-engine/tests/engine/storage/test_multimedia_storage.py b/packages/data-designer-engine/tests/engine/storage/test_multimedia_storage.py deleted file mode 100644 index ade76b5a..00000000 --- a/packages/data-designer-engine/tests/engine/storage/test_multimedia_storage.py +++ /dev/null @@ -1,182 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import base64 -import io - -# Explicitly import PIL.Image submodule to make it accessible as PIL.Image -# Python doesn't automatically import submodules when you import a package, -# so `import PIL` alone doesn't give you access to PIL.Image -import PIL.Image # noqa: E402 -import pytest - -from data_designer.engine.storage.multimedia_storage import IMAGES_SUBDIR, MultimediaStorage -from data_designer.lazy_heavy_imports import PIL - - -@pytest.fixture -def multimedia_storage(tmp_path): - """Create a MultimediaStorage instance with a temporary directory.""" - return MultimediaStorage(base_path=tmp_path) - - -@pytest.fixture -def sample_base64_png() -> str: - """Create a valid 1x1 PNG as base64.""" - img = PIL.Image.new("RGB", (1, 1), color="red") - buf = io.BytesIO() - img.save(buf, format="PNG") - png_bytes = buf.getvalue() - return base64.b64encode(png_bytes).decode() - - -@pytest.fixture -def sample_base64_jpg() -> str: - """Create a valid 1x1 JPEG as base64.""" - img = PIL.Image.new("RGB", (1, 1), color="blue") - buf = io.BytesIO() - img.save(buf, format="JPEG") - jpg_bytes = buf.getvalue() - return base64.b64encode(jpg_bytes).decode() - - -def test_multimedia_storage_init(tmp_path): - """Test MultimediaStorage initialization.""" - storage = MultimediaStorage(base_path=tmp_path) - assert storage.base_path == tmp_path - assert storage.images_dir == tmp_path / IMAGES_SUBDIR - assert storage.images_subdir == IMAGES_SUBDIR - assert storage.validate_images is True - # Should create images directory on init - assert storage.images_dir.exists() - - -def test_multimedia_storage_init_custom_subdir(tmp_path): - """Test MultimediaStorage initialization with custom subdirectory.""" - custom_subdir = "custom_images" - storage = MultimediaStorage(base_path=tmp_path, images_subdir=custom_subdir, validate_images=False) - assert storage.images_subdir == custom_subdir - assert storage.images_dir == tmp_path / custom_subdir - assert storage.validate_images is False - assert storage.images_dir.exists() - - -def test_save_base64_image_png(multimedia_storage, sample_base64_png): - """Test saving a PNG image from base64.""" - relative_path = multimedia_storage.save_base64_image(sample_base64_png) - - # Check return value format - assert relative_path.startswith(f"{IMAGES_SUBDIR}/") - assert relative_path.endswith(".png") - - # Check file exists on disk - full_path = multimedia_storage.base_path / relative_path - assert full_path.exists() - - # Verify file content - saved_bytes = full_path.read_bytes() - expected_bytes = base64.b64decode(sample_base64_png) - assert saved_bytes == expected_bytes - - -def test_save_base64_image_jpg(multimedia_storage, sample_base64_jpg): - """Test saving a JPEG image from base64.""" - relative_path = multimedia_storage.save_base64_image(sample_base64_jpg) - - # Check return value format - assert relative_path.startswith(f"{IMAGES_SUBDIR}/") - assert relative_path.endswith(".jpg") - - # Check file exists on disk - full_path = multimedia_storage.base_path / relative_path - assert full_path.exists() - - -def test_save_base64_image_with_data_uri(multimedia_storage, sample_base64_png): - """Test saving image from data URI format.""" - data_uri = f"data:image/png;base64,{sample_base64_png}" - relative_path = multimedia_storage.save_base64_image(data_uri) - - # Should successfully extract base64 and save - assert relative_path.startswith(f"{IMAGES_SUBDIR}/") - assert relative_path.endswith(".png") - - # Verify file exists and content is correct - full_path = multimedia_storage.base_path / relative_path - assert full_path.exists() - saved_bytes = full_path.read_bytes() - expected_bytes = base64.b64decode(sample_base64_png) - assert saved_bytes == expected_bytes - - -def test_save_base64_image_invalid_base64_raises_error(multimedia_storage): - """Test that invalid base64 data raises ValueError.""" - with pytest.raises(ValueError, match="Invalid base64"): - multimedia_storage.save_base64_image("not-valid-base64!!!") - - -def test_save_base64_image_multiple_images_unique_filenames(multimedia_storage, sample_base64_png): - """Test that multiple images get unique filenames.""" - path1 = multimedia_storage.save_base64_image(sample_base64_png) - path2 = multimedia_storage.save_base64_image(sample_base64_png) - - # Paths should be different (different UUIDs) - assert path1 != path2 - - # Both files should exist - assert (multimedia_storage.base_path / path1).exists() - assert (multimedia_storage.base_path / path2).exists() - - -def test_save_base64_image_validation_enabled(tmp_path, sample_base64_png): - """Test that validation is performed when enabled.""" - storage = MultimediaStorage(base_path=tmp_path, validate_images=True) - # Should succeed with valid image - relative_path = storage.save_base64_image(sample_base64_png) - assert relative_path.startswith(f"{IMAGES_SUBDIR}/") - - -def test_save_base64_image_validation_corrupted_image_raises_error(tmp_path): - """Test that corrupted image fails validation and is cleaned up.""" - storage = MultimediaStorage(base_path=tmp_path, validate_images=True) - - # Create base64 of invalid image data - corrupted_bytes = b"not a valid image" - corrupted_base64 = base64.b64encode(corrupted_bytes).decode() - - with pytest.raises(ValueError, match="Image validation failed"): - storage.save_base64_image(corrupted_base64) - - # Check that no files were left behind - assert len(list(storage.images_dir.iterdir())) == 0 - - -def test_save_base64_image_validation_disabled(tmp_path): - """Test that validation can be disabled.""" - storage = MultimediaStorage(base_path=tmp_path, validate_images=False) - - # Create base64 of invalid image data - corrupted_bytes = b"not a valid image" - corrupted_base64 = base64.b64encode(corrupted_bytes).decode() - - # Should succeed without validation - relative_path = storage.save_base64_image(corrupted_base64) - assert relative_path.startswith(f"{IMAGES_SUBDIR}/") - - # File should exist even though it's invalid - full_path = storage.base_path / relative_path - assert full_path.exists() - - -def test_cleanup(multimedia_storage, sample_base64_png): - """Test cleanup removes images directory.""" - # Save an image first - multimedia_storage.save_base64_image(sample_base64_png) - assert multimedia_storage.images_dir.exists() - assert len(list(multimedia_storage.images_dir.iterdir())) > 0 - - # Cleanup should remove directory - multimedia_storage.cleanup() - assert not multimedia_storage.images_dir.exists() diff --git a/packages/data-designer-engine/tests/engine/test_configurable_task.py b/packages/data-designer-engine/tests/engine/test_configurable_task.py index f20936a2..6e3673de 100644 --- a/packages/data-designer-engine/tests/engine/test_configurable_task.py +++ b/packages/data-designer-engine/tests/engine/test_configurable_task.py @@ -25,7 +25,7 @@ def test_configurable_task_generic_type_variables() -> None: assert TaskConfigT.__bound__ == ConfigBase -def test_configurable_task_concrete_implementation() -> None: +def test_configurable_task_concrete_implementation(tmp_path) -> None: class TestConfig(ConfigBase): value: str @@ -41,13 +41,8 @@ def _initialize(self) -> None: pass config = TestConfig(value="test") - mock_artifact_storage = Mock(spec=ArtifactStorage) - mock_artifact_storage.dataset_name = "test_dataset" - mock_artifact_storage.final_dataset_folder_name = "final_dataset" - mock_artifact_storage.partial_results_folder_name = "partial_results" - mock_artifact_storage.dropped_columns_folder_name = "dropped_columns" - mock_artifact_storage.processors_outputs_folder_name = "processors_outputs" - resource_provider = ResourceProvider(artifact_storage=mock_artifact_storage) + artifact_storage = ArtifactStorage(artifact_path=tmp_path) + resource_provider = ResourceProvider(artifact_storage=artifact_storage) task = TestTask(config=config, resource_provider=resource_provider) @@ -55,7 +50,7 @@ def _initialize(self) -> None: assert task._resource_provider == resource_provider -def test_configurable_task_config_validation() -> None: +def test_configurable_task_config_validation(tmp_path) -> None: class TestConfig(ConfigBase): value: str @@ -69,13 +64,8 @@ def _validate(self) -> None: raise ValueError("Invalid config") config = TestConfig(value="test") - mock_artifact_storage = Mock(spec=ArtifactStorage) - mock_artifact_storage.dataset_name = "test_dataset" - mock_artifact_storage.final_dataset_folder_name = "final_dataset" - mock_artifact_storage.partial_results_folder_name = "partial_results" - mock_artifact_storage.dropped_columns_folder_name = "dropped_columns" - mock_artifact_storage.processors_outputs_folder_name = "processors_outputs" - resource_provider = ResourceProvider(artifact_storage=mock_artifact_storage) + artifact_storage = ArtifactStorage(artifact_path=tmp_path) + resource_provider = ResourceProvider(artifact_storage=artifact_storage) task = TestTask(config=config, resource_provider=resource_provider) assert task._config.value == "test" @@ -85,7 +75,7 @@ def _validate(self) -> None: TestTask(config=invalid_config, resource_provider=resource_provider) -def test_configurable_task_resource_validation() -> None: +def test_configurable_task_resource_validation(tmp_path) -> None: class TestConfig(ConfigBase): value: str @@ -102,14 +92,9 @@ def _initialize(self) -> None: config = TestConfig(value="test") - mock_artifact_storage = Mock(spec=ArtifactStorage) - mock_artifact_storage.dataset_name = "test_dataset" - mock_artifact_storage.final_dataset_folder_name = "final_dataset" - mock_artifact_storage.partial_results_folder_name = "partial_results" - mock_artifact_storage.dropped_columns_folder_name = "dropped_columns" - mock_artifact_storage.processors_outputs_folder_name = "processors_outputs" + artifact_storage = ArtifactStorage(artifact_path=tmp_path) mock_model_registry = Mock(spec=ModelRegistry) - resource_provider = ResourceProvider(artifact_storage=mock_artifact_storage, model_registry=mock_model_registry) + resource_provider = ResourceProvider(artifact_storage=artifact_storage, model_registry=mock_model_registry) task = TestTask(config=config, resource_provider=resource_provider) assert task._resource_provider == resource_provider From 1677f066e5a228c418d558633ece69969cd7d122 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 10:17:32 -0700 Subject: [PATCH 27/64] track images generated in usage --- .../config/utils/image_helpers.py | 25 ++ .../tests/config/utils/test_image_helpers.py | 27 ++ .../src/data_designer/engine/models/facade.py | 403 ++++++++---------- .../data_designer/engine/models/registry.py | 4 + .../src/data_designer/engine/models/usage.py | 23 +- .../tests/engine/models/test_facade.py | 147 +++++++ .../tests/engine/models/test_usage.py | 60 ++- 7 files changed, 457 insertions(+), 232 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py index 67803aff..2069d9bf 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py @@ -23,6 +23,31 @@ # WEBP uses RIFF header - handled separately } +# Patterns for detecting diffusion-based image generation models (DALL-E, Stable Diffusion, Imagen, etc.) +_IMAGE_DIFFUSION_MODEL_PATTERNS = ( + "dall-e", + "dalle", + "stable-diffusion", + "sd-", + "sd_", + "imagen", +) + + +def is_image_diffusion_model(model_name: str) -> bool: + """Return True if the model is a diffusion-based image generation model. + + Diffusion models use the image_generation API (e.g. DALL-E, Stable Diffusion, Imagen). + All other image models are assumed to use the chat/completions API. + + Args: + model_name: Model name or identifier (e.g. from provider). + + Returns: + True if the model is detected as diffusion-based, False otherwise. + """ + return any(pattern in model_name.lower() for pattern in _IMAGE_DIFFUSION_MODEL_PATTERNS) + def extract_base64_from_data_uri(data: str) -> str: """Extract base64 from data URI or return as-is. diff --git a/packages/data-designer-config/tests/config/utils/test_image_helpers.py b/packages/data-designer-config/tests/config/utils/test_image_helpers.py index e0eb0370..aa1ca451 100644 --- a/packages/data-designer-config/tests/config/utils/test_image_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_image_helpers.py @@ -16,6 +16,7 @@ extract_base64_from_data_uri, get_supported_image_extensions, is_base64_image, + is_image_diffusion_model, is_image_path, is_image_url, load_image_path_to_base64, @@ -144,6 +145,32 @@ def test_is_image_url_non_http(): assert is_image_url("ftp://example.com/image.png") is False +# Tests for is_image_diffusion_model + + +def test_is_image_diffusion_model_dall_e(): + assert is_image_diffusion_model("dall-e-3") is True + assert is_image_diffusion_model("DALL-E-2") is True + assert is_image_diffusion_model("openai/dalle-2") is True + + +def test_is_image_diffusion_model_stable_diffusion(): + assert is_image_diffusion_model("stable-diffusion-xl") is True + assert is_image_diffusion_model("sd-2.1") is True + assert is_image_diffusion_model("sd_1.5") is True + + +def test_is_image_diffusion_model_imagen(): + assert is_image_diffusion_model("imagen-3") is True + assert is_image_diffusion_model("google/imagen") is True + + +def test_is_image_diffusion_model_chat_completion_image_models(): + assert is_image_diffusion_model("gemini-3-pro-image-preview") is False + assert is_image_diffusion_model("gpt-5-image") is False + assert is_image_diffusion_model("flux.2-pro") is False + + # Tests for validate_image diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index d13273f4..11f6e9ec 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -9,7 +9,10 @@ from typing import TYPE_CHECKING, Any from data_designer.config.models import GenerationType, ModelConfig, ModelProvider -from data_designer.config.utils.image_helpers import extract_base64_from_data_uri +from data_designer.config.utils.image_helpers import ( + extract_base64_from_data_uri, + is_image_diffusion_model, +) from data_designer.engine.mcp.errors import MCPConfigurationError from data_designer.engine.model_provider import ModelProviderRegistry from data_designer.engine.models.errors import ( @@ -20,7 +23,7 @@ ) from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs from data_designer.engine.models.parsers.errors import ParserException -from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats +from data_designer.engine.models.usage import ImageUsageStats, ModelUsageStats, RequestUsageStats, TokenUsageStats from data_designer.engine.models.utils import ChatMessage, prompt_to_messages from data_designer.engine.secret_resolver import SecretResolver from data_designer.lazy_heavy_imports import litellm @@ -39,16 +42,6 @@ def _identity(x: Any) -> Any: logger = logging.getLogger(__name__) -# Patterns for detecting diffusion-based image generation models -DIFFUSION_MODEL_PATTERNS = [ - "dall-e", - "dalle", - "stable-diffusion", - "sd-", - "sd_", - "imagen", -] - class ModelFacade: def __init__( @@ -117,7 +110,7 @@ def completion( raise e finally: if not skip_usage_tracking and response is not None: - self._track_usage(response) + self._track_token_usage_from_completion(response) def consolidate_kwargs(self, **kwargs) -> dict[str, Any]: # Remove purpose from kwargs to avoid passing it to the model @@ -129,16 +122,153 @@ def consolidate_kwargs(self, **kwargs) -> dict[str, Any]: kwargs["extra_headers"] = self.model_provider.extra_headers return kwargs - def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: - if tool_alias is None: - return None - if self._mcp_registry is None: - raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.") + @catch_llm_exceptions + def generate( + self, + prompt: str, + *, + parser: Callable[[str], Any] = _identity, + system_prompt: str | None = None, + multi_modal_context: list[dict[str, Any]] | None = None, + tool_alias: str | None = None, + max_correction_steps: int = 0, + max_conversation_restarts: int = 0, + skip_usage_tracking: bool = False, + purpose: str | None = None, + **kwargs, + ) -> tuple[Any, list[ChatMessage]]: + """Generate a parsed output with correction steps. - try: - return self._mcp_registry.get_mcp(tool_alias=tool_alias) - except ValueError as exc: - raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc + This generation call will attempt to generate an output which is + valid according to the specified parser, where "valid" implies + that the parser can process the LLM response without raising + an exception. + + `ParserExceptions` are routed back + to the LLM as new rounds in the conversation, where the LLM is provided its + earlier response along with the "user" role responding with the exception string + (not traceback). This will continue for the number of rounds specified by + `max_correction_steps`. + + Args: + prompt (str): Task prompt. + system_prompt (str, optional): Optional system instructions. If not specified, + no system message is provided and the model should use its default system + prompt. + parser (func(str) -> Any): A function applied to the LLM response which processes + an LLM response into some output object. Default: identity function. + tool_alias (str | None): Optional tool configuration alias. When provided, + the model may call permitted tools from the configured MCP providers. + The alias must reference a ToolConfig registered in the MCPRegistry. + max_correction_steps (int): Maximum number of correction rounds permitted + within a single conversation. Note, many rounds can lead to increasing + context size without necessarily improving performance -- small language + models can enter repeated cycles which will not be solved with more steps. + Default: `0` (no correction). + max_conversation_restarts (int): Maximum number of full conversation restarts permitted + if generation fails. Default: `0` (no restarts). + skip_usage_tracking (bool): Whether to skip usage tracking. Default: `False`. + purpose (str): The purpose of the model usage to show as context in the error message. + It is expected to be used by the @catch_llm_exceptions decorator. + **kwargs: Additional arguments to pass to the model. + + Returns: + A tuple containing: + - The parsed output object from the parser. + - The full trace of ChatMessage entries in the conversation, including any tool calls, + corrections, and reasoning traces. Callers can decide whether to store this. + + Raises: + GenerationValidationFailureError: If the maximum number of retries or + correction steps are met and the last response failures on + generation validation. + MCPConfigurationError: If tool_alias is specified but no MCPRegistry is configured. + """ + output_obj = None + tool_schemas = None + tool_call_turns = 0 + total_tool_calls = 0 + curr_num_correction_steps = 0 + curr_num_restarts = 0 + + mcp_facade = self._get_mcp_facade(tool_alias) + + # Checkpoint for restarts - updated after tool calls so we don't repeat them + restart_checkpoint = prompt_to_messages( + user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context + ) + checkpoint_tool_call_turns = 0 + messages: list[ChatMessage] = deepcopy(restart_checkpoint) + + if mcp_facade is not None: + tool_schemas = mcp_facade.get_tool_schemas() + + while True: + completion_kwargs = dict(kwargs) + if tool_schemas is not None: + completion_kwargs["tools"] = tool_schemas + + completion_response = self.completion( + messages, + skip_usage_tracking=skip_usage_tracking, + **completion_kwargs, + ) + + # Process any tool calls in the response (handles parallel tool calling) + if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response): + tool_call_turns += 1 + total_tool_calls += mcp_facade.tool_call_count(completion_response) + + if tool_call_turns > mcp_facade.max_tool_call_turns: + # Gracefully refuse tool calls when budget is exhausted + messages.extend(mcp_facade.refuse_completion_response(completion_response)) + else: + messages.extend(mcp_facade.process_completion_response(completion_response)) + + # Update checkpoint so restarts don't repeat tool calls + restart_checkpoint = deepcopy(messages) + checkpoint_tool_call_turns = tool_call_turns + + continue # Back to top + + # No tool calls remaining to process + response = completion_response.choices[0].message.content or "" + reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None) + messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) + curr_num_correction_steps += 1 + + try: + output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below + break + except ParserException as exc: + if max_correction_steps == 0 and max_conversation_restarts == 0: + raise GenerationValidationFailureError( + "Unsuccessful generation attempt. No retries were attempted." + ) from exc + + if curr_num_correction_steps <= max_correction_steps: + # Add user message with error for correction + messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc)))) + + elif curr_num_restarts < max_conversation_restarts: + curr_num_correction_steps = 0 + curr_num_restarts += 1 + messages = deepcopy(restart_checkpoint) + tool_call_turns = checkpoint_tool_call_turns + + else: + raise GenerationValidationFailureError( + f"Unsuccessful generation despite {max_correction_steps} correction steps " + f"and {max_conversation_restarts} conversation restarts." + ) from exc + + if not skip_usage_tracking and mcp_facade is not None: + self._usage_stats.tool_usage.extend( + tool_calls=total_tool_calls, + tool_call_turns=tool_call_turns, + ) + + return output_obj, messages @catch_llm_exceptions def generate_text_embeddings( @@ -171,7 +301,7 @@ def generate_text_embeddings( raise e finally: if not skip_usage_tracking and response is not None: - self._track_usage_from_embedding(response) + self._track_token_usage_from_embedding(response) @catch_llm_exceptions def generate_image(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]: @@ -201,22 +331,27 @@ def generate_image(self, prompt: str, skip_usage_tracking: bool = False, **kwarg ) # Auto-detect API type based on model name - if self._is_diffusion_model(): - return self._generate_image_diffusion(prompt, skip_usage_tracking, **kwargs) + if is_image_diffusion_model(self.model_name): + images = self._generate_image_diffusion(prompt, skip_usage_tracking, **kwargs) else: - return self._generate_image_chat_completion(prompt, skip_usage_tracking, **kwargs) + images = self._generate_image_chat_completion(prompt, skip_usage_tracking, **kwargs) - def _is_diffusion_model(self) -> bool: - """Detect if model uses diffusion API based on name patterns. + # Track image usage + if not skip_usage_tracking and len(images) > 0: + self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) - Diffusion models include DALL-E, Stable Diffusion, and Imagen variants. - All other image models are assumed to use chat completions API. + return images - Returns: - True if model is detected as diffusion-based, False otherwise - """ - model_lower = self.model_name.lower() - return any(pattern in model_lower for pattern in DIFFUSION_MODEL_PATTERNS) + def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: + if tool_alias is None: + return None + if self._mcp_registry is None: + raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.") + + try: + return self._mcp_registry.get_mcp(tool_alias=tool_alias) + except ValueError as exc: + raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc def _generate_image_chat_completion(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]: """Generate image(s) using autoregressive model via chat completions API. @@ -311,155 +446,7 @@ def _generate_image_diffusion(self, prompt: str, skip_usage_tracking: bool = Fal raise finally: if not skip_usage_tracking and response is not None: - self._track_usage_from_image_diffusion(response) - - @catch_llm_exceptions - def generate( - self, - prompt: str, - *, - parser: Callable[[str], Any] = _identity, - system_prompt: str | None = None, - multi_modal_context: list[dict[str, Any]] | None = None, - tool_alias: str | None = None, - max_correction_steps: int = 0, - max_conversation_restarts: int = 0, - skip_usage_tracking: bool = False, - purpose: str | None = None, - **kwargs, - ) -> tuple[Any, list[ChatMessage]]: - """Generate a parsed output with correction steps. - - This generation call will attempt to generate an output which is - valid according to the specified parser, where "valid" implies - that the parser can process the LLM response without raising - an exception. - - `ParserExceptions` are routed back - to the LLM as new rounds in the conversation, where the LLM is provided its - earlier response along with the "user" role responding with the exception string - (not traceback). This will continue for the number of rounds specified by - `max_correction_steps`. - - Args: - prompt (str): Task prompt. - system_prompt (str, optional): Optional system instructions. If not specified, - no system message is provided and the model should use its default system - prompt. - parser (func(str) -> Any): A function applied to the LLM response which processes - an LLM response into some output object. Default: identity function. - tool_alias (str | None): Optional tool configuration alias. When provided, - the model may call permitted tools from the configured MCP providers. - The alias must reference a ToolConfig registered in the MCPRegistry. - max_correction_steps (int): Maximum number of correction rounds permitted - within a single conversation. Note, many rounds can lead to increasing - context size without necessarily improving performance -- small language - models can enter repeated cycles which will not be solved with more steps. - Default: `0` (no correction). - max_conversation_restarts (int): Maximum number of full conversation restarts permitted - if generation fails. Default: `0` (no restarts). - skip_usage_tracking (bool): Whether to skip usage tracking. Default: `False`. - purpose (str): The purpose of the model usage to show as context in the error message. - It is expected to be used by the @catch_llm_exceptions decorator. - **kwargs: Additional arguments to pass to the model. - - Returns: - A tuple containing: - - The parsed output object from the parser. - - The full trace of ChatMessage entries in the conversation, including any tool calls, - corrections, and reasoning traces. Callers can decide whether to store this. - - Raises: - GenerationValidationFailureError: If the maximum number of retries or - correction steps are met and the last response failures on - generation validation. - MCPConfigurationError: If tool_alias is specified but no MCPRegistry is configured. - """ - output_obj = None - tool_schemas = None - tool_call_turns = 0 - total_tool_calls = 0 - curr_num_correction_steps = 0 - curr_num_restarts = 0 - - mcp_facade = self._get_mcp_facade(tool_alias) - - # Checkpoint for restarts - updated after tool calls so we don't repeat them - restart_checkpoint = prompt_to_messages( - user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context - ) - checkpoint_tool_call_turns = 0 - messages: list[ChatMessage] = deepcopy(restart_checkpoint) - - if mcp_facade is not None: - tool_schemas = mcp_facade.get_tool_schemas() - - while True: - completion_kwargs = dict(kwargs) - if tool_schemas is not None: - completion_kwargs["tools"] = tool_schemas - - completion_response = self.completion( - messages, - skip_usage_tracking=skip_usage_tracking, - **completion_kwargs, - ) - - # Process any tool calls in the response (handles parallel tool calling) - if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response): - tool_call_turns += 1 - total_tool_calls += mcp_facade.tool_call_count(completion_response) - - if tool_call_turns > mcp_facade.max_tool_call_turns: - # Gracefully refuse tool calls when budget is exhausted - messages.extend(mcp_facade.refuse_completion_response(completion_response)) - else: - messages.extend(mcp_facade.process_completion_response(completion_response)) - - # Update checkpoint so restarts don't repeat tool calls - restart_checkpoint = deepcopy(messages) - checkpoint_tool_call_turns = tool_call_turns - - continue # Back to top - - # No tool calls remaining to process - response = completion_response.choices[0].message.content or "" - reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None) - messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) - curr_num_correction_steps += 1 - - try: - output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below - break - except ParserException as exc: - if max_correction_steps == 0 and max_conversation_restarts == 0: - raise GenerationValidationFailureError( - "Unsuccessful generation attempt. No retries were attempted." - ) from exc - - if curr_num_correction_steps <= max_correction_steps: - # Add user message with error for correction - messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc)))) - - elif curr_num_restarts < max_conversation_restarts: - curr_num_correction_steps = 0 - curr_num_restarts += 1 - messages = deepcopy(restart_checkpoint) - tool_call_turns = checkpoint_tool_call_turns - - else: - raise GenerationValidationFailureError( - f"Unsuccessful generation despite {max_correction_steps} correction steps " - f"and {max_conversation_restarts} conversation restarts." - ) from exc - - if not skip_usage_tracking and mcp_facade is not None: - self._usage_stats.tool_usage.extend( - tool_calls=total_tool_calls, - tool_call_turns=tool_call_turns, - ) - - return output_obj, messages + self._track_token_usage_from_image_diffusion(response) def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict: provider = self._model_provider_registry.get_provider(model_config.provider) @@ -478,7 +465,7 @@ def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.Deployme "litellm_params": litellm_params.model_dump(), } - def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> None: + def _track_token_usage_from_completion(self, response: litellm.types.utils.ModelResponse | None) -> None: if response is None: self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) return @@ -495,7 +482,7 @@ def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> No request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) - def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None: + def _track_token_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None: if response is None: self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) return @@ -508,27 +495,12 @@ def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingRes request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) - def _track_usage_from_response(self, response: litellm.types.utils.ResponseResponse | None) -> None: - """Track usage from Responses API response.""" + def _track_token_usage_from_image_diffusion(self, response: litellm.types.utils.ImageResponse | None) -> None: + """Track token usage from image_generation API response.""" if response is None: self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) return - if response.usage is not None: - input_tokens = getattr(response.usage, "input_tokens", 0) or 0 - output_tokens = getattr(response.usage, "output_tokens", 0) or 0 - self._usage_stats.extend( - token_usage=TokenUsageStats( - input_tokens=input_tokens, - output_tokens=output_tokens, - ), - request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), - ) - def _track_usage_from_image_diffusion(self, response: litellm.types.utils.ImageResponse | None) -> None: - """Track usage from image_generation API response.""" - if response is None: - self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) - return if response.usage is not None and isinstance(response.usage, litellm.types.utils.ImageUsage): self._usage_stats.extend( token_usage=TokenUsageStats( @@ -537,28 +509,3 @@ def _track_usage_from_image_diffusion(self, response: litellm.types.utils.ImageR ), request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) - - def _download_url_to_base64(self, url: str) -> str: - """Download image from URL and convert to base64. - - Args: - url: Image URL - - Returns: - Base64-encoded image string - - Raises: - ModelAPIError: If download fails - """ - import base64 - - from data_designer.lazy_heavy_imports import httpx - - try: - with httpx.Client(timeout=30.0) as client: - response = client.get(url) - response.raise_for_status() - image_bytes = response.content - return base64.b64encode(image_bytes).decode("utf-8") - except Exception as e: - raise ModelAPIError(f"Failed to download image from URL {url}: {e}") from e diff --git a/packages/data-designer-engine/src/data_designer/engine/models/registry.py b/packages/data-designer-engine/src/data_designer/engine/models/registry.py index 56945941..2878f64e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/registry.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/registry.py @@ -120,6 +120,10 @@ def log_model_usage(self, total_time_elapsed: float) -> None: f"turns={tool_usage['total_tool_call_turns']}" ) + if image_usage := stats.get("image_usage"): + total_images = image_usage["total_images"] + logger.info(f"{LOG_INDENT}images: total={total_images}") + if model_index < len(sorted_model_names) - 1: logger.info(LOG_INDENT.rstrip()) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/usage.py b/packages/data-designer-engine/src/data_designer/engine/models/usage.py index f44a31ae..169ef1bb 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/usage.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/usage.py @@ -71,10 +71,23 @@ def merge(self, other: ToolUsageStats) -> ToolUsageStats: return self +class ImageUsageStats(BaseModel): + total_images: int = 0 + + @property + def has_usage(self) -> bool: + return self.total_images > 0 + + def extend(self, *, images: int) -> None: + """Extend stats with generated images count.""" + self.total_images += images + + class ModelUsageStats(BaseModel): token_usage: TokenUsageStats = TokenUsageStats() request_usage: RequestUsageStats = RequestUsageStats() tool_usage: ToolUsageStats = ToolUsageStats() + image_usage: ImageUsageStats = ImageUsageStats() @property def has_usage(self) -> bool: @@ -86,6 +99,7 @@ def extend( token_usage: TokenUsageStats | None = None, request_usage: RequestUsageStats | None = None, tool_usage: ToolUsageStats | None = None, + image_usage: ImageUsageStats | None = None, ) -> None: if token_usage is not None: self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens) @@ -95,9 +109,16 @@ def extend( ) if tool_usage is not None: self.tool_usage.merge(tool_usage) + if image_usage is not None: + self.image_usage.extend(images=image_usage.total_images) def get_usage_stats(self, *, total_time_elapsed: float) -> dict: - exclude = {"tool_usage"} if not self.tool_usage.has_usage else None + exclude = set() + if not self.tool_usage.has_usage: + exclude.add("tool_usage") + if not self.image_usage.has_usage: + exclude.add("image_usage") + exclude = exclude if exclude else None return self.model_dump(exclude=exclude) | { "tokens_per_second": int(self.token_usage.total_tokens / total_time_elapsed) if total_time_elapsed > 0 diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index c0ab9cd3..78473d63 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -989,3 +989,150 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe with patch.object(ModelFacade, "completion", new=_completion): with pytest.raises(MCPToolError, match="Invalid tool arguments"): model.generate(prompt="question", parser=lambda x: x, tool_alias="tools") + + +# ============================================================================= +# Image generation tests +# ============================================================================= + + +@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) +def test_generate_image_diffusion_tracks_image_usage( + mock_image_generation: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image tracks image usage for diffusion models.""" + from litellm.types.utils import ImageObject, ImageResponse + + # Mock response with 3 images + mock_response = ImageResponse( + data=[ + ImageObject(b64_json="image1_base64"), + ImageObject(b64_json="image2_base64"), + ImageObject(b64_json="image3_base64"), + ] + ) + mock_image_generation.return_value = mock_response + + # Verify initial state + assert stub_model_facade.usage_stats.image_usage.total_images == 0 + + # Generate images + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): + images = stub_model_facade.generate_image(prompt="test prompt", n=3) + + # Verify results + assert len(images) == 3 + assert images == ["image1_base64", "image2_base64", "image3_base64"] + + # Verify image usage was tracked + assert stub_model_facade.usage_stats.image_usage.total_images == 3 + assert stub_model_facade.usage_stats.image_usage.has_usage is True + + +@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +def test_generate_image_chat_completion_tracks_image_usage( + mock_completion: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image tracks image usage for chat completion models.""" + from litellm.types.utils import Choices, ImageURLListItem, Message, ModelResponse + + # Mock response with images attribute (Message requires type and index per ImageURLListItem) + mock_message = Message( + role="assistant", + content="", + images=[ + ImageURLListItem(type="image_url", image_url={"url": ""}, index=0), + ImageURLListItem(type="image_url", image_url={"url": ""}, index=1), + ], + ) + mock_response = ModelResponse(choices=[Choices(message=mock_message)]) + mock_completion.return_value = mock_response + + # Verify initial state + assert stub_model_facade.usage_stats.image_usage.total_images == 0 + + # Generate images + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): + images = stub_model_facade.generate_image(prompt="test prompt") + + # Verify results + assert len(images) == 2 + assert images == ["image1", "image2"] + + # Verify image usage was tracked + assert stub_model_facade.usage_stats.image_usage.total_images == 2 + assert stub_model_facade.usage_stats.image_usage.has_usage is True + + +@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) +def test_generate_image_skip_usage_tracking( + mock_image_generation: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image respects skip_usage_tracking flag.""" + from litellm.types.utils import ImageObject, ImageResponse + + mock_response = ImageResponse( + data=[ + ImageObject(b64_json="image1_base64"), + ImageObject(b64_json="image2_base64"), + ] + ) + mock_image_generation.return_value = mock_response + + # Verify initial state + assert stub_model_facade.usage_stats.image_usage.total_images == 0 + + # Generate images with skip_usage_tracking=True + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): + images = stub_model_facade.generate_image(prompt="test prompt", skip_usage_tracking=True) + + # Verify results + assert len(images) == 2 + + # Verify image usage was NOT tracked + assert stub_model_facade.usage_stats.image_usage.total_images == 0 + assert stub_model_facade.usage_stats.image_usage.has_usage is False + + +@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) +def test_generate_image_accumulates_usage( + mock_image_generation: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image accumulates image usage across multiple calls.""" + from litellm.types.utils import ImageObject, ImageResponse + + # First call - 2 images + mock_response1 = ImageResponse( + data=[ + ImageObject(b64_json="image1"), + ImageObject(b64_json="image2"), + ] + ) + # Second call - 3 images + mock_response2 = ImageResponse( + data=[ + ImageObject(b64_json="image3"), + ImageObject(b64_json="image4"), + ImageObject(b64_json="image5"), + ] + ) + mock_image_generation.side_effect = [mock_response1, mock_response2] + + # Verify initial state + assert stub_model_facade.usage_stats.image_usage.total_images == 0 + + # First generation + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): + images1 = stub_model_facade.generate_image(prompt="test1") + assert len(images1) == 2 + assert stub_model_facade.usage_stats.image_usage.total_images == 2 + + # Second generation + images2 = stub_model_facade.generate_image(prompt="test2") + assert len(images2) == 3 + # Usage should accumulate + assert stub_model_facade.usage_stats.image_usage.total_images == 5 diff --git a/packages/data-designer-engine/tests/engine/models/test_usage.py b/packages/data-designer-engine/tests/engine/models/test_usage.py index 8e7adb04..2c4f783f 100644 --- a/packages/data-designer-engine/tests/engine/models/test_usage.py +++ b/packages/data-designer-engine/tests/engine/models/test_usage.py @@ -1,7 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats, ToolUsageStats +from data_designer.engine.models.usage import ( + ImageUsageStats, + ModelUsageStats, + RequestUsageStats, + TokenUsageStats, + ToolUsageStats, +) def test_token_usage_stats() -> None: @@ -32,6 +38,20 @@ def test_request_usage_stats() -> None: assert request_usage_stats.has_usage is True +def test_image_usage_stats() -> None: + image_usage_stats = ImageUsageStats() + assert image_usage_stats.total_images == 0 + assert image_usage_stats.has_usage is False + + image_usage_stats.extend(images=5) + assert image_usage_stats.total_images == 5 + assert image_usage_stats.has_usage is True + + image_usage_stats.extend(images=3) + assert image_usage_stats.total_images == 8 + assert image_usage_stats.has_usage is True + + def test_tool_usage_stats_empty_state() -> None: """Test ToolUsageStats initialization with empty state.""" tool_usage = ToolUsageStats() @@ -132,9 +152,10 @@ def test_model_usage_stats() -> None: assert model_usage_stats.token_usage.output_tokens == 0 assert model_usage_stats.request_usage.successful_requests == 0 assert model_usage_stats.request_usage.failed_requests == 0 + assert model_usage_stats.image_usage.total_images == 0 assert model_usage_stats.has_usage is False - # tool_usage is excluded when has_usage is False + # tool_usage and image_usage are excluded when has_usage is False assert model_usage_stats.get_usage_stats(total_time_elapsed=10) == { "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, "request_usage": {"successful_requests": 0, "failed_requests": 0, "total_requests": 0}, @@ -152,7 +173,7 @@ def test_model_usage_stats() -> None: assert model_usage_stats.request_usage.failed_requests == 1 assert model_usage_stats.has_usage is True - # tool_usage is excluded when has_usage is False + # tool_usage and image_usage are excluded when has_usage is False assert model_usage_stats.get_usage_stats(total_time_elapsed=2) == { "token_usage": {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, "request_usage": {"successful_requests": 2, "failed_requests": 1, "total_requests": 3}, @@ -177,3 +198,36 @@ def test_model_usage_stats_extend_with_tool_usage() -> None: assert stats1.tool_usage.total_tool_call_turns == 6 assert stats1.tool_usage.total_generations == 4 assert stats1.tool_usage.generations_with_tools == 3 + + +def test_model_usage_stats_with_image_usage() -> None: + """Test that ModelUsageStats includes image_usage when it has usage.""" + model_usage_stats = ModelUsageStats() + model_usage_stats.extend( + token_usage=TokenUsageStats(input_tokens=10, output_tokens=20), + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + image_usage=ImageUsageStats(total_images=5), + ) + + assert model_usage_stats.image_usage.total_images == 5 + assert model_usage_stats.image_usage.has_usage is True + + # image_usage should be included in output + usage_stats = model_usage_stats.get_usage_stats(total_time_elapsed=2) + assert "image_usage" in usage_stats + assert usage_stats["image_usage"] == {"total_images": 5} + + +def test_model_usage_stats_exclude_unused_stats() -> None: + """Test that ModelUsageStats excludes tool_usage and image_usage when they have no usage.""" + model_usage_stats = ModelUsageStats() + model_usage_stats.extend( + token_usage=TokenUsageStats(input_tokens=10, output_tokens=20), + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + ) + + usage_stats = model_usage_stats.get_usage_stats(total_time_elapsed=2) + assert "tool_usage" not in usage_stats + assert "image_usage" not in usage_stats + assert "token_usage" in usage_stats + assert "request_usage" in usage_stats From 3b4acf19202db778f6905f0c1bd27ca984cdaffb Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 10:55:49 -0700 Subject: [PATCH 28/64] fix image usage tracking --- .../config/utils/image_helpers.py | 9 ++++---- .../src/data_designer/engine/models/usage.py | 2 +- .../tests/engine/models/test_usage.py | 22 +++++++++++++++++++ 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py index 2069d9bf..9fc4e2b0 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py @@ -23,8 +23,8 @@ # WEBP uses RIFF header - handled separately } -# Patterns for detecting diffusion-based image generation models (DALL-E, Stable Diffusion, Imagen, etc.) -_IMAGE_DIFFUSION_MODEL_PATTERNS = ( +# Patterns for diffusion-based image models only (use image_generation API). +IMAGE_DIFFUSION_MODEL_PATTERNS = ( "dall-e", "dalle", "stable-diffusion", @@ -37,8 +37,7 @@ def is_image_diffusion_model(model_name: str) -> bool: """Return True if the model is a diffusion-based image generation model. - Diffusion models use the image_generation API (e.g. DALL-E, Stable Diffusion, Imagen). - All other image models are assumed to use the chat/completions API. + Args: model_name: Model name or identifier (e.g. from provider). @@ -46,7 +45,7 @@ def is_image_diffusion_model(model_name: str) -> bool: Returns: True if the model is detected as diffusion-based, False otherwise. """ - return any(pattern in model_name.lower() for pattern in _IMAGE_DIFFUSION_MODEL_PATTERNS) + return any(pattern in model_name.lower() for pattern in IMAGE_DIFFUSION_MODEL_PATTERNS) def extract_base64_from_data_uri(data: str) -> str: diff --git a/packages/data-designer-engine/src/data_designer/engine/models/usage.py b/packages/data-designer-engine/src/data_designer/engine/models/usage.py index 169ef1bb..64e82b47 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/usage.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/usage.py @@ -91,7 +91,7 @@ class ModelUsageStats(BaseModel): @property def has_usage(self) -> bool: - return self.token_usage.has_usage and self.request_usage.has_usage + return self.token_usage.has_usage or self.request_usage.has_usage or self.image_usage.has_usage def extend( self, diff --git a/packages/data-designer-engine/tests/engine/models/test_usage.py b/packages/data-designer-engine/tests/engine/models/test_usage.py index 2c4f783f..2bfea4b4 100644 --- a/packages/data-designer-engine/tests/engine/models/test_usage.py +++ b/packages/data-designer-engine/tests/engine/models/test_usage.py @@ -218,6 +218,28 @@ def test_model_usage_stats_with_image_usage() -> None: assert usage_stats["image_usage"] == {"total_images": 5} +def test_model_usage_stats_has_usage_any_of() -> None: + """Test that has_usage is True when any of token, request, or image usage is present.""" + # Only token usage + stats = ModelUsageStats() + stats.extend(token_usage=TokenUsageStats(input_tokens=1, output_tokens=0)) + assert stats.has_usage is True + + # Only request usage (e.g. diffusion API without token counts) + stats = ModelUsageStats() + stats.extend(request_usage=RequestUsageStats(successful_requests=1, failed_requests=0)) + assert stats.has_usage is True + + # Only image usage + stats = ModelUsageStats() + stats.extend(image_usage=ImageUsageStats(total_images=2)) + assert stats.has_usage is True + + # None of the three + stats = ModelUsageStats() + assert stats.has_usage is False + + def test_model_usage_stats_exclude_unused_stats() -> None: """Test that ModelUsageStats excludes tool_usage and image_usage when they have no usage.""" model_usage_stats = ModelUsageStats() From 33b4211490519ac6f2e3bd38cd5273d945f58718 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 11:52:15 -0700 Subject: [PATCH 29/64] test clean up --- .../src/data_designer/config/models.py | 28 +------ .../config/utils/image_helpers.py | 2 - .../tests/config/test_models.py | 30 ++++++++ .../data_designer/engine/models/registry.py | 2 +- .../tests/engine/models/test_facade.py | 77 ++++++++++--------- 5 files changed, 76 insertions(+), 63 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index 3dab2d8d..8b16b4bc 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -425,21 +425,7 @@ def generate_kwargs(self) -> dict[str, float | int]: class ImageInferenceParams(BaseInferenceParams): """Configuration for image generation models. - Works for all image generation models. The API type is automatically detected - based on the model name: - - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) use image_generation API - - All other models use chat/completions API (default) - - Image storage behavior: - - Create mode: Images saved to disk with UUID filenames, paths stored in dataframe - - Preview mode: Images stored as base64 directly in dataframe - - Common parameters like quality and size are provided as optional fields. - For model-specific parameters (including n for number of images), use the `extra_body` - field inherited from BaseInferenceParams. - - If the API returns multiple images (either from prompt or API parameters), all images - will be stored as a list in the dataframe. + Works for both diffusion and autoregressive image generation models. Use extra_body for model-specific parameters. Attributes: generation_type: Type of generation, always "image" for this class. @@ -454,22 +440,14 @@ class ImageInferenceParams(BaseInferenceParams): size="1024x1024" ) - # Generate multiple images using extra_body - dd.ImageInferenceParams( - quality="hd", - size="1024x1024", - extra_body={"n": 3} # Request 3 images from API - ) - # With model-specific params via extra_body dd.ImageInferenceParams( - quality="hd", - size="1024x1024", + quality="auto", extra_body={ "generationConfig": { "imageConfig": { "aspectRatio": "1:1", - "negativePrompt": "blurry, low quality" + "imageSize": "1024" } } } diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py index 9fc4e2b0..678d3b80 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py @@ -37,8 +37,6 @@ def is_image_diffusion_model(model_name: str) -> bool: """Return True if the model is a diffusion-based image generation model. - - Args: model_name: Model name or identifier (e.g. from provider). diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py index 38b8079e..4891c78d 100644 --- a/packages/data-designer-config/tests/config/test_models.py +++ b/packages/data-designer-config/tests/config/test_models.py @@ -17,6 +17,7 @@ GenerationType, ImageContext, ImageFormat, + ImageInferenceParams, ManualDistribution, ManualDistributionParams, ModalityDataType, @@ -412,6 +413,12 @@ def test_model_config_construction(): assert model_config.inference_parameters == embedding_params assert model_config.generation_type == GenerationType.EMBEDDING + # test construction with image inference parameters + image_params = ImageInferenceParams(quality="hd", size="1024x1024") + model_config = ModelConfig(alias="test", model="test", inference_parameters=image_params) + assert model_config.inference_parameters == image_params + assert model_config.generation_type == GenerationType.IMAGE + def test_model_config_generation_type_from_dict(): # Test that generation_type in dict is used to create the right inference params type @@ -435,6 +442,29 @@ def test_model_config_generation_type_from_dict(): assert isinstance(model_config.inference_parameters, ChatCompletionInferenceParams) assert model_config.generation_type == GenerationType.CHAT_COMPLETION + model_config = ModelConfig.model_validate( + { + "alias": "test", + "model": "image-model", + "inference_parameters": {"generation_type": "image", "quality": "hd", "size": "1024x1024"}, + } + ) + assert isinstance(model_config.inference_parameters, ImageInferenceParams) + assert model_config.inference_parameters.quality == "hd" + assert model_config.inference_parameters.size == "1024x1024" + assert model_config.generation_type == GenerationType.IMAGE + + +def test_image_inference_params_generate_kwargs() -> None: + """ImageInferenceParams.generate_kwargs includes quality and size when set.""" + params = ImageInferenceParams() + assert params.generate_kwargs.get("quality") is None + assert params.generate_kwargs.get("size") is None + + params = ImageInferenceParams(quality="hd", size="1024x1024") + assert params.generate_kwargs["quality"] == "hd" + assert params.generate_kwargs["size"] == "1024x1024" + def test_chat_completion_params_format_for_display_all_params(): """Test formatting chat completion model with all parameters.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/models/registry.py b/packages/data-designer-engine/src/data_designer/engine/models/registry.py index 2878f64e..c6f2b7c7 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/registry.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/registry.py @@ -187,7 +187,7 @@ def run_health_check(self, model_aliases: list[str]) -> None: skip_usage_tracking=True, purpose="running health checks", ) - elif model.model_generation_type == GenerationType.IMAGE_GENERATION: + elif model.model_generation_type == GenerationType.IMAGE: model.generate_image( prompt="Generate a simple pixel", skip_usage_tracking=True, diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 78473d63..0323ce98 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -1,11 +1,12 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from unittest.mock import patch import pytest -from litellm.types.utils import Choices, EmbeddingResponse, Message, ModelResponse from data_designer.engine.mcp.errors import MCPConfigurationError, MCPToolError from data_designer.engine.models.errors import ModelGenerationValidationFailureError @@ -13,6 +14,10 @@ from data_designer.engine.models.parsers.errors import ParserException from data_designer.engine.models.utils import ChatMessage from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, StubMessage, StubResponse +from data_designer.lazy_heavy_imports import litellm + +if TYPE_CHECKING: + import litellm def mock_oai_response_object(response_text: str) -> StubResponse: @@ -35,12 +40,14 @@ def stub_completion_messages() -> list[ChatMessage]: @pytest.fixture def stub_expected_completion_response(): - return ModelResponse(choices=Choices(message=Message(content="Test response"))) + return litellm.types.utils.ModelResponse( + choices=litellm.types.utils.Choices(message=litellm.types.utils.Message(content="Test response")) + ) @pytest.fixture def stub_expected_embedding_response(): - return EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2) + return litellm.types.utils.EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2) @pytest.mark.parametrize( @@ -106,9 +113,11 @@ def test_generate_with_system_prompt( # Capture messages at call time since they get mutated after the call captured_messages = [] - def capture_and_return(*args: Any, **kwargs: Any) -> ModelResponse: + def capture_and_return(*args: Any, **kwargs: Any) -> litellm.types.utils.ModelResponse: captured_messages.append(list(args[1])) # Copy the messages list - return ModelResponse(choices=Choices(message=Message(content="Hello!"))) + return litellm.types.utils.ModelResponse( + choices=litellm.types.utils.Choices(message=litellm.types.utils.Message(content="Hello!")) + ) mock_completion.side_effect = capture_and_return @@ -166,7 +175,7 @@ def test_completion_success( stub_completion_messages: list[ChatMessage], stub_model_configs: Any, stub_model_facade: ModelFacade, - stub_expected_completion_response: ModelResponse, + stub_expected_completion_response: litellm.types.utils.ModelResponse, skip_usage_tracking: bool, ) -> None: mock_router_completion.side_effect = lambda self, model, messages, **kwargs: stub_expected_completion_response @@ -199,11 +208,13 @@ def test_completion_with_kwargs( stub_completion_messages: list[ChatMessage], stub_model_configs: Any, stub_model_facade: ModelFacade, - stub_expected_completion_response: ModelResponse, + stub_expected_completion_response: litellm.types.utils.ModelResponse, ) -> None: captured_kwargs = {} - def mock_completion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> ModelResponse: + def mock_completion( + self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any + ) -> litellm.types.utils.ModelResponse: captured_kwargs.update(kwargs) return stub_expected_completion_response @@ -1002,14 +1013,12 @@ def test_generate_image_diffusion_tracks_image_usage( stub_model_facade: ModelFacade, ) -> None: """Test that generate_image tracks image usage for diffusion models.""" - from litellm.types.utils import ImageObject, ImageResponse - # Mock response with 3 images - mock_response = ImageResponse( + mock_response = litellm.types.utils.ImageResponse( data=[ - ImageObject(b64_json="image1_base64"), - ImageObject(b64_json="image2_base64"), - ImageObject(b64_json="image3_base64"), + litellm.types.utils.ImageObject(b64_json="image1_base64"), + litellm.types.utils.ImageObject(b64_json="image2_base64"), + litellm.types.utils.ImageObject(b64_json="image3_base64"), ] ) mock_image_generation.return_value = mock_response @@ -1036,18 +1045,20 @@ def test_generate_image_chat_completion_tracks_image_usage( stub_model_facade: ModelFacade, ) -> None: """Test that generate_image tracks image usage for chat completion models.""" - from litellm.types.utils import Choices, ImageURLListItem, Message, ModelResponse - # Mock response with images attribute (Message requires type and index per ImageURLListItem) - mock_message = Message( + mock_message = litellm.types.utils.Message( role="assistant", content="", images=[ - ImageURLListItem(type="image_url", image_url={"url": ""}, index=0), - ImageURLListItem(type="image_url", image_url={"url": ""}, index=1), + litellm.types.utils.ImageURLListItem( + type="image_url", image_url={"url": ""}, index=0 + ), + litellm.types.utils.ImageURLListItem( + type="image_url", image_url={"url": ""}, index=1 + ), ], ) - mock_response = ModelResponse(choices=[Choices(message=mock_message)]) + mock_response = litellm.types.utils.ModelResponse(choices=[litellm.types.utils.Choices(message=mock_message)]) mock_completion.return_value = mock_response # Verify initial state @@ -1072,12 +1083,10 @@ def test_generate_image_skip_usage_tracking( stub_model_facade: ModelFacade, ) -> None: """Test that generate_image respects skip_usage_tracking flag.""" - from litellm.types.utils import ImageObject, ImageResponse - - mock_response = ImageResponse( + mock_response = litellm.types.utils.ImageResponse( data=[ - ImageObject(b64_json="image1_base64"), - ImageObject(b64_json="image2_base64"), + litellm.types.utils.ImageObject(b64_json="image1_base64"), + litellm.types.utils.ImageObject(b64_json="image2_base64"), ] ) mock_image_generation.return_value = mock_response @@ -1103,21 +1112,19 @@ def test_generate_image_accumulates_usage( stub_model_facade: ModelFacade, ) -> None: """Test that generate_image accumulates image usage across multiple calls.""" - from litellm.types.utils import ImageObject, ImageResponse - # First call - 2 images - mock_response1 = ImageResponse( + mock_response1 = litellm.types.utils.ImageResponse( data=[ - ImageObject(b64_json="image1"), - ImageObject(b64_json="image2"), + litellm.types.utils.ImageObject(b64_json="image1"), + litellm.types.utils.ImageObject(b64_json="image2"), ] ) # Second call - 3 images - mock_response2 = ImageResponse( + mock_response2 = litellm.types.utils.ImageResponse( data=[ - ImageObject(b64_json="image3"), - ImageObject(b64_json="image4"), - ImageObject(b64_json="image5"), + litellm.types.utils.ImageObject(b64_json="image3"), + litellm.types.utils.ImageObject(b64_json="image4"), + litellm.types.utils.ImageObject(b64_json="image5"), ] ) mock_image_generation.side_effect = [mock_response1, mock_response2] From fad791ee9c9073e590dd74b4febb7f8cc72b2064 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 12:16:31 -0700 Subject: [PATCH 30/64] Small refactor for simplicity --- .../src/data_designer/config/__init__.py | 4 ++-- .../src/data_designer/config/column_configs.py | 14 +++----------- .../src/data_designer/config/column_types.py | 8 ++++---- .../data_designer/config/utils/visualization.py | 2 +- .../tests/config/test_columns.py | 2 +- .../engine/column_generators/generators/image.py | 16 ++++------------ .../engine/column_generators/registry.py | 4 ++-- .../utils/generator_classification.py | 2 ++ .../dataset_builders/column_wise_builder.py | 2 +- .../column_generators/generators/test_image.py | 14 ++++++-------- .../utils/test_generator_classification.py | 2 ++ 11 files changed, 28 insertions(+), 42 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py index 5686b506..42afae81 100644 --- a/packages/data-designer-config/src/data_designer/config/__init__.py +++ b/packages/data-designer-config/src/data_designer/config/__init__.py @@ -17,7 +17,7 @@ EmbeddingColumnConfig, ExpressionColumnConfig, GenerationStrategy, - ImageGenerationColumnConfig, + ImageColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -123,7 +123,7 @@ "CustomColumnConfig": (_MOD_COLUMN_CONFIGS, "CustomColumnConfig"), "EmbeddingColumnConfig": (_MOD_COLUMN_CONFIGS, "EmbeddingColumnConfig"), "ExpressionColumnConfig": (_MOD_COLUMN_CONFIGS, "ExpressionColumnConfig"), - "ImageGenerationColumnConfig": (_MOD_COLUMN_CONFIGS, "ImageGenerationColumnConfig"), + "ImageColumnConfig": (_MOD_COLUMN_CONFIGS, "ImageColumnConfig"), "GenerationStrategy": (_MOD_COLUMN_CONFIGS, "GenerationStrategy"), "LLMCodeColumnConfig": (_MOD_COLUMN_CONFIGS, "LLMCodeColumnConfig"), "LLMJudgeColumnConfig": (_MOD_COLUMN_CONFIGS, "LLMJudgeColumnConfig"), diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index facbc4cf..e9d89f4e 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -485,22 +485,14 @@ def side_effect_columns(self) -> list[str]: return [] -class ImageGenerationColumnConfig(SingleColumnConfig): +class ImageColumnConfig(SingleColumnConfig): """Configuration for image generation columns. Image columns generate images using either autoregressive or diffusion models. The API used is automatically determined based on the model name: - - **Diffusion models** (DALL-E, Stable Diffusion, Imagen, etc.) β†’ image_generation API - - **All other models** β†’ chat/completions API (default) - - Image storage behavior: - - **Create mode**: Images saved to disk with UUID filenames in `images/` folder, - dataframe stores relative paths (e.g., "images/abc123.png") - - **Preview mode**: Images stored as base64 directly in dataframe - Attributes: - column_type: Discriminator field, always "image-generation" for this configuration type. + column_type: Discriminator field, always "image" for this configuration type. prompt: Prompt template for image generation. Supports Jinja2 templating to reference other columns (e.g., "Generate an image of a {{ character_name }}"). Must be a valid Jinja2 template. @@ -509,7 +501,7 @@ class ImageGenerationColumnConfig(SingleColumnConfig): prompt: str model_alias: str - column_type: Literal["image-generation"] = "image-generation" + column_type: Literal["image"] = "image" @staticmethod def get_column_emoji() -> str: diff --git a/packages/data-designer-config/src/data_designer/config/column_types.py b/packages/data-designer-config/src/data_designer/config/column_types.py index 9b01e7d7..baba25dd 100644 --- a/packages/data-designer-config/src/data_designer/config/column_types.py +++ b/packages/data-designer-config/src/data_designer/config/column_types.py @@ -9,7 +9,7 @@ CustomColumnConfig, EmbeddingColumnConfig, ExpressionColumnConfig, - ImageGenerationColumnConfig, + ImageColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -40,7 +40,7 @@ | SeedDatasetColumnConfig | ValidationColumnConfig | EmbeddingColumnConfig - | ImageGenerationColumnConfig + | ImageColumnConfig ) ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT) @@ -89,7 +89,7 @@ def get_column_display_order() -> list[DataDesignerColumnType]: DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, DataDesignerColumnType.EMBEDDING, - DataDesignerColumnType.IMAGE_GENERATION, + DataDesignerColumnType.IMAGE, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, DataDesignerColumnType.CUSTOM, @@ -145,5 +145,5 @@ def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict: DataDesignerColumnType.SAMPLER: SamplerColumnConfig, DataDesignerColumnType.SEED_DATASET: SeedDatasetColumnConfig, DataDesignerColumnType.EMBEDDING: EmbeddingColumnConfig, - DataDesignerColumnType.IMAGE_GENERATION: ImageGenerationColumnConfig, + DataDesignerColumnType.IMAGE: ImageColumnConfig, } diff --git a/packages/data-designer-config/src/data_designer/config/utils/visualization.py b/packages/data-designer-config/src/data_designer/config/utils/visualization.py index 6a9e8ee5..910bc467 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/visualization.py +++ b/packages/data-designer-config/src/data_designer/config/utils/visualization.py @@ -290,7 +290,7 @@ def display_sample_record( render_list.append(pad_console_element(table)) # Collect image generation columns (will be displayed at the end) - image_columns = config_builder.get_columns_of_type(DataDesignerColumnType.IMAGE_GENERATION) + image_columns = config_builder.get_columns_of_type(DataDesignerColumnType.IMAGE) images_to_display_later = [] if len(image_columns) > 0: # Check if we're in a notebook to decide display style diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index 56bb912d..e633518d 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -53,7 +53,7 @@ def test_data_designer_column_type_get_display_order(): DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, DataDesignerColumnType.EMBEDDING, - DataDesignerColumnType.IMAGE_GENERATION, + DataDesignerColumnType.IMAGE, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, DataDesignerColumnType.CUSTOM, diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py index 41586e4b..c8396e24 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING -from data_designer.config.column_configs import ImageGenerationColumnConfig +from data_designer.config.column_configs import ImageColumnConfig from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModel, GenerationStrategy from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering from data_designer.engine.processing.utils import deserialize_json_values @@ -14,18 +14,12 @@ from data_designer.engine.storage.media_storage import MediaStorage -class ImageCellGenerator(WithJinja2UserTemplateRendering, ColumnGeneratorWithModel[ImageGenerationColumnConfig]): +class ImageCellGenerator(WithJinja2UserTemplateRendering, ColumnGeneratorWithModel[ImageColumnConfig]): """Generator for image columns with disk or dataframe persistence. Media storage always exists and determines behavior via its mode: - - DISK mode (create): Saves images to disk and stores relative paths in dataframe - - DATAFRAME mode (preview): Stores base64 directly in dataframe - - API is automatically detected based on the model name: - - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) β†’ image_generation API - - All other models β†’ chat/completions API (default) - - Storage is accessed via ResourceProvider.artifact_storage.media_storage + - DISK mode: Saves images to disk and stores relative paths in dataframe + - DATAFRAME mode: Stores base64 directly in dataframe """ @property @@ -69,8 +63,6 @@ def generate(self, data: dict) -> dict: base64_images = self.model.generate_image(prompt=prompt) # Store via media storage (mode determines disk vs dataframe storage) - # TODO: MediaStorage will check its mode (DISK/DATAFRAME) and act accordingly - # For now, always saves to disk - need to implement mode system results = [self.media_storage.save_base64_image(base64_image) for base64_image in base64_images] data[self.config.name] = results diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/registry.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/registry.py index a4538ad6..f4fc27b9 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/registry.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/registry.py @@ -8,7 +8,7 @@ CustomColumnConfig, EmbeddingColumnConfig, ExpressionColumnConfig, - ImageGenerationColumnConfig, + ImageColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -54,7 +54,7 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig) registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig) registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig) - registry.register(DataDesignerColumnType.IMAGE_GENERATION, ImageCellGenerator, ImageGenerationColumnConfig) + registry.register(DataDesignerColumnType.IMAGE, ImageCellGenerator, ImageColumnConfig) if with_plugins: for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR): registry.register( diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/generator_classification.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/generator_classification.py index 2e082779..7a45fc71 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/generator_classification.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/generator_classification.py @@ -22,6 +22,7 @@ def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType) DataDesignerColumnType.LLM_TEXT, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE, } dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType)) return column_type in dag_column_types @@ -36,6 +37,7 @@ def column_type_is_model_generated(column_type: str | DataDesignerColumnType) -> DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE, } for plugin in plugin_manager.get_column_generator_plugins(): if issubclass(plugin.impl_cls, ColumnGeneratorWithModelRegistry): diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 6802f805..ad9e265b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -172,7 +172,7 @@ def _has_image_columns(self) -> bool: """Check if config has any image generation columns.""" from data_designer.config.column_types import DataDesignerColumnType - return any(col.column_type == DataDesignerColumnType.IMAGE_GENERATION for col in self.single_column_configs) + return any(col.column_type == DataDesignerColumnType.IMAGE for col in self.single_column_configs) def _initialize_generators(self) -> list[ColumnGenerator]: """Initialize column generators. diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py index e7055d67..80523ff5 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py @@ -5,7 +5,7 @@ import pytest -from data_designer.config.column_configs import ImageGenerationColumnConfig +from data_designer.config.column_configs import ImageColumnConfig from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.column_generators.generators.image import ImageCellGenerator from data_designer.engine.processing.ginja.exceptions import UserTemplateError @@ -13,9 +13,7 @@ @pytest.fixture def stub_image_column_config(): - return ImageGenerationColumnConfig( - name="test_image", prompt="A {{ style }} image of {{ subject }}", model_alias="test_model" - ) + return ImageColumnConfig(name="test_image", prompt="A {{ style }} image of {{ subject }}", model_alias="test_model") @pytest.fixture @@ -24,14 +22,14 @@ def stub_base64_images() -> list[str]: def test_image_cell_generator_generation_strategy( - stub_image_column_config: ImageGenerationColumnConfig, stub_resource_provider: None + stub_image_column_config: ImageColumnConfig, stub_resource_provider: None ) -> None: generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) assert generator.get_generation_strategy() == GenerationStrategy.CELL_BY_CELL def test_image_cell_generator_media_storage_property( - stub_image_column_config: ImageGenerationColumnConfig, stub_resource_provider: None + stub_image_column_config: ImageColumnConfig, stub_resource_provider: None ) -> None: generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) # Should return media_storage from artifact_storage (always exists) @@ -105,7 +103,7 @@ def test_image_cell_generator_missing_columns_error(stub_image_column_config, st def test_image_cell_generator_empty_prompt_error(stub_resource_provider): """Test that empty rendered prompt raises UserTemplateError.""" # Create config with template that renders to empty string - config = ImageGenerationColumnConfig(name="test_image", prompt="{{ empty }}", model_alias="test_model") + config = ImageColumnConfig(name="test_image", prompt="{{ empty }}", model_alias="test_model") generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) @@ -115,7 +113,7 @@ def test_image_cell_generator_empty_prompt_error(stub_resource_provider): def test_image_cell_generator_whitespace_only_prompt_error(stub_resource_provider): """Test that whitespace-only rendered prompt raises ValueError.""" - config = ImageGenerationColumnConfig(name="test_image", prompt="{{ spaces }}", model_alias="test_model") + config = ImageColumnConfig(name="test_image", prompt="{{ spaces }}", model_alias="test_model") generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) diff --git a/packages/data-designer-engine/tests/engine/column_generators/utils/test_generator_classification.py b/packages/data-designer-engine/tests/engine/column_generators/utils/test_generator_classification.py index bdf15e5d..0be26b11 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/utils/test_generator_classification.py +++ b/packages/data-designer-engine/tests/engine/column_generators/utils/test_generator_classification.py @@ -14,6 +14,7 @@ def test_column_type_is_model_generated() -> None: assert column_type_is_model_generated(DataDesignerColumnType.LLM_STRUCTURED) assert column_type_is_model_generated(DataDesignerColumnType.LLM_JUDGE) assert column_type_is_model_generated(DataDesignerColumnType.EMBEDDING) + assert column_type_is_model_generated(DataDesignerColumnType.IMAGE) assert not column_type_is_model_generated(DataDesignerColumnType.SAMPLER) assert not column_type_is_model_generated(DataDesignerColumnType.VALIDATION) assert not column_type_is_model_generated(DataDesignerColumnType.EXPRESSION) @@ -28,5 +29,6 @@ def test_column_type_used_in_execution_dag() -> None: assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_TEXT) assert column_type_used_in_execution_dag(DataDesignerColumnType.VALIDATION) assert column_type_used_in_execution_dag(DataDesignerColumnType.EMBEDDING) + assert column_type_used_in_execution_dag(DataDesignerColumnType.IMAGE) assert not column_type_used_in_execution_dag(DataDesignerColumnType.SAMPLER) assert not column_type_used_in_execution_dag(DataDesignerColumnType.SEED_DATASET) From 54ebcc80cb2a9b62fdc98804631193205faa2414 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 12:59:43 -0700 Subject: [PATCH 31/64] update ImageInferenceParams --- .../src/data_designer/config/models.py | 21 +++++-------------- .../tests/config/test_models.py | 21 ++++++++++--------- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index 8b16b4bc..0542a8b8 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -425,24 +425,20 @@ def generate_kwargs(self) -> dict[str, float | int]: class ImageInferenceParams(BaseInferenceParams): """Configuration for image generation models. - Works for both diffusion and autoregressive image generation models. Use extra_body for model-specific parameters. + Works for both diffusion and autoregressive image generation models. Pass all model-specific image options via `extra_body`. Attributes: generation_type: Type of generation, always "image" for this class. - quality: Image quality setting (e.g., "standard", "hd"). Optional and model-specific. - size: Image size specification (e.g., "1024x1024", "1792x1024"). Optional and model-specific. Example: ```python - # Standard usage with common params + # OpenAI-style (DALLΒ·E): quality and size in extra_body or as top-level kwargs dd.ImageInferenceParams( - quality="hd", - size="1024x1024" + extra_body={"size": "1024x1024", "quality": "hd"} ) - # With model-specific params via extra_body + # Gemini-style: generationConfig.imageConfig dd.ImageInferenceParams( - quality="auto", extra_body={ "generationConfig": { "imageConfig": { @@ -456,17 +452,10 @@ class ImageInferenceParams(BaseInferenceParams): """ generation_type: Literal[GenerationType.IMAGE] = GenerationType.IMAGE - quality: str | None = None - size: str | None = None @property def generate_kwargs(self) -> dict[str, Any]: - result = super().generate_kwargs - if self.quality is not None: - result["quality"] = self.quality - if self.size is not None: - result["size"] = self.size - return result + return super().generate_kwargs InferenceParamsT: TypeAlias = Annotated[ diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py index 4891c78d..564b235c 100644 --- a/packages/data-designer-config/tests/config/test_models.py +++ b/packages/data-designer-config/tests/config/test_models.py @@ -414,7 +414,7 @@ def test_model_config_construction(): assert model_config.generation_type == GenerationType.EMBEDDING # test construction with image inference parameters - image_params = ImageInferenceParams(quality="hd", size="1024x1024") + image_params = ImageInferenceParams(extra_body={"size": "1024x1024", "quality": "hd"}) model_config = ModelConfig(alias="test", model="test", inference_parameters=image_params) assert model_config.inference_parameters == image_params assert model_config.generation_type == GenerationType.IMAGE @@ -446,24 +446,25 @@ def test_model_config_generation_type_from_dict(): { "alias": "test", "model": "image-model", - "inference_parameters": {"generation_type": "image", "quality": "hd", "size": "1024x1024"}, + "inference_parameters": { + "generation_type": "image", + "extra_body": {"size": "1024x1024", "quality": "hd"}, + }, } ) assert isinstance(model_config.inference_parameters, ImageInferenceParams) - assert model_config.inference_parameters.quality == "hd" - assert model_config.inference_parameters.size == "1024x1024" + assert model_config.inference_parameters.extra_body == {"size": "1024x1024", "quality": "hd"} assert model_config.generation_type == GenerationType.IMAGE def test_image_inference_params_generate_kwargs() -> None: - """ImageInferenceParams.generate_kwargs includes quality and size when set.""" + """ImageInferenceParams.generate_kwargs delegates to base; image params go via extra_body.""" params = ImageInferenceParams() - assert params.generate_kwargs.get("quality") is None - assert params.generate_kwargs.get("size") is None + assert "quality" not in params.generate_kwargs + assert "size" not in params.generate_kwargs - params = ImageInferenceParams(quality="hd", size="1024x1024") - assert params.generate_kwargs["quality"] == "hd" - assert params.generate_kwargs["size"] == "1024x1024" + params = ImageInferenceParams(extra_body={"size": "1024x1024", "quality": "hd"}) + assert params.generate_kwargs.get("extra_body") == {"size": "1024x1024", "quality": "hd"} def test_chat_completion_params_format_for_display_all_params(): From 3aad6081dca8582d88f06b5e464c4b0f0ac79bf2 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 13:01:38 -0700 Subject: [PATCH 32/64] add example tutorial for image generation --- docs/notebook_source/1-the-basics.py | 2 + ...tructured-outputs-and-jinja-expressions.py | 2 + .../3-seeding-with-a-dataset.py | 2 + .../4-providing-images-as-context.py | 2 + docs/notebook_source/5-generating-images.py | 212 ++++++++++++++++++ docs/notebook_source/_README.md | 9 + 6 files changed, 229 insertions(+) create mode 100644 docs/notebook_source/5-generating-images.py diff --git a/docs/notebook_source/1-the-basics.py b/docs/notebook_source/1-the-basics.py index 392efb34..8735d582 100644 --- a/docs/notebook_source/1-the-basics.py +++ b/docs/notebook_source/1-the-basics.py @@ -330,3 +330,5 @@ # # - [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/) # +# - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) +# diff --git a/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py b/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py index 66b3773f..df581612 100644 --- a/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py +++ b/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py @@ -372,3 +372,5 @@ class ProductReview(BaseModel): # # - [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/) # +# - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) +# diff --git a/docs/notebook_source/3-seeding-with-a-dataset.py b/docs/notebook_source/3-seeding-with-a-dataset.py index c9d694a8..e4f9218e 100644 --- a/docs/notebook_source/3-seeding-with-a-dataset.py +++ b/docs/notebook_source/3-seeding-with-a-dataset.py @@ -274,3 +274,5 @@ # # - [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/) # +# - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) +# diff --git a/docs/notebook_source/4-providing-images-as-context.py b/docs/notebook_source/4-providing-images-as-context.py index a11880ba..1fd68dac 100644 --- a/docs/notebook_source/4-providing-images-as-context.py +++ b/docs/notebook_source/4-providing-images-as-context.py @@ -299,3 +299,5 @@ def convert_image_to_chat_format(record, height: int) -> dict: # - Combine vision-based summaries with other column types for multi-modal workflows # - Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering # +# - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) with Data Designer +# diff --git a/docs/notebook_source/5-generating-images.py b/docs/notebook_source/5-generating-images.py new file mode 100644 index 00000000..aee5a0c1 --- /dev/null +++ b/docs/notebook_source/5-generating-images.py @@ -0,0 +1,212 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.18.1 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # 🎨 Data Designer Tutorial: Generating Images +# +# #### πŸ“š What you'll learn +# +# This notebook shows how to generate synthetic image data with Data Designer using image-generation models. +# +# - πŸ–ΌοΈ **Image generation columns**: Add columns that produce images from text prompts +# - πŸ“ **Jinja2 prompts**: Drive diversity by referencing other columns in your prompt template +# - πŸ’Ύ **Preview vs create**: Preview stores base64 in the dataframe; create saves images to disk and stores paths +# +# Data Designer supports both **diffusion** (e.g. DALLΒ·E, Stable Diffusion, Imagen) and **autoregressive** (e.g. Gemini image, GPT image) models; the API is chosen automatically from the model name. +# +# If this is your first time using Data Designer, we recommend starting with the [first notebook](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/) in this tutorial series. +# + +# %% [markdown] +# ### πŸ“¦ Import Data Designer +# +# - `data_designer.config` provides the configuration API. +# - `DataDesigner` is the main interface for generation. +# + +# %% +from IPython.display import Image as IPImage +from IPython.display import display + +import data_designer.config as dd +from data_designer.interface import DataDesigner + +# %% [markdown] +# ### βš™οΈ Initialize the Data Designer interface +# +# When initialized without arguments, [default model providers](https://nvidia-nemo.github.io/DataDesigner/latest/concepts/models/default-model-settings/) are used. This tutorial uses [OpenRouter](https://openrouter.ai) with the Flux 2 Pro image model; set `OPENROUTER_API_KEY` in your environment. +# + +# %% +data_designer = DataDesigner() + +# %% [markdown] +# ### πŸŽ›οΈ Define an image-generation model +# +# - Use `ImageInferenceParams` so Data Designer treats this model as an image generator. +# - Image options (size, quality, aspect ratio, etc.) are model-specific; pass them via `extra_body`. +# + +# %% +MODEL_PROVIDER = "openrouter" +MODEL_ID = "black-forest-labs/flux.2-pro" +MODEL_ALIAS = "image-model" + +model_configs = [ + dd.ModelConfig( + alias=MODEL_ALIAS, + model=MODEL_ID, + provider=MODEL_PROVIDER, + inference_parameters=dd.ImageInferenceParams( + extra_body={"size": "1024x1024"}, + ), + ) +] + +# %% [markdown] +# ### πŸ—οΈ Build the config: samplers + image column +# +# We'll generate diverse **dog portrait** images: sampler columns drive subject (breed), age, style, look direction, and emotion. The image-generation column uses a Jinja2 prompt that references all of them. +# + +# %% +config_builder = dd.DataDesignerConfigBuilder(model_configs=model_configs) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="subject", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=[ + "a Golden Retriever", + "a German Shepherd", + "a Labrador Retriever", + "a Bulldog", + "a Beagle", + "a Poodle", + "a Corgi", + "a Siberian Husky", + "a Dalmatian", + ], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="age", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=["1-3", "3-6", "6-9", "9-12", "12-15"], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="style", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=[ + "photorealistic", + "oil painting", + "watercolor", + "digital art", + "sketch", + "anime", + ], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="look_direction", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=["left", "right", "front", "up", "down"], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="emotion", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=["happy", "curious", "serious", "sleepy", "excited"], + ), + ) +) + +config_builder.add_column( + dd.ImageColumnConfig( + name="generated_image", + prompt=( + "A {{ style }} portrait of {{ subject }} {{ age }} years old looking {{ look_direction }} " + "towards a crowd of the same kind with an {{ emotion }} expression." + ), + model_alias=MODEL_ALIAS, + ) +) + +data_designer.validate(config_builder) + +# %% [markdown] +# ### πŸ” Preview: images as base64 +# +# In **preview** mode, generated images are stored as base64 strings in the dataframe. Run the next cell to step through each record (images are shown in the sample record display, but only in a notebook environment). +# + +# %% +preview = data_designer.preview(config_builder, num_records=2) + +# %% +for i in range(len(preview.dataset)): + preview.display_sample_record() + +# %% +preview.dataset + +# %% [markdown] +# ### πŸ†™ Create: images saved to disk +# +# In **create** mode, images are written to an `images/` folder with UUID filenames; the dataframe stores relative paths (e.g. `images/1d16b6e2-562f-4f51-91e5-baaa999ea916.png`). +# + +# %% +results = data_designer.create(config_builder, num_records=5, dataset_name="tutorial-5-images") + +# %% +dataset = results.load_dataset() +dataset.head() + +# %% +# Display all image from the created dataset. Paths are relative to the artifact output directory. +for index, row in dataset.iterrows(): + path_or_list = row.get("generated_image") + if path_or_list is not None: + for path in path_or_list: + base = results.artifact_storage.base_dataset_path + full_path = base / path + display(IPImage(data=full_path)) + +# %% [markdown] +# ## ⏭️ Next steps +# +# - [The basics](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/): samplers and LLM text columns +# - [Structured outputs and Jinja](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/2-structured-outputs-and-jinja-expressions/) +# - [Seeding with a dataset](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/3-seeding-with-a-dataset/) +# - [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/) +# diff --git a/docs/notebook_source/_README.md b/docs/notebook_source/_README.md index 09053c22..7bcd77d1 100644 --- a/docs/notebook_source/_README.md +++ b/docs/notebook_source/_README.md @@ -97,6 +97,15 @@ Learn how to use vision-language models to generate text descriptions from image - Generating detailed summaries from document images - Inspecting and validating vision-based generation results +### [5. Generating Images](5-generating-images.ipynb) + +Generate synthetic image data with Data Designer: + +- Configuring image-generation models with `ImageInferenceParams` +- Adding image columns with Jinja2 prompts and sampler-driven diversity +- Preview (base64 in dataframe) vs create (images saved to disk, paths in dataframe) +- Displaying generated images in the notebook + ## πŸ“– Important Documentation Sections Before diving into the tutorials, familiarize yourself with these key documentation sections: From f252c376917bcb791113318a84ee4e84d005d793 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 18:15:14 -0700 Subject: [PATCH 33/64] support multi-modal context in ImageColumnConfig --- .../data_designer/config/column_configs.py | 4 + .../column_generators/generators/image.py | 9 +- .../src/data_designer/engine/models/facade.py | 28 +++++- .../data_designer/engine/models/registry.py | 2 +- .../generators/test_image.py | 90 ++++++++++++++++++- 5 files changed, 125 insertions(+), 8 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index e9d89f4e..e3ea013d 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -497,10 +497,14 @@ class ImageColumnConfig(SingleColumnConfig): reference other columns (e.g., "Generate an image of a {{ character_name }}"). Must be a valid Jinja2 template. model_alias: The model to use for image generation. + multi_modal_context: Optional list of image contexts for multi-modal generation. + Enables autoregressive multi-modal models to generate images based on image inputs. + Only works with autoregressive models that support image-to-image generation. """ prompt: str model_alias: str + multi_modal_context: list[ImageContext] | None = None column_type: Literal["image"] = "image" @staticmethod diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py index c8396e24..11bc732c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py @@ -59,8 +59,15 @@ def generate(self, data: dict) -> dict: if not prompt or not prompt.strip(): raise ValueError(f"Rendered prompt for column {self.config.name!r} is empty") + # Process multi-modal context if provided + multi_modal_context = None + if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0: + multi_modal_context = [] + for context in self.config.multi_modal_context: + multi_modal_context.extend(context.get_contexts(deserialized_record)) + # Generate images (returns list of base64 strings) - base64_images = self.model.generate_image(prompt=prompt) + base64_images = self.model.generate_image(prompt=prompt, multi_modal_context=multi_modal_context) # Store via media storage (mode determines disk vs dataframe storage) results = [self.media_storage.save_base64_image(base64_image) for base64_image in base64_images] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 11f6e9ec..51940e99 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -304,7 +304,13 @@ def generate_text_embeddings( self._track_token_usage_from_embedding(response) @catch_llm_exceptions - def generate_image(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]: + def generate_image( + self, + prompt: str, + multi_modal_context: list[dict[str, Any]] | None = None, + skip_usage_tracking: bool = False, + **kwargs, + ) -> list[str]: """Generate image(s) and return base64-encoded data. Automatically detects the appropriate API based on model name: @@ -316,6 +322,8 @@ def generate_image(self, prompt: str, skip_usage_tracking: bool = False, **kwarg Args: prompt: The prompt for image generation + multi_modal_context: Optional list of image contexts for multi-modal generation. + Only used with autoregressive models via chat completions API. skip_usage_tracking: Whether to skip usage tracking **kwargs: Additional arguments to pass to the model (including n=number of images) @@ -334,7 +342,7 @@ def generate_image(self, prompt: str, skip_usage_tracking: bool = False, **kwarg if is_image_diffusion_model(self.model_name): images = self._generate_image_diffusion(prompt, skip_usage_tracking, **kwargs) else: - images = self._generate_image_chat_completion(prompt, skip_usage_tracking, **kwargs) + images = self._generate_image_chat_completion(prompt, multi_modal_context, skip_usage_tracking, **kwargs) # Track image usage if not skip_usage_tracking and len(images) > 0: @@ -353,14 +361,26 @@ def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: except ValueError as exc: raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc - def _generate_image_chat_completion(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]: + def _generate_image_chat_completion( + self, + prompt: str, + multi_modal_context: list[dict[str, Any]] | None = None, + skip_usage_tracking: bool = False, + **kwargs, + ) -> list[str]: """Generate image(s) using autoregressive model via chat completions API. + Args: + prompt: The prompt for image generation + multi_modal_context: Optional list of image contexts for multi-modal generation + skip_usage_tracking: Whether to skip usage tracking + **kwargs: Additional arguments to pass to the model + Returns: List of base64-encoded image strings """ kwargs = self.consolidate_kwargs(**kwargs) - messages = [ChatMessage.as_user(content=prompt)] + messages = prompt_to_messages(user_prompt=prompt, multi_modal_context=multi_modal_context) response = None try: diff --git a/packages/data-designer-engine/src/data_designer/engine/models/registry.py b/packages/data-designer-engine/src/data_designer/engine/models/registry.py index c6f2b7c7..0b103e76 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/registry.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/registry.py @@ -189,7 +189,7 @@ def run_health_check(self, model_aliases: list[str]) -> None: ) elif model.model_generation_type == GenerationType.IMAGE: model.generate_image( - prompt="Generate a simple pixel", + prompt="Generate a simple illustration of a thumbs up sign.", skip_usage_tracking=True, purpose="running health checks", ) diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py index 80523ff5..b433bc55 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py @@ -6,6 +6,7 @@ import pytest from data_designer.config.column_configs import ImageColumnConfig +from data_designer.config.models import ImageContext, ImageFormat, ModalityDataType from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.column_generators.generators.image import ImageCellGenerator from data_designer.engine.processing.ginja.exceptions import UserTemplateError @@ -58,7 +59,7 @@ def test_image_cell_generator_generate_with_storage( assert data[stub_image_column_config.name] == ["images/uuid1.png", "images/uuid2.png"] # Verify model was called with rendered prompt - mock_generate.assert_called_once_with(prompt="A photorealistic image of cat") + mock_generate.assert_called_once_with(prompt="A photorealistic image of cat", multi_modal_context=None) # Verify storage was called for each image assert mock_storage.save_base64_image.call_count == 2 @@ -88,7 +89,7 @@ def test_image_cell_generator_generate_in_dataframe_mode( assert data[stub_image_column_config.name] == stub_base64_images # Verify model was called with rendered prompt - mock_generate.assert_called_once_with(prompt="A watercolor image of dog") + mock_generate.assert_called_once_with(prompt="A watercolor image of dog", multi_modal_context=None) def test_image_cell_generator_missing_columns_error(stub_image_column_config, stub_resource_provider): @@ -119,3 +120,88 @@ def test_image_cell_generator_whitespace_only_prompt_error(stub_resource_provide with pytest.raises(ValueError, match="empty"): generator.generate(data={"spaces": " "}) + + +def test_image_cell_generator_with_multi_modal_context(stub_resource_provider): + """Test generate with multi-modal context for autoregressive models.""" + # Create image context that references a column with URL + image_context = ImageContext(column_name="reference_image", data_type=ModalityDataType.URL) + + config = ImageColumnConfig( + name="test_image", + prompt="Generate a similar image to the reference", + model_alias="test_model", + multi_modal_context=[image_context], + ) + + # Setup mock media storage + mock_storage = Mock() + mock_storage.save_base64_image.return_value = "images/generated.png" + stub_resource_provider.artifact_storage.media_storage = mock_storage + + stub_base64_images = ["base64_generated_image"] + + with patch.object( + stub_resource_provider.model_registry.get_model.return_value, + "generate_image", + return_value=stub_base64_images, + ) as mock_generate: + generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) + data = generator.generate(data={"reference_image": "https://example.com/image.png"}) + + # Check that column was added + assert config.name in data + assert data[config.name] == ["images/generated.png"] + + # Verify model was called with prompt and multi_modal_context + mock_generate.assert_called_once() + call_args = mock_generate.call_args + assert call_args.kwargs["prompt"] == "Generate a similar image to the reference" + assert call_args.kwargs["multi_modal_context"] is not None + assert len(call_args.kwargs["multi_modal_context"]) == 1 + assert call_args.kwargs["multi_modal_context"][0]["type"] == "image_url" + assert call_args.kwargs["multi_modal_context"][0]["image_url"] == "https://example.com/image.png" + + +def test_image_cell_generator_with_base64_multi_modal_context(stub_resource_provider): + """Test generate with base64 multi-modal context.""" + # Create image context that references a column with base64 data + image_context = ImageContext( + column_name="reference_image", data_type=ModalityDataType.BASE64, image_format=ImageFormat.PNG + ) + + config = ImageColumnConfig( + name="test_image", + prompt="Generate a variation of this image", + model_alias="test_model", + multi_modal_context=[image_context], + ) + + # Setup mock media storage + mock_storage = Mock() + mock_storage.save_base64_image.return_value = "images/generated.png" + stub_resource_provider.artifact_storage.media_storage = mock_storage + + stub_base64_images = ["base64_generated_image"] + + with patch.object( + stub_resource_provider.model_registry.get_model.return_value, + "generate_image", + return_value=stub_base64_images, + ) as mock_generate: + generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) + data = generator.generate(data={"reference_image": "iVBORw0KGgoAAAANS"}) + + # Check that column was added + assert config.name in data + assert data[config.name] == ["images/generated.png"] + + # Verify model was called with prompt and multi_modal_context + mock_generate.assert_called_once() + call_args = mock_generate.call_args + assert call_args.kwargs["prompt"] == "Generate a variation of this image" + assert call_args.kwargs["multi_modal_context"] is not None + assert len(call_args.kwargs["multi_modal_context"]) == 1 + assert call_args.kwargs["multi_modal_context"][0]["type"] == "image_url" + # Should be formatted as data URI + assert "data:image/png;base64," in call_args.kwargs["multi_modal_context"][0]["image_url"]["url"] From d6a0f2fcb8b0acb664e6c101a14ae1910095fb62 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 18:39:53 -0700 Subject: [PATCH 34/64] updated tutorial notebook --- docs/notebook_source/5-generating-images.py | 116 +++++++++++++++++--- 1 file changed, 100 insertions(+), 16 deletions(-) diff --git a/docs/notebook_source/5-generating-images.py b/docs/notebook_source/5-generating-images.py index aee5a0c1..28638ff9 100644 --- a/docs/notebook_source/5-generating-images.py +++ b/docs/notebook_source/5-generating-images.py @@ -69,7 +69,7 @@ model=MODEL_ID, provider=MODEL_PROVIDER, inference_parameters=dd.ImageInferenceParams( - extra_body={"size": "1024x1024"}, + extra_body={"height": 512, "width": 512}, ), ) ] @@ -85,7 +85,24 @@ config_builder.add_column( dd.SamplerColumnConfig( - name="subject", + name="style", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=[ + "photorealistic", + "oil painting", + "watercolor", + "digital art", + "sketch", + "anime", + ], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="dog_breed", sampler_type=dd.SamplerType.CATEGORY, params=dd.CategorySamplerParams( values=[ @@ -98,6 +115,58 @@ "a Corgi", "a Siberian Husky", "a Dalmatian", + "a Yorkshire Terrier", + "a Boxer", + "a Dachshund", + "a Doberman Pinscher", + "a Shih Tzu", + "a Chihuahua", + "a Border Collie", + "an Australian Shepherd", + "a Cocker Spaniel", + "a Maltese", + "a Pomeranian", + "a Saint Bernard", + "a Great Dane", + "an Akita", + "a Samoyed", + "a Boston Terrier", + ], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="cat_breed", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=[ + "a Persian", + "a Maine Coon", + "a Siamese", + "a Ragdoll", + "a Bengal", + "an Abyssinian", + "a British Shorthair", + "a Sphynx", + "a Scottish Fold", + "a Russian Blue", + "a Birman", + "an Oriental Shorthair", + "a Norwegian Forest Cat", + "a Devon Rex", + "a Burmese", + "an Egyptian Mau", + "a Tonkinese", + "a Himalayan", + "a Savannah", + "a Chartreux", + "a Somali", + "a Manx", + "a Turkish Angora", + "a Balinese", + "an American Shorthair", ], ), ) @@ -105,7 +174,7 @@ config_builder.add_column( dd.SamplerColumnConfig( - name="age", + name="dog_age", sampler_type=dd.SamplerType.CATEGORY, params=dd.CategorySamplerParams( values=["1-3", "3-6", "6-9", "9-12", "12-15"], @@ -115,24 +184,27 @@ config_builder.add_column( dd.SamplerColumnConfig( - name="style", + name="cat_age", sampler_type=dd.SamplerType.CATEGORY, params=dd.CategorySamplerParams( - values=[ - "photorealistic", - "oil painting", - "watercolor", - "digital art", - "sketch", - "anime", - ], + values=["1-3", "3-6", "6-9", "9-12", "12-18"], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="dog_look_direction", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=["left", "right", "front", "up", "down"], ), ) ) config_builder.add_column( dd.SamplerColumnConfig( - name="look_direction", + name="cat_look_direction", sampler_type=dd.SamplerType.CATEGORY, params=dd.CategorySamplerParams( values=["left", "right", "front", "up", "down"], @@ -142,7 +214,7 @@ config_builder.add_column( dd.SamplerColumnConfig( - name="emotion", + name="dog_emotion", sampler_type=dd.SamplerType.CATEGORY, params=dd.CategorySamplerParams( values=["happy", "curious", "serious", "sleepy", "excited"], @@ -150,12 +222,24 @@ ) ) +config_builder.add_column( + dd.SamplerColumnConfig( + name="cat_emotion", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=["aloof", "curious", "content", "sleepy", "playful"], + ), + ) +) + config_builder.add_column( dd.ImageColumnConfig( name="generated_image", prompt=( - "A {{ style }} portrait of {{ subject }} {{ age }} years old looking {{ look_direction }} " - "towards a crowd of the same kind with an {{ emotion }} expression." + """ +A {{ style }} family pet portrait of a {{ dog_breed }} dog of {{ dog_age }} years old looking {{dog_look_direction}} with an {{ dog_emotion }} expression and +{{ cat_breed }} cat of {{ cat_age }} years old looking {{ cat_look_direction }} with an {{ cat_emotion }} expression in the background. Both subjects should be in focus. + """ ), model_alias=MODEL_ALIAS, ) From f5c6cf9418bd432d4b121d6ce82e7f30586e43df Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 18:57:09 -0700 Subject: [PATCH 35/64] organize image artifacts by column name --- .../column_generators/generators/image.py | 6 +- .../engine/storage/media_storage.py | 18 ++-- .../generators/test_image.py | 23 +++-- .../engine/storage/test_media_storage.py | 92 +++++++++++++++---- 4 files changed, 107 insertions(+), 32 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py index 11bc732c..55721916 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py @@ -70,7 +70,11 @@ def generate(self, data: dict) -> dict: base64_images = self.model.generate_image(prompt=prompt, multi_modal_context=multi_modal_context) # Store via media storage (mode determines disk vs dataframe storage) - results = [self.media_storage.save_base64_image(base64_image) for base64_image in base64_images] + # Use column name as subfolder to organize images + results = [ + self.media_storage.save_base64_image(base64_image, subfolder_name=self.config.name) + for base64_image in base64_images + ] data[self.config.name] = results return data diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py index ddac3459..df83e331 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py @@ -66,14 +66,15 @@ def _ensure_images_directory(self) -> None: """Create images directory if it doesn't exist (lazy initialization).""" self.images_dir.mkdir(parents=True, exist_ok=True) - def save_base64_image(self, base64_data: str) -> str: + def save_base64_image(self, base64_data: str, subfolder_name: str) -> str: """Save or return base64 image based on storage mode. Args: base64_data: Base64 encoded image string (with or without data URI prefix) + subfolder_name: Subfolder name to organize images (e.g., "images//") Returns: - DISK mode: Relative path to saved image (e.g., "images/f47ac10b-58cc.png") + DISK mode: Relative path to saved image (e.g., "images/subfolder_name/f47ac10b-58cc.png") DATAFRAME mode: Original base64 data string Raises: @@ -85,8 +86,11 @@ def save_base64_image(self, base64_data: str) -> str: return base64_data # DISK mode: save to disk, validate, and return relative path - # Ensure images directory exists (lazy initialization) - self._ensure_images_directory() + # Determine the target directory (organized by subfolder) + target_dir = self.images_dir / subfolder_name + + # Ensure target directory exists (lazy initialization) + target_dir.mkdir(parents=True, exist_ok=True) # Decode base64 to bytes image_bytes = decode_base64_image(base64_data) @@ -97,8 +101,10 @@ def save_base64_image(self, base64_data: str) -> str: # Generate unique filename image_id = uuid.uuid4() filename = f"{image_id}.{image_format.value}" - full_path = self.images_dir / filename - relative_path = f"{self.images_subdir}/{filename}" + full_path = target_dir / filename + + # Build relative path + relative_path = f"{self.images_subdir}/{subfolder_name}/{filename}" # Write to disk with open(full_path, "wb") as f: diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py index b433bc55..ca5cbfae 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py @@ -43,7 +43,10 @@ def test_image_cell_generator_generate_with_storage( """Test generate with media storage (create mode) - saves to disk.""" # Setup mock media storage mock_storage = Mock() - mock_storage.save_base64_image.side_effect = ["images/uuid1.png", "images/uuid2.png"] + mock_storage.save_base64_image.side_effect = [ + "images/test_image/uuid1.png", + "images/test_image/uuid2.png", + ] stub_resource_provider.artifact_storage.media_storage = mock_storage with patch.object( @@ -54,17 +57,20 @@ def test_image_cell_generator_generate_with_storage( generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) data = generator.generate(data={"style": "photorealistic", "subject": "cat"}) - # Check that column was added with relative paths + # Check that column was added with relative paths (organized in subfolder) assert stub_image_column_config.name in data - assert data[stub_image_column_config.name] == ["images/uuid1.png", "images/uuid2.png"] + assert data[stub_image_column_config.name] == [ + "images/test_image/uuid1.png", + "images/test_image/uuid2.png", + ] # Verify model was called with rendered prompt mock_generate.assert_called_once_with(prompt="A photorealistic image of cat", multi_modal_context=None) - # Verify storage was called for each image + # Verify storage was called for each image with subfolder name assert mock_storage.save_base64_image.call_count == 2 - mock_storage.save_base64_image.assert_any_call("base64_image_1") - mock_storage.save_base64_image.assert_any_call("base64_image_2") + mock_storage.save_base64_image.assert_any_call("base64_image_1", subfolder_name="test_image") + mock_storage.save_base64_image.assert_any_call("base64_image_2", subfolder_name="test_image") def test_image_cell_generator_generate_in_dataframe_mode( @@ -91,6 +97,11 @@ def test_image_cell_generator_generate_in_dataframe_mode( # Verify model was called with rendered prompt mock_generate.assert_called_once_with(prompt="A watercolor image of dog", multi_modal_context=None) + # Verify storage was called for each image with subfolder name (even in DATAFRAME mode) + assert mock_storage.save_base64_image.call_count == 2 + mock_storage.save_base64_image.assert_any_call("base64_image_1", subfolder_name="test_image") + mock_storage.save_base64_image.assert_any_call("base64_image_2", subfolder_name="test_image") + def test_image_cell_generator_missing_columns_error(stub_image_column_config, stub_resource_provider): """Test that missing required columns raises ValueError.""" diff --git a/packages/data-designer-engine/tests/engine/storage/test_media_storage.py b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py index abd17afe..105348d2 100644 --- a/packages/data-designer-engine/tests/engine/storage/test_media_storage.py +++ b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py @@ -62,10 +62,10 @@ def test_media_storage_init_custom_subdir(tmp_path): def test_save_base64_image_png(media_storage, sample_base64_png): """Test saving a PNG image from base64.""" - relative_path = media_storage.save_base64_image(sample_base64_png) + relative_path = media_storage.save_base64_image(sample_base64_png, subfolder_name="test_column") - # Check return value format - assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + # Check return value format (organized by column name) + assert relative_path.startswith(f"{IMAGES_SUBDIR}/test_column/") assert relative_path.endswith(".png") # Check file exists on disk @@ -80,10 +80,10 @@ def test_save_base64_image_png(media_storage, sample_base64_png): def test_save_base64_image_jpg(media_storage, sample_base64_jpg): """Test saving a JPEG image from base64.""" - relative_path = media_storage.save_base64_image(sample_base64_jpg) + relative_path = media_storage.save_base64_image(sample_base64_jpg, subfolder_name="test_column") - # Check return value format - assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + # Check return value format (organized by column name) + assert relative_path.startswith(f"{IMAGES_SUBDIR}/test_column/") assert relative_path.endswith(".jpg") # Check file exists on disk @@ -94,10 +94,10 @@ def test_save_base64_image_jpg(media_storage, sample_base64_jpg): def test_save_base64_image_with_data_uri(media_storage, sample_base64_png): """Test saving image from data URI format.""" data_uri = f"data:image/png;base64,{sample_base64_png}" - relative_path = media_storage.save_base64_image(data_uri) + relative_path = media_storage.save_base64_image(data_uri, subfolder_name="test_column") - # Should successfully extract base64 and save - assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + # Should successfully extract base64 and save (organized by column name) + assert relative_path.startswith(f"{IMAGES_SUBDIR}/test_column/") assert relative_path.endswith(".png") # Verify file exists and content is correct @@ -111,13 +111,13 @@ def test_save_base64_image_with_data_uri(media_storage, sample_base64_png): def test_save_base64_image_invalid_base64_raises_error(media_storage): """Test that invalid base64 data raises ValueError.""" with pytest.raises(ValueError, match="Invalid base64"): - media_storage.save_base64_image("not-valid-base64!!!") + media_storage.save_base64_image("not-valid-base64!!!", subfolder_name="test_column") def test_save_base64_image_multiple_images_unique_filenames(media_storage, sample_base64_png): """Test that multiple images get unique filenames.""" - path1 = media_storage.save_base64_image(sample_base64_png) - path2 = media_storage.save_base64_image(sample_base64_png) + path1 = media_storage.save_base64_image(sample_base64_png, subfolder_name="test_column") + path2 = media_storage.save_base64_image(sample_base64_png, subfolder_name="test_column") # Paths should be different (different UUIDs) assert path1 != path2 @@ -131,8 +131,8 @@ def test_save_base64_image_disk_mode_validates(tmp_path, sample_base64_png): """Test that DISK mode validates images.""" storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DISK) # Should succeed with valid image - relative_path = storage.save_base64_image(sample_base64_png) - assert relative_path.startswith(f"{IMAGES_SUBDIR}/") + relative_path = storage.save_base64_image(sample_base64_png, subfolder_name="test_column") + assert relative_path.startswith(f"{IMAGES_SUBDIR}/test_column/") def test_save_base64_image_disk_mode_corrupted_image_raises_error(tmp_path): @@ -144,18 +144,20 @@ def test_save_base64_image_disk_mode_corrupted_image_raises_error(tmp_path): corrupted_base64 = base64.b64encode(corrupted_bytes).decode() with pytest.raises(ValueError, match="Image validation failed"): - storage.save_base64_image(corrupted_base64) + storage.save_base64_image(corrupted_base64, subfolder_name="test_column") # Check that no files were left behind (cleanup on validation failure) - assert len(list(storage.images_dir.iterdir())) == 0 + column_dir = storage.images_dir / "test_column" + if column_dir.exists(): + assert len(list(column_dir.iterdir())) == 0 def test_save_base64_image_dataframe_mode_returns_base64(tmp_path, sample_base64_png): """Test that DATAFRAME mode returns base64 directly without disk operations.""" storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DATAFRAME) - # Should return the same base64 data - result = storage.save_base64_image(sample_base64_png) + # Should return the same base64 data (column_name is ignored in DATAFRAME mode) + result = storage.save_base64_image(sample_base64_png, subfolder_name="test_column") assert result == sample_base64_png # Directory should not be created in DATAFRAME mode (lazy initialization) @@ -165,10 +167,62 @@ def test_save_base64_image_dataframe_mode_returns_base64(tmp_path, sample_base64 def test_cleanup(media_storage, sample_base64_png): """Test cleanup removes images directory.""" # Save an image first - media_storage.save_base64_image(sample_base64_png) + media_storage.save_base64_image(sample_base64_png, subfolder_name="test_column") assert media_storage.images_dir.exists() assert len(list(media_storage.images_dir.iterdir())) > 0 # Cleanup should remove directory media_storage.cleanup() assert not media_storage.images_dir.exists() + + +def test_save_base64_image_with_subfolder_name(media_storage, sample_base64_png): + """Test saving image with subfolder name organizes into subdirectory.""" + subfolder = "test_subfolder" + relative_path = media_storage.save_base64_image(sample_base64_png, subfolder_name=subfolder) + + # Check return value format includes subfolder + assert relative_path.startswith(f"{IMAGES_SUBDIR}/{subfolder}/") + assert relative_path.endswith(".png") + + # Check file exists in correct subdirectory + full_path = media_storage.base_path / relative_path + assert full_path.exists() + assert full_path.parent.name == subfolder + + # Verify file content + saved_bytes = full_path.read_bytes() + expected_bytes = base64.b64decode(sample_base64_png) + assert saved_bytes == expected_bytes + + +def test_save_base64_image_with_different_subfolder_names(media_storage, sample_base64_png, sample_base64_jpg): + """Test that images with different subfolder names are stored in separate subdirectories.""" + path1 = media_storage.save_base64_image(sample_base64_png, subfolder_name="subfolder_a") + path2 = media_storage.save_base64_image(sample_base64_jpg, subfolder_name="subfolder_b") + + # Check paths are in different subdirectories + assert "subfolder_a" in path1 + assert "subfolder_b" in path2 + + # Check both directories exist + subfolder_a_dir = media_storage.images_dir / "subfolder_a" + subfolder_b_dir = media_storage.images_dir / "subfolder_b" + assert subfolder_a_dir.exists() + assert subfolder_b_dir.exists() + + # Check files exist in their respective directories + assert (media_storage.base_path / path1).exists() + assert (media_storage.base_path / path2).exists() + + +def test_save_base64_image_dataframe_mode_with_subfolder_name(tmp_path, sample_base64_png): + """Test that DATAFRAME mode returns base64 directly even with subfolder name.""" + storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DATAFRAME) + + # Should return the same base64 data regardless of subfolder name + result = storage.save_base64_image(sample_base64_png, subfolder_name="test_subfolder") + assert result == sample_base64_png + + # Directory should not be created in DATAFRAME mode + assert not storage.images_dir.exists() From 71e2bac46a4e952980498da69845d515cc024635 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 19:33:56 -0700 Subject: [PATCH 36/64] address pr comments --- .../config/utils/visualization.py | 10 ++++--- .../src/data_designer/engine/models/facade.py | 26 ++++++++++++++----- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/utils/visualization.py b/packages/data-designer-config/src/data_designer/config/utils/visualization.py index 910bc467..9d65cca5 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/visualization.py +++ b/packages/data-designer-config/src/data_designer/config/utils/visualization.py @@ -3,6 +3,7 @@ from __future__ import annotations +import html import json import os from collections import OrderedDict @@ -81,14 +82,17 @@ def _display_image_if_in_notebook(image_data: str, col_name: str, base_path: str # Use the base64 data directly without resizing img_base64 = base64_data + # Escape column name to prevent HTML injection + escaped_col_name = html.escape(col_name) + # Create HTML with caption and image in left-aligned container - html = f""" + html_content = f"""
-
πŸ–ΌοΈ {col_name}
+
πŸ–ΌοΈ {escaped_col_name}
""" - display(HTML(html)) + display(HTML(html_content)) return True except (ImportError, NameError): # Not in a notebook environment diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 51940e99..a14231ab 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -11,6 +11,7 @@ from data_designer.config.models import GenerationType, ModelConfig, ModelProvider from data_designer.config.utils.image_helpers import ( extract_base64_from_data_uri, + is_base64_image, is_image_diffusion_model, ) from data_designer.engine.mcp.errors import MCPConfigurationError @@ -40,6 +41,14 @@ def _identity(x: Any) -> Any: return x +def _try_extract_base64(data: str) -> str | None: + """Try to extract base64 image data from a data URI, returning None on failure.""" + try: + return extract_base64_from_data_uri(data) + except ValueError: + return None + + logger = logging.getLogger(__name__) @@ -410,19 +419,22 @@ def _generate_image_chat_completion( image_url = image["image_url"] if isinstance(image_url, dict) and "url" in image_url: - url = image_url["url"] - images.append(extract_base64_from_data_uri(url)) + if (b64 := _try_extract_base64(image_url["url"])) is not None: + images.append(b64) elif isinstance(image_url, str): - images.append(extract_base64_from_data_uri(image_url)) + if (b64 := _try_extract_base64(image_url)) is not None: + images.append(b64) # Fallback: treat as base64 string elif isinstance(image, str): - images.append(extract_base64_from_data_uri(image)) + if (b64 := _try_extract_base64(image)) is not None: + images.append(b64) - # Fallback: check content field + # Fallback: check content field if it looks like image data if not images: content = message.content or "" - if content: - images.append(extract_base64_from_data_uri(content)) + if content and (content.startswith("data:image/") or is_base64_image(content)): + if (b64 := _try_extract_base64(content)) is not None: + images.append(b64) if not images: raise ModelAPIError("No image data found in response") From 46138d81c3917a6abb967ae321b2216770fdc6fa Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 19:34:30 -0700 Subject: [PATCH 37/64] fix license headers --- .../data_designer/engine/column_generators/generators/image.py | 2 +- .../src/data_designer/engine/storage/__init__.py | 2 +- .../src/data_designer/engine/storage/media_storage.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py index 55721916..730e73bb 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py b/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py index 34c776d5..9d416c65 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from data_designer.engine.storage.media_storage import MediaStorage, StorageMode diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py index df83e331..9adefc89 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations From deb5fc2bdd7f05373ca5559a558441644581c3b3 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 19:37:19 -0700 Subject: [PATCH 38/64] generate collab notebooks --- docs/colab_notebooks/1-the-basics.ipynb | 66 +-- ...ctured-outputs-and-jinja-expressions.ipynb | 62 +-- .../3-seeding-with-a-dataset.ipynb | 58 +-- .../4-providing-images-as-context.ipynb | 70 +-- .../colab_notebooks/5-generating-images.ipynb | 437 ++++++++++++++++++ 5 files changed, 569 insertions(+), 124 deletions(-) create mode 100644 docs/colab_notebooks/5-generating-images.ipynb diff --git a/docs/colab_notebooks/1-the-basics.ipynb b/docs/colab_notebooks/1-the-basics.ipynb index ec2c5a99..ed8942df 100644 --- a/docs/colab_notebooks/1-the-basics.ipynb +++ b/docs/colab_notebooks/1-the-basics.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "c79eea7a", + "id": "945eebf8", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: The Basics\n", @@ -14,7 +14,7 @@ }, { "cell_type": "markdown", - "id": "2476f160", + "id": "8e8f2e22", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -26,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "3646f62e", + "id": "92d91bf1", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -37,7 +37,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3348e5c8", + "id": "0b9b4427", "metadata": {}, "outputs": [], "source": [ @@ -48,7 +48,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19cd9249", + "id": "8878d172", "metadata": {}, "outputs": [], "source": [ @@ -66,7 +66,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5a6d13a9", + "id": "4c92bfb3", "metadata": {}, "outputs": [], "source": [ @@ -76,7 +76,7 @@ }, { "cell_type": "markdown", - "id": "d445af5b", + "id": "4e39eed1", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -89,7 +89,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4df0031d", + "id": "70c96cfb", "metadata": {}, "outputs": [], "source": [ @@ -98,7 +98,7 @@ }, { "cell_type": "markdown", - "id": "0f69b576", + "id": "99d975c9", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -115,7 +115,7 @@ { "cell_type": "code", "execution_count": null, - "id": "65d9be99", + "id": "851228c8", "metadata": {}, "outputs": [], "source": [ @@ -145,7 +145,7 @@ }, { "cell_type": "markdown", - "id": "72582d09", + "id": "fefb639d", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -160,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8d7992b4", + "id": "0ba52672", "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "markdown", - "id": "741a15a0", + "id": "7cc2aefc", "metadata": {}, "source": [ "## 🎲 Getting started with sampler columns\n", @@ -186,7 +186,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c3879c70", + "id": "a5a34b1a", "metadata": {}, "outputs": [], "source": [ @@ -195,7 +195,7 @@ }, { "cell_type": "markdown", - "id": "1575ef81", + "id": "ee4d1b6a", "metadata": {}, "source": [ "Let's start designing our product review dataset by adding product category and subcategory columns.\n" @@ -204,7 +204,7 @@ { "cell_type": "code", "execution_count": null, - "id": "87a88d7b", + "id": "7782d790", "metadata": {}, "outputs": [], "source": [ @@ -285,7 +285,7 @@ }, { "cell_type": "markdown", - "id": "8c74b738", + "id": "f88e8b18", "metadata": {}, "source": [ "Next, let's add samplers to generate data related to the customer and their review.\n" @@ -294,7 +294,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4eb1da1f", + "id": "19174a73", "metadata": {}, "outputs": [], "source": [ @@ -331,7 +331,7 @@ }, { "cell_type": "markdown", - "id": "4324d869", + "id": "01438115", "metadata": {}, "source": [ "## 🦜 LLM-generated columns\n", @@ -346,7 +346,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1302a503", + "id": "9c8f1275", "metadata": {}, "outputs": [], "source": [ @@ -382,7 +382,7 @@ }, { "cell_type": "markdown", - "id": "7cf8241b", + "id": "f61e3771", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -399,7 +399,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6fc6cf39", + "id": "7f8dc56e", "metadata": {}, "outputs": [], "source": [ @@ -409,7 +409,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c929e068", + "id": "5b66172a", "metadata": {}, "outputs": [], "source": [ @@ -420,7 +420,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dfb04e2a", + "id": "b0eaa931", "metadata": {}, "outputs": [], "source": [ @@ -430,7 +430,7 @@ }, { "cell_type": "markdown", - "id": "adb879da", + "id": "122d099d", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -443,7 +443,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ff58dd9f", + "id": "f40f7ba0", "metadata": {}, "outputs": [], "source": [ @@ -453,7 +453,7 @@ }, { "cell_type": "markdown", - "id": "57c7355d", + "id": "597c41ec", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -466,7 +466,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df49db99", + "id": "acf8caa3", "metadata": {}, "outputs": [], "source": [ @@ -476,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2bbc48dd", + "id": "697e9090", "metadata": {}, "outputs": [], "source": [ @@ -489,7 +489,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dc0673fa", + "id": "18f34e66", "metadata": {}, "outputs": [], "source": [ @@ -501,7 +501,7 @@ }, { "cell_type": "markdown", - "id": "7688217b", + "id": "4c498f62", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", @@ -512,7 +512,9 @@ "\n", "- [Seeding synthetic data generation with an external dataset](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/3-seeding-with-a-dataset/)\n", "\n", - "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n" + "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n", + "\n", + "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/)\n" ] } ], diff --git a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb index c813ea50..49be6edb 100644 --- a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb +++ b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "258752cd", + "id": "bd333de9", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Structured Outputs and Jinja Expressions\n", @@ -16,7 +16,7 @@ }, { "cell_type": "markdown", - "id": "fc4217c3", + "id": "28fb2ee3", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -28,7 +28,7 @@ }, { "cell_type": "markdown", - "id": "2b831130", + "id": "fbeb3b2d", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -39,7 +39,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fa1eda43", + "id": "6ef3d2ae", "metadata": {}, "outputs": [], "source": [ @@ -50,7 +50,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5f014571", + "id": "07546806", "metadata": {}, "outputs": [], "source": [ @@ -68,7 +68,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f7409282", + "id": "81b00725", "metadata": {}, "outputs": [], "source": [ @@ -78,7 +78,7 @@ }, { "cell_type": "markdown", - "id": "8234dd4b", + "id": "a5cf694f", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -91,7 +91,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21633aed", + "id": "8320e2b0", "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "markdown", - "id": "9b215265", + "id": "348e2c5a", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -117,7 +117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76260638", + "id": "21019fc5", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "e6bfd93d", + "id": "7bf9d9af", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -162,7 +162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a0fbd497", + "id": "88abb685", "metadata": {}, "outputs": [], "source": [ @@ -171,7 +171,7 @@ }, { "cell_type": "markdown", - "id": "7faae40e", + "id": "d8e790c6", "metadata": {}, "source": [ "### πŸ§‘β€πŸŽ¨ Designing our data\n", @@ -198,7 +198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f2f94909", + "id": "64465ab1", "metadata": {}, "outputs": [], "source": [ @@ -226,7 +226,7 @@ }, { "cell_type": "markdown", - "id": "696f19f4", + "id": "cfbad124", "metadata": {}, "source": [ "Next, let's design our product review dataset using a few more tricks compared to the previous notebook.\n" @@ -235,7 +235,7 @@ { "cell_type": "code", "execution_count": null, - "id": "312b50cd", + "id": "aa93a4c9", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +344,7 @@ }, { "cell_type": "markdown", - "id": "ecd971ca", + "id": "74aa72fc", "metadata": {}, "source": [ "Next, we will use more advanced Jinja expressions to create new columns.\n", @@ -361,7 +361,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bda01ffc", + "id": "9ae978cc", "metadata": {}, "outputs": [], "source": [ @@ -414,7 +414,7 @@ }, { "cell_type": "markdown", - "id": "059613e1", + "id": "ec850f14", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -431,7 +431,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23c9b839", + "id": "cb18575e", "metadata": {}, "outputs": [], "source": [ @@ -441,7 +441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5adcdbd", + "id": "eee46dc6", "metadata": {}, "outputs": [], "source": [ @@ -452,7 +452,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1cc39cae", + "id": "082d0fc4", "metadata": {}, "outputs": [], "source": [ @@ -462,7 +462,7 @@ }, { "cell_type": "markdown", - "id": "bcca3f06", + "id": "e8d80b94", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -475,7 +475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6e1957ca", + "id": "4b0a7299", "metadata": {}, "outputs": [], "source": [ @@ -485,7 +485,7 @@ }, { "cell_type": "markdown", - "id": "9db283d3", + "id": "d7e0c925", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -498,7 +498,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30826883", + "id": "b599d759", "metadata": {}, "outputs": [], "source": [ @@ -508,7 +508,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88d4d3bd", + "id": "07a7c0da", "metadata": {}, "outputs": [], "source": [ @@ -521,7 +521,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8762a2bb", + "id": "7760dffa", "metadata": {}, "outputs": [], "source": [ @@ -533,7 +533,7 @@ }, { "cell_type": "markdown", - "id": "0375fcd2", + "id": "6d19000a", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", @@ -542,7 +542,9 @@ "\n", "- [Seeding synthetic data generation with an external dataset](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/3-seeding-with-a-dataset/)\n", "\n", - "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n" + "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n", + "\n", + "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/)\n" ] } ], diff --git a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb index c5d427d0..468aa795 100644 --- a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb +++ b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "b2a3e544", + "id": "573c3e7b", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Seeding Synthetic Data Generation with an External Dataset\n", @@ -16,7 +16,7 @@ }, { "cell_type": "markdown", - "id": "d57c4f0a", + "id": "63f6c36d", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -28,7 +28,7 @@ }, { "cell_type": "markdown", - "id": "f7da8723", + "id": "02cc81c7", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -39,7 +39,7 @@ { "cell_type": "code", "execution_count": null, - "id": "90a12556", + "id": "18d51631", "metadata": {}, "outputs": [], "source": [ @@ -50,7 +50,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8fcdfde5", + "id": "67c55f6b", "metadata": {}, "outputs": [], "source": [ @@ -68,7 +68,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5899e85c", + "id": "cfe2ff62", "metadata": {}, "outputs": [], "source": [ @@ -78,7 +78,7 @@ }, { "cell_type": "markdown", - "id": "6c093c90", + "id": "bdbc5b03", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -91,7 +91,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6a2066fe", + "id": "55d9caf1", "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "markdown", - "id": "f5e81142", + "id": "aa1623bc", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -117,7 +117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "880012ea", + "id": "9d1310cf", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "4b77a92c", + "id": "e64ce3b7", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -162,7 +162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f4ab6628", + "id": "dafd6155", "metadata": {}, "outputs": [], "source": [ @@ -171,7 +171,7 @@ }, { "cell_type": "markdown", - "id": "26fb0a63", + "id": "7c01f11c", "metadata": {}, "source": [ "## πŸ₯ Prepare a seed dataset\n", @@ -196,7 +196,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84908e88", + "id": "7941073f", "metadata": {}, "outputs": [], "source": [ @@ -214,7 +214,7 @@ }, { "cell_type": "markdown", - "id": "1947e70a", + "id": "a68c7d55", "metadata": {}, "source": [ "## 🎨 Designing our synthetic patient notes dataset\n", @@ -227,7 +227,7 @@ { "cell_type": "code", "execution_count": null, - "id": "be2fbad1", + "id": "f1b3d4d4", "metadata": {}, "outputs": [], "source": [ @@ -308,7 +308,7 @@ }, { "cell_type": "markdown", - "id": "8fcce5dc", + "id": "eff1bf9f", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -325,7 +325,7 @@ { "cell_type": "code", "execution_count": null, - "id": "82dc02f8", + "id": "b5955230", "metadata": {}, "outputs": [], "source": [ @@ -335,7 +335,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1f2d1583", + "id": "062a7294", "metadata": {}, "outputs": [], "source": [ @@ -346,7 +346,7 @@ { "cell_type": "code", "execution_count": null, - "id": "62a9173b", + "id": "6378e1be", "metadata": {}, "outputs": [], "source": [ @@ -356,7 +356,7 @@ }, { "cell_type": "markdown", - "id": "5263e705", + "id": "51e5175e", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -369,7 +369,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5295320f", + "id": "891b6860", "metadata": {}, "outputs": [], "source": [ @@ -379,7 +379,7 @@ }, { "cell_type": "markdown", - "id": "3ecc195f", + "id": "0f52668f", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -392,7 +392,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3865fb59", + "id": "ed083bd8", "metadata": {}, "outputs": [], "source": [ @@ -402,7 +402,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a7acf2b0", + "id": "039c42e4", "metadata": {}, "outputs": [], "source": [ @@ -415,7 +415,7 @@ { "cell_type": "code", "execution_count": null, - "id": "81a6e999", + "id": "623ca205", "metadata": {}, "outputs": [], "source": [ @@ -427,14 +427,16 @@ }, { "cell_type": "markdown", - "id": "4503b1cf", + "id": "0a7e7d42", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", "\n", "Check out the following notebook to learn more about:\n", "\n", - "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n" + "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n", + "\n", + "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/)\n" ] } ], diff --git a/docs/colab_notebooks/4-providing-images-as-context.ipynb b/docs/colab_notebooks/4-providing-images-as-context.ipynb index cd175537..62ac63e8 100644 --- a/docs/colab_notebooks/4-providing-images-as-context.ipynb +++ b/docs/colab_notebooks/4-providing-images-as-context.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "90dda708", + "id": "731384ed", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Providing Images as Context for Vision-Based Data Generation" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "52ccb1e5", + "id": "bc66dd23", "metadata": {}, "source": [ "#### πŸ“š What you'll learn\n", @@ -25,7 +25,7 @@ }, { "cell_type": "markdown", - "id": "9627c4eb", + "id": "4539a931", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -37,7 +37,7 @@ }, { "cell_type": "markdown", - "id": "1817171a", + "id": "f88809bf", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -48,7 +48,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1f15a669", + "id": "3628d4c4", "metadata": {}, "outputs": [], "source": [ @@ -59,7 +59,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1201c93b", + "id": "7fcf0f75", "metadata": {}, "outputs": [], "source": [ @@ -77,7 +77,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f814b76c", + "id": "6654714a", "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "markdown", - "id": "ac423d57", + "id": "22488cb7", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -113,7 +113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3c655c2d", + "id": "39913ca0", "metadata": {}, "outputs": [], "source": [ @@ -122,7 +122,7 @@ }, { "cell_type": "markdown", - "id": "7d41e922", + "id": "fba112ab", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -139,7 +139,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a8b5f4bf", + "id": "70fd86dd", "metadata": {}, "outputs": [], "source": [ @@ -162,7 +162,7 @@ }, { "cell_type": "markdown", - "id": "6455fc58", + "id": "810c7457", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -177,7 +177,7 @@ { "cell_type": "code", "execution_count": null, - "id": "462c2e01", + "id": "9b2204d0", "metadata": {}, "outputs": [], "source": [ @@ -186,7 +186,7 @@ }, { "cell_type": "markdown", - "id": "31369d10", + "id": "29e3dae5", "metadata": {}, "source": [ "### 🌱 Seed Dataset Creation\n", @@ -203,7 +203,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55d9432a", + "id": "e2cc3506", "metadata": {}, "outputs": [], "source": [ @@ -218,7 +218,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8614c4e9", + "id": "7a821067", "metadata": {}, "outputs": [], "source": [ @@ -266,7 +266,7 @@ { "cell_type": "code", "execution_count": null, - "id": "80550e46", + "id": "359d144b", "metadata": {}, "outputs": [], "source": [ @@ -284,7 +284,7 @@ { "cell_type": "code", "execution_count": null, - "id": "65ced9bb", + "id": "985cd308", "metadata": {}, "outputs": [], "source": [ @@ -294,7 +294,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34b210e8", + "id": "6a8cb414", "metadata": {}, "outputs": [], "source": [ @@ -306,7 +306,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d506903d", + "id": "a57e1b73", "metadata": {}, "outputs": [], "source": [ @@ -335,7 +335,7 @@ }, { "cell_type": "markdown", - "id": "b91032a2", + "id": "7518100a", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -352,7 +352,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4bd947de", + "id": "4c1fe540", "metadata": {}, "outputs": [], "source": [ @@ -362,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0ff4c07", + "id": "bceafe91", "metadata": {}, "outputs": [], "source": [ @@ -373,7 +373,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e97e4dfe", + "id": "20f4ace5", "metadata": {}, "outputs": [], "source": [ @@ -383,7 +383,7 @@ }, { "cell_type": "markdown", - "id": "0a284c12", + "id": "16a86d56", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -396,7 +396,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2570e7fd", + "id": "c1bbae97", "metadata": {}, "outputs": [], "source": [ @@ -406,7 +406,7 @@ }, { "cell_type": "markdown", - "id": "28b8eb5a", + "id": "d8d7604f", "metadata": {}, "source": [ "### πŸ”Ž Visual Inspection\n", @@ -417,7 +417,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5d0d9336", + "id": "27c0636c", "metadata": { "lines_to_next_cell": 2 }, @@ -441,7 +441,7 @@ }, { "cell_type": "markdown", - "id": "1c257a81", + "id": "f6b99539", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -454,7 +454,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6d840e9", + "id": "e5d53787", "metadata": {}, "outputs": [], "source": [ @@ -464,7 +464,7 @@ { "cell_type": "code", "execution_count": null, - "id": "909e6f3f", + "id": "1f859e49", "metadata": {}, "outputs": [], "source": [ @@ -477,7 +477,7 @@ { "cell_type": "code", "execution_count": null, - "id": "adbb4cae", + "id": "6688e3c5", "metadata": {}, "outputs": [], "source": [ @@ -489,7 +489,7 @@ }, { "cell_type": "markdown", - "id": "d085584c", + "id": "28635b09", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", @@ -499,7 +499,9 @@ "- Experiment with different vision models for specific document types\n", "- Try different prompt variations to generate specialized descriptions (e.g., technical details, key findings)\n", "- Combine vision-based summaries with other column types for multi-modal workflows\n", - "- Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering\n" + "- Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering\n", + "\n", + "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) with Data Designer\n" ] } ], diff --git a/docs/colab_notebooks/5-generating-images.ipynb b/docs/colab_notebooks/5-generating-images.ipynb new file mode 100644 index 00000000..485fe258 --- /dev/null +++ b/docs/colab_notebooks/5-generating-images.ipynb @@ -0,0 +1,437 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0ee289e6", + "metadata": {}, + "source": [ + "# 🎨 Data Designer Tutorial: Generating Images\n", + "\n", + "#### πŸ“š What you'll learn\n", + "\n", + "This notebook shows how to generate synthetic image data with Data Designer using image-generation models.\n", + "\n", + "- πŸ–ΌοΈ **Image generation columns**: Add columns that produce images from text prompts\n", + "- πŸ“ **Jinja2 prompts**: Drive diversity by referencing other columns in your prompt template\n", + "- πŸ’Ύ **Preview vs create**: Preview stores base64 in the dataframe; create saves images to disk and stores paths\n", + "\n", + "Data Designer supports both **diffusion** (e.g. DALLΒ·E, Stable Diffusion, Imagen) and **autoregressive** (e.g. Gemini image, GPT image) models; the API is chosen automatically from the model name.\n", + "\n", + "If this is your first time using Data Designer, we recommend starting with the [first notebook](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/) in this tutorial series.\n" + ] + }, + { + "cell_type": "markdown", + "id": "86f748c1", + "metadata": {}, + "source": [ + "### πŸ“¦ Import Data Designer\n", + "\n", + "- `data_designer.config` provides the configuration API.\n", + "- `DataDesigner` is the main interface for generation.\n" + ] + }, + { + "cell_type": "markdown", + "id": "c610ee22", + "metadata": {}, + "source": [ + "### ⚑ Colab Setup\n", + "\n", + "Run the cells below to install the dependencies and set up the API key. If you don't have an API key, you can generate one from [build.nvidia.com](https://build.nvidia.com).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "818ca495", + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install -U data-designer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f165bb15", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "from google.colab import userdata\n", + "\n", + "try:\n", + " os.environ[\"NVIDIA_API_KEY\"] = userdata.get(\"NVIDIA_API_KEY\")\n", + "except userdata.SecretNotFoundError:\n", + " os.environ[\"NVIDIA_API_KEY\"] = getpass.getpass(\"Enter your NVIDIA API key: \")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5decfc83", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Image as IPImage\n", + "from IPython.display import display\n", + "\n", + "import data_designer.config as dd\n", + "from data_designer.interface import DataDesigner" + ] + }, + { + "cell_type": "markdown", + "id": "929f35d6", + "metadata": {}, + "source": [ + "### βš™οΈ Initialize the Data Designer interface\n", + "\n", + "When initialized without arguments, [default model providers](https://nvidia-nemo.github.io/DataDesigner/latest/concepts/models/default-model-settings/) are used. This tutorial uses [OpenRouter](https://openrouter.ai) with the Flux 2 Pro image model; set `OPENROUTER_API_KEY` in your environment.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4c8b7d7", + "metadata": {}, + "outputs": [], + "source": [ + "data_designer = DataDesigner()" + ] + }, + { + "cell_type": "markdown", + "id": "8ed7b0b6", + "metadata": {}, + "source": [ + "### πŸŽ›οΈ Define an image-generation model\n", + "\n", + "- Use `ImageInferenceParams` so Data Designer treats this model as an image generator.\n", + "- Image options (size, quality, aspect ratio, etc.) are model-specific; pass them via `extra_body`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6b1ca66", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_PROVIDER = \"openrouter\"\n", + "MODEL_ID = \"black-forest-labs/flux.2-pro\"\n", + "MODEL_ALIAS = \"image-model\"\n", + "\n", + "model_configs = [\n", + " dd.ModelConfig(\n", + " alias=MODEL_ALIAS,\n", + " model=MODEL_ID,\n", + " provider=MODEL_PROVIDER,\n", + " inference_parameters=dd.ImageInferenceParams(\n", + " extra_body={\"height\": 512, \"width\": 512},\n", + " ),\n", + " )\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "498cfecf", + "metadata": {}, + "source": [ + "### πŸ—οΈ Build the config: samplers + image column\n", + "\n", + "We'll generate diverse **dog portrait** images: sampler columns drive subject (breed), age, style, look direction, and emotion. The image-generation column uses a Jinja2 prompt that references all of them.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e74fc7ab", + "metadata": {}, + "outputs": [], + "source": [ + "config_builder = dd.DataDesignerConfigBuilder(model_configs=model_configs)\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"style\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\n", + " \"photorealistic\",\n", + " \"oil painting\",\n", + " \"watercolor\",\n", + " \"digital art\",\n", + " \"sketch\",\n", + " \"anime\",\n", + " ],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"dog_breed\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\n", + " \"a Golden Retriever\",\n", + " \"a German Shepherd\",\n", + " \"a Labrador Retriever\",\n", + " \"a Bulldog\",\n", + " \"a Beagle\",\n", + " \"a Poodle\",\n", + " \"a Corgi\",\n", + " \"a Siberian Husky\",\n", + " \"a Dalmatian\",\n", + " \"a Yorkshire Terrier\",\n", + " \"a Boxer\",\n", + " \"a Dachshund\",\n", + " \"a Doberman Pinscher\",\n", + " \"a Shih Tzu\",\n", + " \"a Chihuahua\",\n", + " \"a Border Collie\",\n", + " \"an Australian Shepherd\",\n", + " \"a Cocker Spaniel\",\n", + " \"a Maltese\",\n", + " \"a Pomeranian\",\n", + " \"a Saint Bernard\",\n", + " \"a Great Dane\",\n", + " \"an Akita\",\n", + " \"a Samoyed\",\n", + " \"a Boston Terrier\",\n", + " ],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"cat_breed\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\n", + " \"a Persian\",\n", + " \"a Maine Coon\",\n", + " \"a Siamese\",\n", + " \"a Ragdoll\",\n", + " \"a Bengal\",\n", + " \"an Abyssinian\",\n", + " \"a British Shorthair\",\n", + " \"a Sphynx\",\n", + " \"a Scottish Fold\",\n", + " \"a Russian Blue\",\n", + " \"a Birman\",\n", + " \"an Oriental Shorthair\",\n", + " \"a Norwegian Forest Cat\",\n", + " \"a Devon Rex\",\n", + " \"a Burmese\",\n", + " \"an Egyptian Mau\",\n", + " \"a Tonkinese\",\n", + " \"a Himalayan\",\n", + " \"a Savannah\",\n", + " \"a Chartreux\",\n", + " \"a Somali\",\n", + " \"a Manx\",\n", + " \"a Turkish Angora\",\n", + " \"a Balinese\",\n", + " \"an American Shorthair\",\n", + " ],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"dog_age\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\"1-3\", \"3-6\", \"6-9\", \"9-12\", \"12-15\"],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"cat_age\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\"1-3\", \"3-6\", \"6-9\", \"9-12\", \"12-18\"],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"dog_look_direction\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\"left\", \"right\", \"front\", \"up\", \"down\"],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"cat_look_direction\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\"left\", \"right\", \"front\", \"up\", \"down\"],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"dog_emotion\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\"happy\", \"curious\", \"serious\", \"sleepy\", \"excited\"],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"cat_emotion\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\"aloof\", \"curious\", \"content\", \"sleepy\", \"playful\"],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.ImageColumnConfig(\n", + " name=\"generated_image\",\n", + " prompt=(\n", + " \"\"\"\n", + "A {{ style }} family pet portrait of a {{ dog_breed }} dog of {{ dog_age }} years old looking {{dog_look_direction}} with an {{ dog_emotion }} expression and\n", + "{{ cat_breed }} cat of {{ cat_age }} years old looking {{ cat_look_direction }} with an {{ cat_emotion }} expression in the background. Both subjects should be in focus.\n", + " \"\"\"\n", + " ),\n", + " model_alias=MODEL_ALIAS,\n", + " )\n", + ")\n", + "\n", + "data_designer.validate(config_builder)" + ] + }, + { + "cell_type": "markdown", + "id": "c592c820", + "metadata": {}, + "source": [ + "### πŸ” Preview: images as base64\n", + "\n", + "In **preview** mode, generated images are stored as base64 strings in the dataframe. Run the next cell to step through each record (images are shown in the sample record display, but only in a notebook environment).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eee17bb1", + "metadata": {}, + "outputs": [], + "source": [ + "preview = data_designer.preview(config_builder, num_records=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3cd320cc", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(len(preview.dataset)):\n", + " preview.display_sample_record()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffb5e188", + "metadata": {}, + "outputs": [], + "source": [ + "preview.dataset" + ] + }, + { + "cell_type": "markdown", + "id": "87b83328", + "metadata": {}, + "source": [ + "### πŸ†™ Create: images saved to disk\n", + "\n", + "In **create** mode, images are written to an `images/` folder with UUID filenames; the dataframe stores relative paths (e.g. `images/1d16b6e2-562f-4f51-91e5-baaa999ea916.png`).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8f9cc41", + "metadata": {}, + "outputs": [], + "source": [ + "results = data_designer.create(config_builder, num_records=5, dataset_name=\"tutorial-5-images\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d4453e5", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = results.load_dataset()\n", + "dataset.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "198301ab", + "metadata": {}, + "outputs": [], + "source": [ + "# Display all image from the created dataset. Paths are relative to the artifact output directory.\n", + "for index, row in dataset.iterrows():\n", + " path_or_list = row.get(\"generated_image\")\n", + " if path_or_list is not None:\n", + " for path in path_or_list:\n", + " base = results.artifact_storage.base_dataset_path\n", + " full_path = base / path\n", + " display(IPImage(data=full_path))" + ] + }, + { + "cell_type": "markdown", + "id": "2bdcef2b", + "metadata": {}, + "source": [ + "## ⏭️ Next steps\n", + "\n", + "- [The basics](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/): samplers and LLM text columns\n", + "- [Structured outputs and Jinja](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/2-structured-outputs-and-jinja-expressions/)\n", + "- [Seeding with a dataset](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/3-seeding-with-a-dataset/)\n", + "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From d11d049aad3474476a03a099c701f596c24c8ca2 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 19:39:10 -0700 Subject: [PATCH 39/64] move pillow to lib dep from notebook --- packages/data-designer-config/pyproject.toml | 1 + uv.lock | 2 ++ 2 files changed, 3 insertions(+) diff --git a/packages/data-designer-config/pyproject.toml b/packages/data-designer-config/pyproject.toml index 04af4adc..569c8fe0 100644 --- a/packages/data-designer-config/pyproject.toml +++ b/packages/data-designer-config/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "jinja2>=3.1.6,<4", "numpy>=1.23.5,<3", "pandas>=2.3.3,<3", + "pillow>=12.0.0,<13", "pyarrow>=19.0.1,<20", # Required for parquet I/O operations "pydantic[email]>=2.9.2,<3", "pygments>=2.19.2,<3", diff --git a/uv.lock b/uv.lock index b26a9385..6a4a8432 100644 --- a/uv.lock +++ b/uv.lock @@ -965,6 +965,7 @@ dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pandas" }, + { name = "pillow" }, { name = "pyarrow" }, { name = "pydantic", extra = ["email"] }, { name = "pygments" }, @@ -978,6 +979,7 @@ requires-dist = [ { name = "jinja2", specifier = ">=3.1.6,<4" }, { name = "numpy", specifier = ">=1.23.5,<3" }, { name = "pandas", specifier = ">=2.3.3,<3" }, + { name = "pillow", specifier = ">=12.0.0,<13" }, { name = "pyarrow", specifier = ">=19.0.1,<20" }, { name = "pydantic", extras = ["email"], specifier = ">=2.9.2,<3" }, { name = "pygments", specifier = ">=2.19.2,<3" }, From 511e1f26a3180579227d01598bdf7e90b2e07b26 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 19:41:49 -0700 Subject: [PATCH 40/64] update uv lock" --- pyproject.toml | 7 +++---- uv.lock | 37 ++++++++++++------------------------- 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f7b71536..988e6f04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,11 +42,11 @@ dev-dependencies = [ [dependency-groups] dev = [ "jsonpath-ng>=1.5.3,<2", - "pytest>=9.0.2,<10", - "pytest-asyncio>=1.3.0,<2", + "pytest>=8.3.3,<9", + "pytest-asyncio>=0.24.0,<1", "pytest-cov>=7.0.0,<8", "pytest-env>=1.2.0,<2", - "pytest-httpx>=0.36.0,<1", + "pytest-httpx>=0.35.0,<1", "pre-commit>=4.0.0,<5", ] docs = [ @@ -63,7 +63,6 @@ notebooks = [ "datasets>=4.0.0,<5", "ipykernel>=6.29.0,<7", "jupyter>=1.0.0,<2", - "pillow>=12.0.0,<13", ] recipes = [ "bm25s>=0.2.0,<1", diff --git a/uv.lock b/uv.lock index 6a4a8432..9a111de5 100644 --- a/uv.lock +++ b/uv.lock @@ -308,15 +308,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537, upload-time = "2025-02-01T15:17:37.39Z" }, ] -[[package]] -name = "backports-asyncio-runner" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, -] - [[package]] name = "backrefs" version = "6.1" @@ -1071,7 +1062,6 @@ notebooks = [ { name = "datasets" }, { name = "ipykernel" }, { name = "jupyter" }, - { name = "pillow" }, ] recipes = [ { name = "bm25s" }, @@ -1085,11 +1075,11 @@ requires-dist = [{ name = "matplotlib", specifier = ">=3.10.8" }] dev = [ { name = "jsonpath-ng", specifier = ">=1.5.3,<2" }, { name = "pre-commit", specifier = ">=4.0.0,<5" }, - { name = "pytest", specifier = ">=9.0.2,<10" }, - { name = "pytest-asyncio", specifier = ">=1.3.0,<2" }, + { name = "pytest", specifier = ">=8.3.3,<9" }, + { name = "pytest-asyncio", specifier = ">=0.24.0,<1" }, { name = "pytest-cov", specifier = ">=7.0.0,<8" }, { name = "pytest-env", specifier = ">=1.2.0,<2" }, - { name = "pytest-httpx", specifier = ">=0.36.0,<1" }, + { name = "pytest-httpx", specifier = ">=0.35.0,<1" }, { name = "ruff", specifier = ">=0.14.10,<1" }, ] docs = [ @@ -1106,7 +1096,6 @@ notebooks = [ { name = "datasets", specifier = ">=4.0.0,<5" }, { name = "ipykernel", specifier = ">=6.29.0,<7" }, { name = "jupyter", specifier = ">=1.0.0,<2" }, - { name = "pillow", specifier = ">=12.0.0,<13" }, ] recipes = [ { name = "bm25s", specifier = ">=0.2.0,<1" }, @@ -4407,7 +4396,7 @@ wheels = [ [[package]] name = "pytest" -version = "9.0.2" +version = "8.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, @@ -4418,23 +4407,21 @@ dependencies = [ { name = "pygments" }, { name = "tomli", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, + { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, ] [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "0.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, { name = "pytest" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/c4/453c52c659521066969523e87d85d54139bbd17b78f09532fb8eb8cdb58e/pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f", size = 54156, upload-time = "2025-03-25T06:22:28.883Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, + { url = "https://files.pythonhosted.org/packages/20/7f/338843f449ace853647ace35870874f69a764d251872ed1b4de9f234822c/pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0", size = 19694, upload-time = "2025-03-25T06:22:27.807Z" }, ] [[package]] @@ -4466,15 +4453,15 @@ wheels = [ [[package]] name = "pytest-httpx" -version = "0.36.0" +version = "0.35.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/5574834da9499066fa1a5ea9c336f94dba2eae02298d36dab192fcf95c86/pytest_httpx-0.36.0.tar.gz", hash = "sha256:9edb66a5fd4388ce3c343189bc67e7e1cb50b07c2e3fc83b97d511975e8a831b", size = 56793, upload-time = "2025-12-02T16:34:57.414Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1f/89/5b12b7b29e3d0af3a4b9c071ee92fa25a9017453731a38f08ba01c280f4c/pytest_httpx-0.35.0.tar.gz", hash = "sha256:d619ad5d2e67734abfbb224c3d9025d64795d4b8711116b1a13f72a251ae511f", size = 54146, upload-time = "2024-11-28T19:16:54.237Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e2/d2/1eb1ea9c84f0d2033eb0b49675afdc71aa4ea801b74615f00f3c33b725e3/pytest_httpx-0.36.0-py3-none-any.whl", hash = "sha256:bd4c120bb80e142df856e825ec9f17981effb84d159f9fa29ed97e2357c3a9c8", size = 20229, upload-time = "2025-12-02T16:34:56.45Z" }, + { url = "https://files.pythonhosted.org/packages/b0/ed/026d467c1853dd83102411a78126b4842618e86c895f93528b0528c7a620/pytest_httpx-0.35.0-py3-none-any.whl", hash = "sha256:ee11a00ffcea94a5cbff47af2114d34c5b231c326902458deed73f9c459fd744", size = 19442, upload-time = "2024-11-28T19:16:52.787Z" }, ] [[package]] From 2b22df8517fd225a462608b19b568a7561968e92 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 9 Feb 2026 20:00:24 -0700 Subject: [PATCH 41/64] remove legacy flag from display_sample_record --- .../data_designer/config/utils/visualization.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/utils/visualization.py b/packages/data-designer-config/src/data_designer/config/utils/visualization.py index 9d65cca5..2132b83b 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/visualization.py +++ b/packages/data-designer-config/src/data_designer/config/utils/visualization.py @@ -47,13 +47,12 @@ console = Console() -def _display_image_if_in_notebook(image_data: str, col_name: str, base_path: str | None = None) -> bool: +def _display_image_if_in_notebook(image_data: str, col_name: str) -> bool: """Display image with caption in Jupyter notebook if available. Args: image_data: Base64-encoded image data, data URI, or file path. col_name: Name of the column (used for caption). - base_path: Optional base path to resolve relative image paths. Returns: True if image was displayed, False otherwise. @@ -66,7 +65,7 @@ def _display_image_if_in_notebook(image_data: str, col_name: str, base_path: str # Check if it's a file path and load it if is_image_path(image_data) and not image_data.startswith(""}, + {"image_url": ""}, + ] + + mock_choice = MagicMock() + mock_choice.message = mock_message + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + mock_completion.return_value = mock_response + + # Generate images + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): + images = stub_model_facade.generate_image(prompt="test prompt") + + # Verify results + assert len(images) == 2 + assert images == ["image1", "image2"] + + +@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +def test_generate_image_chat_completion_with_plain_strings( + mock_completion: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image handles images as plain strings.""" + # Create mock message with images as plain strings + mock_message = MagicMock() + mock_message.role = "assistant" + mock_message.content = "" + mock_message.images = [ + "", + "image2", # Plain base64 without data URI prefix + ] + + mock_choice = MagicMock() + mock_choice.message = mock_message + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + mock_completion.return_value = mock_response + + # Generate images + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): + images = stub_model_facade.generate_image(prompt="test prompt") + + # Verify results + assert len(images) == 2 + assert images == ["image1", "image2"] + + @patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) def test_generate_image_skip_usage_tracking( mock_image_generation: Any, diff --git a/packages/data-designer-engine/tests/engine/storage/test_media_storage.py b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py index f908a4c2..9d74734a 100644 --- a/packages/data-designer-engine/tests/engine/storage/test_media_storage.py +++ b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py @@ -60,13 +60,23 @@ def test_media_storage_init_custom_subdir(tmp_path): assert not storage.images_dir.exists() -def test_save_base64_image_png(media_storage, sample_base64_png): - """Test saving a PNG image from base64.""" - relative_path = media_storage.save_base64_image(sample_base64_png, subfolder_name="test_column") +@pytest.mark.parametrize( + "image_fixture,expected_extension", + [ + ("sample_base64_png", ".png"), + ("sample_base64_jpg", ".jpg"), + ], +) +def test_save_base64_image_format(media_storage, image_fixture, expected_extension, request): + """Test saving images from base64 in different formats.""" + # Get the actual fixture value using request.getfixturevalue + sample_base64 = request.getfixturevalue(image_fixture) + + relative_path = media_storage.save_base64_image(sample_base64, subfolder_name="test_column") # Check return value format (organized by column name) assert relative_path.startswith(f"{IMAGES_SUBDIR}/test_column/") - assert relative_path.endswith(".png") + assert relative_path.endswith(expected_extension) # Check file exists on disk full_path = media_storage.base_path / relative_path @@ -74,23 +84,10 @@ def test_save_base64_image_png(media_storage, sample_base64_png): # Verify file content saved_bytes = full_path.read_bytes() - expected_bytes = base64.b64decode(sample_base64_png) + expected_bytes = base64.b64decode(sample_base64) assert saved_bytes == expected_bytes -def test_save_base64_image_jpg(media_storage, sample_base64_jpg): - """Test saving a JPEG image from base64.""" - relative_path = media_storage.save_base64_image(sample_base64_jpg, subfolder_name="test_column") - - # Check return value format (organized by column name) - assert relative_path.startswith(f"{IMAGES_SUBDIR}/test_column/") - assert relative_path.endswith(".jpg") - - # Check file exists on disk - full_path = media_storage.base_path / relative_path - assert full_path.exists() - - def test_save_base64_image_with_data_uri(media_storage, sample_base64_png): """Test saving image from data URI format.""" data_uri = f"data:image/png;base64,{sample_base64_png}" From 5aa7e109faab4889e0fb5b1fa6121033caefd565 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 10 Feb 2026 15:26:17 -0700 Subject: [PATCH 54/64] Use regex for base64 character validation in is_base64_image --- .../src/data_designer/config/utils/image_helpers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py index c20c81ea..69ee5310 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py @@ -7,6 +7,7 @@ import base64 import io +import re from pathlib import Path from typing import TYPE_CHECKING @@ -23,6 +24,8 @@ # WEBP uses RIFF header - handled separately } +_BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/=]+$") + # Patterns for diffusion-based image models only (use image_generation API). IMAGE_DIFFUSION_MODEL_PATTERNS = ( "dall-e", @@ -152,9 +155,7 @@ def is_base64_image(value: str) -> bool: if value.startswith("data:image/"): return True # Check if it looks like base64 (at least 100 chars, contains only base64 chars) - if len(value) > 100 and all( - c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=" for c in value[:100] - ): + if len(value) > 100 and _BASE64_PATTERN.match(value[:100]): try: # Try to decode a small portion to verify it's valid base64 base64.b64decode(value[:100]) From ecaeb727b239427a63c0aaaf178145fbf321bf7c Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 10 Feb 2026 15:38:46 -0700 Subject: [PATCH 55/64] move to a constant --- .../data_designer/config/utils/image_helpers.py | 17 ++++------------- .../tests/config/utils/test_image_helpers.py | 10 ---------- 2 files changed, 4 insertions(+), 23 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py index 69ee5310..0fb949af 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py @@ -36,6 +36,8 @@ "imagen", ) +SUPPORTED_IMAGE_EXTENSIONS = [f".{fmt.value.lower()}" for fmt in ImageFormat] + def is_image_diffusion_model(model_name: str) -> bool: """Return True if the model is a diffusion-based image generation model. @@ -137,7 +139,7 @@ def is_image_path(value: str) -> bool: """ if not isinstance(value, str): return False - return any(value.lower().endswith(ext) for ext in get_supported_image_extensions()) + return any(value.lower().endswith(ext) for ext in SUPPORTED_IMAGE_EXTENSIONS) def is_base64_image(value: str) -> bool: @@ -176,9 +178,7 @@ def is_image_url(value: str) -> bool: """ if not isinstance(value, str): return False - return value.startswith(("http://", "https://")) and any( - ext in value.lower() for ext in get_supported_image_extensions() - ) + return value.startswith(("http://", "https://")) and any(ext in value.lower() for ext in SUPPORTED_IMAGE_EXTENSIONS) def load_image_path_to_base64(image_path: str, base_path: str | None = None) -> str | None: @@ -228,12 +228,3 @@ def validate_image(image_path: Path) -> None: img.verify() except Exception as e: raise ValueError(f"Image validation failed: {e}") from e - - -def get_supported_image_extensions() -> list[str]: - """Get list of supported image extensions from ImageFormat enum. - - Returns: - List of extensions with leading dot (e.g., [".png", ".jpg", ...]) - """ - return [f".{fmt.value}" for fmt in ImageFormat] diff --git a/packages/data-designer-config/tests/config/utils/test_image_helpers.py b/packages/data-designer-config/tests/config/utils/test_image_helpers.py index f24696e4..08ea3b50 100644 --- a/packages/data-designer-config/tests/config/utils/test_image_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_image_helpers.py @@ -14,7 +14,6 @@ decode_base64_image, detect_image_format, extract_base64_from_data_uri, - get_supported_image_extensions, is_base64_image, is_image_diffusion_model, is_image_path, @@ -204,15 +203,6 @@ def test_validate_image_nonexistent_raises_error(tmp_path): validate_image(image_path) -# Tests for get_supported_image_extensions - - -def test_get_supported_image_extensions_matches_enum(): - result = get_supported_image_extensions() - enum_values = [f".{fmt.value}" for fmt in ImageFormat] - assert set(result) == set(enum_values) - - # Additional tests for uncovered lines From 622b1c4f75013d184fefb46da95e2bf5a69a285c Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 10 Feb 2026 15:43:21 -0700 Subject: [PATCH 56/64] fix pyproject.toml --- pyproject.toml | 6 +++--- uv.lock | 35 +++++++++++++++++++++++------------ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 35566648..99d7b78c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,11 +39,11 @@ dev-dependencies = [ [dependency-groups] dev = [ "jsonpath-ng>=1.5.3,<2", - "pytest>=8.3.3,<9", - "pytest-asyncio>=0.24.0,<1", + "pytest>=9.0.2,<10", + "pytest-asyncio>=1.3.0,<2", "pytest-cov>=7.0.0,<8", "pytest-env>=1.2.0,<2", - "pytest-httpx>=0.35.0,<1", + "pytest-httpx>=0.36.0,<1", "pre-commit>=4.0.0,<5", ] docs = [ diff --git a/uv.lock b/uv.lock index 17306f0e..200d0b12 100644 --- a/uv.lock +++ b/uv.lock @@ -308,6 +308,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537, upload-time = "2025-02-01T15:17:37.39Z" }, ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, +] + [[package]] name = "backrefs" version = "6.1" @@ -905,11 +914,11 @@ recipes = [ dev = [ { name = "jsonpath-ng", specifier = ">=1.5.3,<2" }, { name = "pre-commit", specifier = ">=4.0.0,<5" }, - { name = "pytest", specifier = ">=8.3.3,<9" }, - { name = "pytest-asyncio", specifier = ">=0.24.0,<1" }, + { name = "pytest", specifier = ">=9.0.2,<10" }, + { name = "pytest-asyncio", specifier = ">=1.3.0,<2" }, { name = "pytest-cov", specifier = ">=7.0.0,<8" }, { name = "pytest-env", specifier = ">=1.2.0,<2" }, - { name = "pytest-httpx", specifier = ">=0.35.0,<1" }, + { name = "pytest-httpx", specifier = ">=0.36.0,<1" }, { name = "ruff", specifier = ">=0.14.10,<1" }, ] docs = [ @@ -3986,7 +3995,7 @@ wheels = [ [[package]] name = "pytest" -version = "8.4.2" +version = "9.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, @@ -3997,21 +4006,23 @@ dependencies = [ { name = "pygments" }, { name = "tomli", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, + { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] [[package]] name = "pytest-asyncio" -version = "0.26.0" +version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8e/c4/453c52c659521066969523e87d85d54139bbd17b78f09532fb8eb8cdb58e/pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f", size = 54156, upload-time = "2025-03-25T06:22:28.883Z" } +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/20/7f/338843f449ace853647ace35870874f69a764d251872ed1b4de9f234822c/pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0", size = 19694, upload-time = "2025-03-25T06:22:27.807Z" }, + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, ] [[package]] @@ -4043,15 +4054,15 @@ wheels = [ [[package]] name = "pytest-httpx" -version = "0.35.0" +version = "0.36.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1f/89/5b12b7b29e3d0af3a4b9c071ee92fa25a9017453731a38f08ba01c280f4c/pytest_httpx-0.35.0.tar.gz", hash = "sha256:d619ad5d2e67734abfbb224c3d9025d64795d4b8711116b1a13f72a251ae511f", size = 54146, upload-time = "2024-11-28T19:16:54.237Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/5574834da9499066fa1a5ea9c336f94dba2eae02298d36dab192fcf95c86/pytest_httpx-0.36.0.tar.gz", hash = "sha256:9edb66a5fd4388ce3c343189bc67e7e1cb50b07c2e3fc83b97d511975e8a831b", size = 56793, upload-time = "2025-12-02T16:34:57.414Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b0/ed/026d467c1853dd83102411a78126b4842618e86c895f93528b0528c7a620/pytest_httpx-0.35.0-py3-none-any.whl", hash = "sha256:ee11a00ffcea94a5cbff47af2114d34c5b231c326902458deed73f9c459fd744", size = 19442, upload-time = "2024-11-28T19:16:52.787Z" }, + { url = "https://files.pythonhosted.org/packages/e2/d2/1eb1ea9c84f0d2033eb0b49675afdc71aa4ea801b74615f00f3c33b725e3/pytest_httpx-0.36.0-py3-none-any.whl", hash = "sha256:bd4c120bb80e142df856e825ec9f17981effb84d159f9fa29ed97e2357c3a9c8", size = 20229, upload-time = "2025-12-02T16:34:56.45Z" }, ] [[package]] From 400e97b55a555c6fc788245c38862599a29b3189 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 10 Feb 2026 15:50:57 -0700 Subject: [PATCH 57/64] regen colab notebooks --- docs/colab_notebooks/1-the-basics.ipynb | 62 ++++++++--------- ...ctured-outputs-and-jinja-expressions.ipynb | 58 ++++++++-------- .../3-seeding-with-a-dataset.ipynb | 54 +++++++-------- .../4-providing-images-as-context.ipynb | 66 +++++++++---------- .../colab_notebooks/5-generating-images.ipynb | 44 ++++++------- 5 files changed, 142 insertions(+), 142 deletions(-) diff --git a/docs/colab_notebooks/1-the-basics.ipynb b/docs/colab_notebooks/1-the-basics.ipynb index ed8942df..f50209f7 100644 --- a/docs/colab_notebooks/1-the-basics.ipynb +++ b/docs/colab_notebooks/1-the-basics.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "945eebf8", + "id": "96178d08", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: The Basics\n", @@ -14,7 +14,7 @@ }, { "cell_type": "markdown", - "id": "8e8f2e22", + "id": "1d02a1d6", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -26,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "92d91bf1", + "id": "2292d817", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -37,7 +37,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0b9b4427", + "id": "8af621fc", "metadata": {}, "outputs": [], "source": [ @@ -48,7 +48,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8878d172", + "id": "70e6a11c", "metadata": {}, "outputs": [], "source": [ @@ -66,7 +66,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4c92bfb3", + "id": "41031828", "metadata": {}, "outputs": [], "source": [ @@ -76,7 +76,7 @@ }, { "cell_type": "markdown", - "id": "4e39eed1", + "id": "0b480b10", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -89,7 +89,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70c96cfb", + "id": "d434a8e2", "metadata": {}, "outputs": [], "source": [ @@ -98,7 +98,7 @@ }, { "cell_type": "markdown", - "id": "99d975c9", + "id": "f88f6792", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -115,7 +115,7 @@ { "cell_type": "code", "execution_count": null, - "id": "851228c8", + "id": "4261574c", "metadata": {}, "outputs": [], "source": [ @@ -145,7 +145,7 @@ }, { "cell_type": "markdown", - "id": "fefb639d", + "id": "bbbc3d58", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -160,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0ba52672", + "id": "92c0cf35", "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "markdown", - "id": "7cc2aefc", + "id": "44246c7d", "metadata": {}, "source": [ "## 🎲 Getting started with sampler columns\n", @@ -186,7 +186,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5a34b1a", + "id": "07d20f3f", "metadata": {}, "outputs": [], "source": [ @@ -195,7 +195,7 @@ }, { "cell_type": "markdown", - "id": "ee4d1b6a", + "id": "9d3c87b0", "metadata": {}, "source": [ "Let's start designing our product review dataset by adding product category and subcategory columns.\n" @@ -204,7 +204,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7782d790", + "id": "c646b021", "metadata": {}, "outputs": [], "source": [ @@ -285,7 +285,7 @@ }, { "cell_type": "markdown", - "id": "f88e8b18", + "id": "ff18b032", "metadata": {}, "source": [ "Next, let's add samplers to generate data related to the customer and their review.\n" @@ -294,7 +294,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19174a73", + "id": "78846d99", "metadata": {}, "outputs": [], "source": [ @@ -331,7 +331,7 @@ }, { "cell_type": "markdown", - "id": "01438115", + "id": "97059bfc", "metadata": {}, "source": [ "## 🦜 LLM-generated columns\n", @@ -346,7 +346,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9c8f1275", + "id": "98c66eff", "metadata": {}, "outputs": [], "source": [ @@ -382,7 +382,7 @@ }, { "cell_type": "markdown", - "id": "f61e3771", + "id": "ff2d52b9", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -399,7 +399,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7f8dc56e", + "id": "6e622478", "metadata": {}, "outputs": [], "source": [ @@ -409,7 +409,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5b66172a", + "id": "1addc7d8", "metadata": {}, "outputs": [], "source": [ @@ -420,7 +420,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b0eaa931", + "id": "7af4b9c3", "metadata": {}, "outputs": [], "source": [ @@ -430,7 +430,7 @@ }, { "cell_type": "markdown", - "id": "122d099d", + "id": "91d0ee89", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -443,7 +443,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f40f7ba0", + "id": "e1e3aed0", "metadata": {}, "outputs": [], "source": [ @@ -453,7 +453,7 @@ }, { "cell_type": "markdown", - "id": "597c41ec", + "id": "6eaa402e", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -466,7 +466,7 @@ { "cell_type": "code", "execution_count": null, - "id": "acf8caa3", + "id": "f6b148d4", "metadata": {}, "outputs": [], "source": [ @@ -476,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "697e9090", + "id": "f4e62e5b", "metadata": {}, "outputs": [], "source": [ @@ -489,7 +489,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18f34e66", + "id": "7d426ab0", "metadata": {}, "outputs": [], "source": [ @@ -501,7 +501,7 @@ }, { "cell_type": "markdown", - "id": "4c498f62", + "id": "449d003c", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb index 49be6edb..a6e04680 100644 --- a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb +++ b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "bd333de9", + "id": "ba22504d", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Structured Outputs and Jinja Expressions\n", @@ -16,7 +16,7 @@ }, { "cell_type": "markdown", - "id": "28fb2ee3", + "id": "c176fe63", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -28,7 +28,7 @@ }, { "cell_type": "markdown", - "id": "fbeb3b2d", + "id": "32c80f72", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -39,7 +39,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6ef3d2ae", + "id": "4ab45e3a", "metadata": {}, "outputs": [], "source": [ @@ -50,7 +50,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07546806", + "id": "2ae70d67", "metadata": {}, "outputs": [], "source": [ @@ -68,7 +68,7 @@ { "cell_type": "code", "execution_count": null, - "id": "81b00725", + "id": "2cdc070b", "metadata": {}, "outputs": [], "source": [ @@ -78,7 +78,7 @@ }, { "cell_type": "markdown", - "id": "a5cf694f", + "id": "a04261b9", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -91,7 +91,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8320e2b0", + "id": "c8bef18a", "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "markdown", - "id": "348e2c5a", + "id": "ed555636", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -117,7 +117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21019fc5", + "id": "47208094", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "7bf9d9af", + "id": "36c200d9", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -162,7 +162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88abb685", + "id": "57c0d82f", "metadata": {}, "outputs": [], "source": [ @@ -171,7 +171,7 @@ }, { "cell_type": "markdown", - "id": "d8e790c6", + "id": "01ff63ca", "metadata": {}, "source": [ "### πŸ§‘β€πŸŽ¨ Designing our data\n", @@ -198,7 +198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "64465ab1", + "id": "4fb0f1ca", "metadata": {}, "outputs": [], "source": [ @@ -226,7 +226,7 @@ }, { "cell_type": "markdown", - "id": "cfbad124", + "id": "8f35bd87", "metadata": {}, "source": [ "Next, let's design our product review dataset using a few more tricks compared to the previous notebook.\n" @@ -235,7 +235,7 @@ { "cell_type": "code", "execution_count": null, - "id": "aa93a4c9", + "id": "43341f16", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +344,7 @@ }, { "cell_type": "markdown", - "id": "74aa72fc", + "id": "34c3e08b", "metadata": {}, "source": [ "Next, we will use more advanced Jinja expressions to create new columns.\n", @@ -361,7 +361,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9ae978cc", + "id": "c168c089", "metadata": {}, "outputs": [], "source": [ @@ -414,7 +414,7 @@ }, { "cell_type": "markdown", - "id": "ec850f14", + "id": "7e6521a2", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -431,7 +431,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cb18575e", + "id": "03510f78", "metadata": {}, "outputs": [], "source": [ @@ -441,7 +441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eee46dc6", + "id": "ad599c43", "metadata": {}, "outputs": [], "source": [ @@ -452,7 +452,7 @@ { "cell_type": "code", "execution_count": null, - "id": "082d0fc4", + "id": "dbd3e17c", "metadata": {}, "outputs": [], "source": [ @@ -462,7 +462,7 @@ }, { "cell_type": "markdown", - "id": "e8d80b94", + "id": "4db52c26", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -475,7 +475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4b0a7299", + "id": "f1007ac4", "metadata": {}, "outputs": [], "source": [ @@ -485,7 +485,7 @@ }, { "cell_type": "markdown", - "id": "d7e0c925", + "id": "dcd68de4", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -498,7 +498,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b599d759", + "id": "27b6bfe8", "metadata": {}, "outputs": [], "source": [ @@ -508,7 +508,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07a7c0da", + "id": "d4e9a395", "metadata": {}, "outputs": [], "source": [ @@ -521,7 +521,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7760dffa", + "id": "946b3aa8", "metadata": {}, "outputs": [], "source": [ @@ -533,7 +533,7 @@ }, { "cell_type": "markdown", - "id": "6d19000a", + "id": "f50d996e", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb index 468aa795..639e88df 100644 --- a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb +++ b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "573c3e7b", + "id": "25501772", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Seeding Synthetic Data Generation with an External Dataset\n", @@ -16,7 +16,7 @@ }, { "cell_type": "markdown", - "id": "63f6c36d", + "id": "67ffc49e", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -28,7 +28,7 @@ }, { "cell_type": "markdown", - "id": "02cc81c7", + "id": "54a42504", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -39,7 +39,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18d51631", + "id": "05b45354", "metadata": {}, "outputs": [], "source": [ @@ -50,7 +50,7 @@ { "cell_type": "code", "execution_count": null, - "id": "67c55f6b", + "id": "039360fe", "metadata": {}, "outputs": [], "source": [ @@ -68,7 +68,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cfe2ff62", + "id": "028d5e8a", "metadata": {}, "outputs": [], "source": [ @@ -78,7 +78,7 @@ }, { "cell_type": "markdown", - "id": "bdbc5b03", + "id": "15a1df61", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -91,7 +91,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55d9caf1", + "id": "a87b6ff6", "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "markdown", - "id": "aa1623bc", + "id": "b9166cfd", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -117,7 +117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9d1310cf", + "id": "4961d3b0", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "e64ce3b7", + "id": "b1d8588a", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -162,7 +162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dafd6155", + "id": "cf42a4dd", "metadata": {}, "outputs": [], "source": [ @@ -171,7 +171,7 @@ }, { "cell_type": "markdown", - "id": "7c01f11c", + "id": "8d6b26aa", "metadata": {}, "source": [ "## πŸ₯ Prepare a seed dataset\n", @@ -196,7 +196,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7941073f", + "id": "fc90401d", "metadata": {}, "outputs": [], "source": [ @@ -214,7 +214,7 @@ }, { "cell_type": "markdown", - "id": "a68c7d55", + "id": "6f5ee960", "metadata": {}, "source": [ "## 🎨 Designing our synthetic patient notes dataset\n", @@ -227,7 +227,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f1b3d4d4", + "id": "e9db2ff0", "metadata": {}, "outputs": [], "source": [ @@ -308,7 +308,7 @@ }, { "cell_type": "markdown", - "id": "eff1bf9f", + "id": "00efc894", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -325,7 +325,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b5955230", + "id": "3e3d824e", "metadata": {}, "outputs": [], "source": [ @@ -335,7 +335,7 @@ { "cell_type": "code", "execution_count": null, - "id": "062a7294", + "id": "27785af7", "metadata": {}, "outputs": [], "source": [ @@ -346,7 +346,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6378e1be", + "id": "430998d1", "metadata": {}, "outputs": [], "source": [ @@ -356,7 +356,7 @@ }, { "cell_type": "markdown", - "id": "51e5175e", + "id": "dda6458b", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -369,7 +369,7 @@ { "cell_type": "code", "execution_count": null, - "id": "891b6860", + "id": "f45bc088", "metadata": {}, "outputs": [], "source": [ @@ -379,7 +379,7 @@ }, { "cell_type": "markdown", - "id": "0f52668f", + "id": "1e913fd8", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -392,7 +392,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ed083bd8", + "id": "30b8b7f7", "metadata": {}, "outputs": [], "source": [ @@ -402,7 +402,7 @@ { "cell_type": "code", "execution_count": null, - "id": "039c42e4", + "id": "b7ff96d1", "metadata": {}, "outputs": [], "source": [ @@ -415,7 +415,7 @@ { "cell_type": "code", "execution_count": null, - "id": "623ca205", + "id": "dbfef8a8", "metadata": {}, "outputs": [], "source": [ @@ -427,7 +427,7 @@ }, { "cell_type": "markdown", - "id": "0a7e7d42", + "id": "5db3f38d", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/4-providing-images-as-context.ipynb b/docs/colab_notebooks/4-providing-images-as-context.ipynb index 62ac63e8..9797695e 100644 --- a/docs/colab_notebooks/4-providing-images-as-context.ipynb +++ b/docs/colab_notebooks/4-providing-images-as-context.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "731384ed", + "id": "19e57933", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Providing Images as Context for Vision-Based Data Generation" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "bc66dd23", + "id": "25e3cc64", "metadata": {}, "source": [ "#### πŸ“š What you'll learn\n", @@ -25,7 +25,7 @@ }, { "cell_type": "markdown", - "id": "4539a931", + "id": "4aae5c82", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -37,7 +37,7 @@ }, { "cell_type": "markdown", - "id": "f88809bf", + "id": "24dfae6c", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -48,7 +48,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3628d4c4", + "id": "619b1aae", "metadata": {}, "outputs": [], "source": [ @@ -59,7 +59,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7fcf0f75", + "id": "0d49a542", "metadata": {}, "outputs": [], "source": [ @@ -77,7 +77,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6654714a", + "id": "1b28f160", "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "markdown", - "id": "22488cb7", + "id": "63dc34de", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -113,7 +113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39913ca0", + "id": "672155c8", "metadata": {}, "outputs": [], "source": [ @@ -122,7 +122,7 @@ }, { "cell_type": "markdown", - "id": "fba112ab", + "id": "4b32c25e", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -139,7 +139,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70fd86dd", + "id": "72971915", "metadata": {}, "outputs": [], "source": [ @@ -162,7 +162,7 @@ }, { "cell_type": "markdown", - "id": "810c7457", + "id": "115ad20f", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -177,7 +177,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9b2204d0", + "id": "11e844d2", "metadata": {}, "outputs": [], "source": [ @@ -186,7 +186,7 @@ }, { "cell_type": "markdown", - "id": "29e3dae5", + "id": "77862fce", "metadata": {}, "source": [ "### 🌱 Seed Dataset Creation\n", @@ -203,7 +203,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e2cc3506", + "id": "e415a502", "metadata": {}, "outputs": [], "source": [ @@ -218,7 +218,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7a821067", + "id": "335f2611", "metadata": {}, "outputs": [], "source": [ @@ -266,7 +266,7 @@ { "cell_type": "code", "execution_count": null, - "id": "359d144b", + "id": "f055e88d", "metadata": {}, "outputs": [], "source": [ @@ -284,7 +284,7 @@ { "cell_type": "code", "execution_count": null, - "id": "985cd308", + "id": "47a1c586", "metadata": {}, "outputs": [], "source": [ @@ -294,7 +294,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6a8cb414", + "id": "3a77fc52", "metadata": {}, "outputs": [], "source": [ @@ -306,7 +306,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a57e1b73", + "id": "c0941cc7", "metadata": {}, "outputs": [], "source": [ @@ -335,7 +335,7 @@ }, { "cell_type": "markdown", - "id": "7518100a", + "id": "578e77dc", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -352,7 +352,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4c1fe540", + "id": "9f0c11ce", "metadata": {}, "outputs": [], "source": [ @@ -362,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bceafe91", + "id": "b10412c1", "metadata": {}, "outputs": [], "source": [ @@ -373,7 +373,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20f4ace5", + "id": "766ee2d7", "metadata": {}, "outputs": [], "source": [ @@ -383,7 +383,7 @@ }, { "cell_type": "markdown", - "id": "16a86d56", + "id": "6370bfa5", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -396,7 +396,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c1bbae97", + "id": "d57ded0e", "metadata": {}, "outputs": [], "source": [ @@ -406,7 +406,7 @@ }, { "cell_type": "markdown", - "id": "d8d7604f", + "id": "5afd8e8c", "metadata": {}, "source": [ "### πŸ”Ž Visual Inspection\n", @@ -417,7 +417,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27c0636c", + "id": "aa4bfcc3", "metadata": { "lines_to_next_cell": 2 }, @@ -441,7 +441,7 @@ }, { "cell_type": "markdown", - "id": "f6b99539", + "id": "4eeaada6", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -454,7 +454,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5d53787", + "id": "0ee5b1b9", "metadata": {}, "outputs": [], "source": [ @@ -464,7 +464,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1f859e49", + "id": "e5e8b241", "metadata": {}, "outputs": [], "source": [ @@ -477,7 +477,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6688e3c5", + "id": "23ebb3ca", "metadata": {}, "outputs": [], "source": [ @@ -489,7 +489,7 @@ }, { "cell_type": "markdown", - "id": "28635b09", + "id": "14a78533", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/5-generating-images.ipynb b/docs/colab_notebooks/5-generating-images.ipynb index 485fe258..c8092938 100644 --- a/docs/colab_notebooks/5-generating-images.ipynb +++ b/docs/colab_notebooks/5-generating-images.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "0ee289e6", + "id": "735e6197", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Generating Images\n", @@ -22,7 +22,7 @@ }, { "cell_type": "markdown", - "id": "86f748c1", + "id": "92ae4afe", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -33,7 +33,7 @@ }, { "cell_type": "markdown", - "id": "c610ee22", + "id": "ccc77347", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -44,7 +44,7 @@ { "cell_type": "code", "execution_count": null, - "id": "818ca495", + "id": "23627c23", "metadata": {}, "outputs": [], "source": [ @@ -55,7 +55,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f165bb15", + "id": "bf958dc6", "metadata": {}, "outputs": [], "source": [ @@ -73,7 +73,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5decfc83", + "id": "ab0cfff8", "metadata": {}, "outputs": [], "source": [ @@ -86,7 +86,7 @@ }, { "cell_type": "markdown", - "id": "929f35d6", + "id": "a18ef5ce", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -97,7 +97,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b4c8b7d7", + "id": "5fe11301", "metadata": {}, "outputs": [], "source": [ @@ -106,7 +106,7 @@ }, { "cell_type": "markdown", - "id": "8ed7b0b6", + "id": "b913d454", "metadata": {}, "source": [ "### πŸŽ›οΈ Define an image-generation model\n", @@ -118,7 +118,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d6b1ca66", + "id": "a50d26ee", "metadata": {}, "outputs": [], "source": [ @@ -140,7 +140,7 @@ }, { "cell_type": "markdown", - "id": "498cfecf", + "id": "122374d9", "metadata": {}, "source": [ "### πŸ—οΈ Build the config: samplers + image column\n", @@ -151,7 +151,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e74fc7ab", + "id": "940f2b70", "metadata": {}, "outputs": [], "source": [ @@ -324,7 +324,7 @@ }, { "cell_type": "markdown", - "id": "c592c820", + "id": "e13e0bb4", "metadata": {}, "source": [ "### πŸ” Preview: images as base64\n", @@ -335,7 +335,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eee17bb1", + "id": "2a60a76f", "metadata": {}, "outputs": [], "source": [ @@ -345,7 +345,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3cd320cc", + "id": "3c831ee8", "metadata": {}, "outputs": [], "source": [ @@ -356,7 +356,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ffb5e188", + "id": "143e762f", "metadata": {}, "outputs": [], "source": [ @@ -365,7 +365,7 @@ }, { "cell_type": "markdown", - "id": "87b83328", + "id": "a84606b4", "metadata": {}, "source": [ "### πŸ†™ Create: images saved to disk\n", @@ -376,17 +376,17 @@ { "cell_type": "code", "execution_count": null, - "id": "a8f9cc41", + "id": "89147954", "metadata": {}, "outputs": [], "source": [ - "results = data_designer.create(config_builder, num_records=5, dataset_name=\"tutorial-5-images\")" + "results = data_designer.create(config_builder, num_records=2, dataset_name=\"tutorial-5-images\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "0d4453e5", + "id": "04c96063", "metadata": {}, "outputs": [], "source": [ @@ -397,7 +397,7 @@ { "cell_type": "code", "execution_count": null, - "id": "198301ab", + "id": "edb794bb", "metadata": {}, "outputs": [], "source": [ @@ -413,7 +413,7 @@ }, { "cell_type": "markdown", - "id": "2bdcef2b", + "id": "e0a72bf6", "metadata": {}, "source": [ "## ⏭️ Next steps\n", From 469a3d295fd24afbce9179e24889ff4947c95c6d Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Tue, 10 Feb 2026 16:01:01 -0700 Subject: [PATCH 58/64] raise a ValueError if we fail to detect image format --- .../config/utils/image_helpers.py | 28 +++++++++++++++---- .../tests/config/utils/test_image_helpers.py | 17 +++++------ .../engine/storage/test_media_storage.py | 2 +- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py index 0fb949af..c91974d8 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py @@ -21,9 +21,20 @@ IMAGE_FORMAT_MAGIC_BYTES = { ImageFormat.PNG: b"\x89PNG\r\n\x1a\n", ImageFormat.JPG: b"\xff\xd8\xff", + ImageFormat.GIF: b"GIF8", # WEBP uses RIFF header - handled separately } +# Maps PIL format name (lowercase) to our ImageFormat enum. +# PIL reports "JPEG" (not "JPG"), so we normalize it here. +_PIL_FORMAT_TO_IMAGE_FORMAT: dict[str, ImageFormat] = { + "png": ImageFormat.PNG, + "jpeg": ImageFormat.JPG, + "jpg": ImageFormat.JPG, + "gif": ImageFormat.GIF, + "webp": ImageFormat.WEBP, +} + _BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/=]+$") # Patterns for diffusion-based image models only (use image_generation API). @@ -105,13 +116,18 @@ def detect_image_format(image_bytes: bytes) -> ImageFormat: image_bytes: Image data as bytes Returns: - Detected format (defaults to PNG if unknown) + Detected ImageFormat + + Raises: + ValueError: If the image format cannot be determined """ # Check magic bytes first (fast) if image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.PNG]): return ImageFormat.PNG elif image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.JPG]): return ImageFormat.JPG + elif image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.GIF]): + return ImageFormat.GIF elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]: return ImageFormat.WEBP @@ -119,13 +135,15 @@ def detect_image_format(image_bytes: bytes) -> ImageFormat: try: img = Image.open(io.BytesIO(image_bytes)) format_str = img.format.lower() if img.format else None - if format_str in [fmt.value for fmt in ImageFormat]: - return ImageFormat(format_str if format_str != ImageFormat.JPEG.value else ImageFormat.JPG.value) + if format_str in _PIL_FORMAT_TO_IMAGE_FORMAT: + return _PIL_FORMAT_TO_IMAGE_FORMAT[format_str] except Exception: pass - # Default to PNG - return ImageFormat.PNG + raise ValueError( + f"Unable to detect image format (first 8 bytes: {image_bytes[:8]!r}). " + f"Supported formats: {', '.join(SUPPORTED_IMAGE_EXTENSIONS)}." + ) def is_image_path(value: str) -> bool: diff --git a/packages/data-designer-config/tests/config/utils/test_image_helpers.py b/packages/data-designer-config/tests/config/utils/test_image_helpers.py index 08ea3b50..fe2f40b7 100644 --- a/packages/data-designer-config/tests/config/utils/test_image_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_image_helpers.py @@ -84,9 +84,10 @@ def test_detect_image_format_webp(): assert detect_image_format(webp_magic) == ImageFormat.WEBP -def test_detect_image_format_unknown_defaults_to_png(): +def test_detect_image_format_unknown_raises_error(): unknown_bytes = b"\x00\x00\x00\x00" + b"\x00" * 10 - assert detect_image_format(unknown_bytes) == ImageFormat.PNG + with pytest.raises(ValueError, match="Unable to detect image format"): + detect_image_format(unknown_bytes) # Tests for is_image_path @@ -206,31 +207,27 @@ def test_validate_image_nonexistent_raises_error(tmp_path): # Additional tests for uncovered lines -def test_detect_image_format_with_pil_fallback_unsupported_format(tmp_path): - # Create a real GIF image that will trigger PIL fallback - # (GIF has different magic bytes not in our fast-path detection) +def test_detect_image_format_gif_magic_bytes(tmp_path): + # GIF files start with "GIF87a" or "GIF89a" and are now detected via magic bytes img = Image.new("RGB", (1, 1), color="red") gif_path = tmp_path / "test.gif" img.save(gif_path, format="GIF") gif_bytes = gif_path.read_bytes() - # Should use PIL fallback and correctly detect GIF format result = detect_image_format(gif_bytes) assert result == ImageFormat.GIF def test_detect_image_format_with_pil_fallback_jpeg(): - # Test PIL fallback path that converts "jpeg" format string to JPG enum - # Use mock since we can't easily create valid JPEG bytes without magic bytes + # Test PIL fallback path that normalizes "jpeg" -> JPG enum mock_img = Mock() mock_img.format = "JPEG" - # Use bytes that don't match our magic bytes to trigger PIL fallback + # Use bytes that don't match any magic bytes to trigger PIL fallback test_bytes = b"\x00\x00\x00\x00" with patch.object(Image, "open", return_value=mock_img): result = detect_image_format(test_bytes) - # Should convert JPEG -> JPG via line 96 assert result == ImageFormat.JPG diff --git a/packages/data-designer-engine/tests/engine/storage/test_media_storage.py b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py index 9d74734a..2e690fb4 100644 --- a/packages/data-designer-engine/tests/engine/storage/test_media_storage.py +++ b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py @@ -140,7 +140,7 @@ def test_save_base64_image_disk_mode_corrupted_image_raises_error(tmp_path): corrupted_bytes = b"not a valid image" corrupted_base64 = base64.b64encode(corrupted_bytes).decode() - with pytest.raises(ValueError, match="Image validation failed"): + with pytest.raises(ValueError, match="Unable to detect image format"): storage.save_base64_image(corrupted_base64, subfolder_name="test_column") # Check that no files were left behind (cleanup on validation failure) From 1e43394b142b77acfb589f0e0cec0d567587acf9 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 11 Feb 2026 10:19:48 -0700 Subject: [PATCH 59/64] Fix diffusion image gen --- .../config/utils/image_helpers.py | 23 ++++++++++- .../src/data_designer/engine/models/facade.py | 40 ++++++++++++++----- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py index c91974d8..45f43622 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py @@ -11,6 +11,8 @@ from pathlib import Path from typing import TYPE_CHECKING +import requests + from data_designer.config.models import ImageFormat from data_designer.lazy_heavy_imports import Image @@ -39,12 +41,13 @@ # Patterns for diffusion-based image models only (use image_generation API). IMAGE_DIFFUSION_MODEL_PATTERNS = ( - "dall-e", + "dall-e-", "dalle", "stable-diffusion", "sd-", "sd_", "imagen", + "gpt-image-", ) SUPPORTED_IMAGE_EXTENSIONS = [f".{fmt.value.lower()}" for fmt in ImageFormat] @@ -232,6 +235,24 @@ def load_image_path_to_base64(image_path: str, base_path: str | None = None) -> return None +def load_image_url_to_base64(url: str, timeout: int = 60) -> str: + """Download an image from a URL and return as base64. + + Args: + url: HTTP(S) URL pointing to an image. + timeout: Request timeout in seconds. + + Returns: + Base64-encoded image data. + + Raises: + requests.HTTPError: If the download fails with a non-2xx status. + """ + resp = requests.get(url, timeout=timeout) + resp.raise_for_status() + return base64.b64encode(resp.content).decode() + + def validate_image(image_path: Path) -> None: """Validate that an image file is readable and not corrupted. diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index e637d9f4..902ac80a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -13,6 +13,7 @@ extract_base64_from_data_uri, is_base64_image, is_image_diffusion_model, + load_image_url_to_base64, ) from data_designer.engine.mcp.errors import MCPConfigurationError from data_designer.engine.model_provider import ModelProviderRegistry @@ -41,13 +42,30 @@ def _identity(x: Any) -> Any: return x -def _try_extract_base64(data: str) -> str | None: - """Try to extract base64 image data from a data URI, returning None on failure.""" +def _try_extract_base64(source: str | litellm.types.utils.ImageObject) -> str | None: + """Try to extract base64 image data from a data URI string or image response object. + + Args: + source: Either a data URI string (e.g. "data:image/png;base64,...") + or a litellm ImageObject with b64_json/url attributes. + + Returns: + Base64-encoded image string, or None if extraction fails. + """ try: - return extract_base64_from_data_uri(data) - except ValueError: + if isinstance(source, str): + return extract_base64_from_data_uri(source) + + if getattr(source, "b64_json", None): + return source.b64_json + + if getattr(source, "url", None): + return load_image_url_to_base64(source.url) + except Exception: return None + return None + logger = logging.getLogger(__name__) @@ -447,16 +465,14 @@ def _generate_image_chat_completion( def _generate_image_diffusion(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]: """Generate image(s) using diffusion model via image_generation API. - Always returns base64. The API is configured to return base64 format. + Always returns base64. If the API returns URLs instead of inline base64, + the images are downloaded and converted automatically. Returns: List of base64-encoded image strings """ kwargs = self.consolidate_kwargs(**kwargs) - # Always request base64 format - kwargs["response_format"] = "b64_json" - response = None try: @@ -471,8 +487,12 @@ def _generate_image_diffusion(self, prompt: str, skip_usage_tracking: bool = Fal if not response.data or len(response.data) == 0: raise ImageGenerationError("Image generation returned no data") - # Return all images as list - return [img.b64_json for img in response.data] + images = [b64 for img in response.data if (b64 := _try_extract_base64(img)) is not None] + + if not images: + raise ImageGenerationError("No image data could be extracted from response") + + return images except Exception: raise From 8f6be9bae09623873d805207750e7d27d68b5e8f Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 11 Feb 2026 10:28:38 -0700 Subject: [PATCH 60/64] Add requests to config pyproject.toml --- packages/data-designer-config/pyproject.toml | 1 + uv.lock | 2 ++ 2 files changed, 3 insertions(+) diff --git a/packages/data-designer-config/pyproject.toml b/packages/data-designer-config/pyproject.toml index 569c8fe0..dc980798 100644 --- a/packages/data-designer-config/pyproject.toml +++ b/packages/data-designer-config/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "pygments>=2.19.2,<3", "python-json-logger>=3,<4", "pyyaml>=6.0.1,<7", + "requests>=2.32.0,<3", "rich>=13.7.1,<15", ] diff --git a/uv.lock b/uv.lock index 200d0b12..d92497dd 100644 --- a/uv.lock +++ b/uv.lock @@ -805,6 +805,7 @@ dependencies = [ { name = "pygments" }, { name = "python-json-logger" }, { name = "pyyaml" }, + { name = "requests" }, { name = "rich" }, ] @@ -819,6 +820,7 @@ requires-dist = [ { name = "pygments", specifier = ">=2.19.2,<3" }, { name = "python-json-logger", specifier = ">=3,<4" }, { name = "pyyaml", specifier = ">=6.0.1,<7" }, + { name = "requests", specifier = ">=2.32.0,<3" }, { name = "rich", specifier = ">=13.7.1,<15" }, ] From 87dcab1e9bfb4e0cbc4a4cc8ee3ab60979cae4cd Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 11 Feb 2026 15:46:01 -0700 Subject: [PATCH 61/64] address pr feedback from andre --- .../tests/config/utils/test_image_helpers.py | 325 +++++++++--------- .../dataset_builders/column_wise_builder.py | 25 ++ .../src/data_designer/engine/models/facade.py | 4 + .../engine/storage/media_storage.py | 36 +- .../engine/storage/test_media_storage.py | 82 +++-- 5 files changed, 267 insertions(+), 205 deletions(-) diff --git a/packages/data-designer-config/tests/config/utils/test_image_helpers.py b/packages/data-designer-config/tests/config/utils/test_image_helpers.py index fe2f40b7..8b2f557a 100644 --- a/packages/data-designer-config/tests/config/utils/test_image_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_image_helpers.py @@ -5,6 +5,7 @@ import base64 import io +from pathlib import Path from unittest.mock import Mock, patch import pytest @@ -23,37 +24,51 @@ ) from data_designer.lazy_heavy_imports import Image -# Tests for extract_base64_from_data_uri +@pytest.fixture +def sample_png_bytes() -> bytes: + """Create a valid 1x1 PNG as raw bytes.""" + img = Image.new("RGB", (1, 1), color="red") + buf = io.BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + +# --------------------------------------------------------------------------- +# extract_base64_from_data_uri +# --------------------------------------------------------------------------- -def test_extract_base64_from_data_uri_with_prefix(): + +def test_extract_base64_from_data_uri_with_prefix() -> None: data_uri = "" result = extract_base64_from_data_uri(data_uri) assert result == "iVBORw0KGgoAAAANS" -def test_extract_base64_plain_base64_without_prefix(): +def test_extract_base64_plain_base64_without_prefix() -> None: plain_base64 = "iVBORw0KGgoAAAANS" result = extract_base64_from_data_uri(plain_base64) assert result == plain_base64 -def test_extract_base64_invalid_data_uri_raises_error(): +def test_extract_base64_invalid_data_uri_raises_error() -> None: with pytest.raises(ValueError, match="Invalid data URI format: missing comma separator"): extract_base64_from_data_uri("data:image/png;base64") -# Tests for decode_base64_image +# --------------------------------------------------------------------------- +# decode_base64_image +# --------------------------------------------------------------------------- -def test_decode_base64_image_valid(): +def test_decode_base64_image_valid() -> None: png_bytes = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" base64_data = base64.b64encode(png_bytes).decode() result = decode_base64_image(base64_data) assert result == png_bytes -def test_decode_base64_image_with_data_uri(): +def test_decode_base64_image_with_data_uri() -> None: png_bytes = b"\x89PNG\r\n\x1a\n" base64_data = base64.b64encode(png_bytes).decode() data_uri = f"data:image/png;base64,{base64_data}" @@ -61,274 +76,258 @@ def test_decode_base64_image_with_data_uri(): assert result == png_bytes -def test_decode_base64_image_invalid_raises_error(): +def test_decode_base64_image_invalid_raises_error() -> None: with pytest.raises(ValueError, match="Invalid base64 data"): decode_base64_image("not-valid-base64!!!") -# Tests for detect_image_format +# --------------------------------------------------------------------------- +# detect_image_format (magic bytes) +# --------------------------------------------------------------------------- -def test_detect_image_format_png(): - png_magic = b"\x89PNG\r\n\x1a\n" + b"\x00" * 10 - assert detect_image_format(png_magic) == ImageFormat.PNG +@pytest.mark.parametrize( + "header_bytes,expected_format", + [ + (b"\x89PNG\r\n\x1a\n" + b"\x00" * 10, ImageFormat.PNG), + (b"\xff\xd8\xff" + b"\x00" * 10, ImageFormat.JPG), + (b"RIFF" + b"\x00" * 4 + b"WEBP", ImageFormat.WEBP), + ], + ids=["png", "jpg", "webp"], +) +def test_detect_image_format_magic_bytes(header_bytes: bytes, expected_format: ImageFormat) -> None: + assert detect_image_format(header_bytes) == expected_format -def test_detect_image_format_jpg(): - jpg_magic = b"\xff\xd8\xff" + b"\x00" * 10 - assert detect_image_format(jpg_magic) == ImageFormat.JPG +def test_detect_image_format_gif_magic_bytes(tmp_path: Path) -> None: + img = Image.new("RGB", (1, 1), color="red") + gif_path = tmp_path / "test.gif" + img.save(gif_path, format="GIF") + gif_bytes = gif_path.read_bytes() + assert detect_image_format(gif_bytes) == ImageFormat.GIF -def test_detect_image_format_webp(): - webp_magic = b"RIFF" + b"\x00" * 4 + b"WEBP" - assert detect_image_format(webp_magic) == ImageFormat.WEBP +def test_detect_image_format_with_pil_fallback_jpeg() -> None: + mock_img = Mock() + mock_img.format = "JPEG" + test_bytes = b"\x00\x00\x00\x00" + with patch.object(Image, "open", return_value=mock_img): + result = detect_image_format(test_bytes) + assert result == ImageFormat.JPG -def test_detect_image_format_unknown_raises_error(): + +def test_detect_image_format_unknown_raises_error() -> None: unknown_bytes = b"\x00\x00\x00\x00" + b"\x00" * 10 with pytest.raises(ValueError, match="Unable to detect image format"): detect_image_format(unknown_bytes) -# Tests for is_image_path +# --------------------------------------------------------------------------- +# is_image_path +# --------------------------------------------------------------------------- -def test_is_image_path_various_extensions(): - assert is_image_path("/path/to/image.png") is True - assert is_image_path("image.PNG") is True - assert is_image_path("image.jpg") is True - assert is_image_path("image.jpeg") is True +@pytest.mark.parametrize( + "value,expected", + [ + ("/path/to/image.png", True), + ("image.PNG", True), + ("image.jpg", True), + ("image.jpeg", True), + ("/path/to/file.txt", False), + ("document.pdf", False), + ("/some.png/file.txt", False), + ], + ids=["png", "png-upper", "jpg", "jpeg", "txt", "pdf", "ext-in-dir"], +) +def test_is_image_path(value: str, expected: bool) -> None: + assert is_image_path(value) is expected -def test_is_image_path_non_image(): - assert is_image_path("/path/to/file.txt") is False - assert is_image_path("document.pdf") is False +# --------------------------------------------------------------------------- +# is_image_url +# --------------------------------------------------------------------------- -def test_is_image_path_extension_in_directory(): - assert is_image_path("/some.png/file.txt") is False +@pytest.mark.parametrize( + "value,expected", + [ + ("http://example.com/image.png", True), + ("https://example.com/photo.jpg", True), + ("https://example.com/image.png?size=large", True), + ("https://example.com/page.html", False), + ("ftp://example.com/image.png", False), + ], + ids=["http", "https", "query-params", "non-image-ext", "ftp"], +) +def test_is_image_url(value: str, expected: bool) -> None: + assert is_image_url(value) is expected -# Tests for is_base64_image +# --------------------------------------------------------------------------- +# is_base64_image +# --------------------------------------------------------------------------- -def test_is_base64_image_data_uri(): +def test_is_base64_image_data_uri() -> None: assert is_base64_image("") is True -def test_is_base64_image_long_valid_base64(): +def test_is_base64_image_long_valid_base64() -> None: long_base64 = base64.b64encode(b"x" * 100).decode() assert is_base64_image(long_base64) is True -def test_is_base64_image_short_string(): +def test_is_base64_image_short_string() -> None: assert is_base64_image("short") is False -# Tests for is_image_url - - -def test_is_image_url_http_and_https(): - assert is_image_url("http://example.com/image.png") is True - assert is_image_url("https://example.com/photo.jpg") is True - - -def test_is_image_url_with_query_params(): - assert is_image_url("https://example.com/image.png?size=large") is True - - -def test_is_image_url_without_image_extension(): - assert is_image_url("https://example.com/page.html") is False - - -def test_is_image_url_non_http(): - assert is_image_url("ftp://example.com/image.png") is False - - -# Tests for is_image_diffusion_model - - -def test_is_image_diffusion_model_dall_e(): - assert is_image_diffusion_model("dall-e-3") is True - assert is_image_diffusion_model("DALL-E-2") is True - assert is_image_diffusion_model("openai/dalle-2") is True - - -def test_is_image_diffusion_model_stable_diffusion(): - assert is_image_diffusion_model("stable-diffusion-xl") is True - assert is_image_diffusion_model("sd-2.1") is True - assert is_image_diffusion_model("sd_1.5") is True - +def test_is_base64_image_invalid_base64_decode() -> None: + invalid_base64 = "A" * 50 + "=" + "A" * 49 + "more text" + assert is_base64_image(invalid_base64) is False -def test_is_image_diffusion_model_imagen(): - assert is_image_diffusion_model("imagen-3") is True - assert is_image_diffusion_model("google/imagen") is True +# --------------------------------------------------------------------------- +# Non-string guard (is_image_path, is_base64_image, is_image_url) +# --------------------------------------------------------------------------- -def test_is_image_diffusion_model_chat_completion_image_models(): - assert is_image_diffusion_model("gemini-3-pro-image-preview") is False - assert is_image_diffusion_model("gpt-5-image") is False - assert is_image_diffusion_model("flux.2-pro") is False +@pytest.mark.parametrize( + "func", + [is_image_path, is_base64_image, is_image_url], + ids=["is_image_path", "is_base64_image", "is_image_url"], +) +@pytest.mark.parametrize("value", [123, None, []], ids=["int", "none", "list"]) +def test_non_string_input_returns_false(func: object, value: object) -> None: + assert func(value) is False + + +# --------------------------------------------------------------------------- +# is_image_diffusion_model +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "model_name,expected", + [ + ("dall-e-3", True), + ("DALL-E-2", True), + ("openai/dalle-2", True), + ("stable-diffusion-xl", True), + ("sd-2.1", True), + ("sd_1.5", True), + ("imagen-3", True), + ("google/imagen", True), + ("gpt-image-1", True), + ("gemini-3-pro-image-preview", False), + ("gpt-5-image", False), + ("flux.2-pro", False), + ], + ids=[ + "dall-e-3", + "DALL-E-2", + "dalle-2", + "stable-diffusion-xl", + "sd-2.1", + "sd_1.5", + "imagen-3", + "google-imagen", + "gpt-image-1", + "gemini-not-diffusion", + "gpt-5-not-diffusion", + "flux-not-diffusion", + ], +) +def test_is_image_diffusion_model(model_name: str, expected: bool) -> None: + assert is_image_diffusion_model(model_name) is expected -# Tests for validate_image +# --------------------------------------------------------------------------- +# validate_image +# --------------------------------------------------------------------------- -def test_validate_image_valid_png(tmp_path): - # Create a valid 1x1 PNG using PIL - img = Image.new("RGB", (1, 1), color="red") - buf = io.BytesIO() - img.save(buf, format="PNG") - png_bytes = buf.getvalue() +def test_validate_image_valid_png(tmp_path: Path, sample_png_bytes: bytes) -> None: image_path = tmp_path / "test.png" - image_path.write_bytes(png_bytes) - - # Should not raise + image_path.write_bytes(sample_png_bytes) validate_image(image_path) -def test_validate_image_corrupted_raises_error(tmp_path): - # Create an invalid image file +def test_validate_image_corrupted_raises_error(tmp_path: Path) -> None: image_path = tmp_path / "corrupted.png" image_path.write_bytes(b"not a valid image") - with pytest.raises(ValueError, match="Image validation failed"): validate_image(image_path) -def test_validate_image_nonexistent_raises_error(tmp_path): +def test_validate_image_nonexistent_raises_error(tmp_path: Path) -> None: image_path = tmp_path / "nonexistent.png" - with pytest.raises(ValueError, match="Image validation failed"): validate_image(image_path) -# Additional tests for uncovered lines - +# --------------------------------------------------------------------------- +# load_image_path_to_base64 +# --------------------------------------------------------------------------- -def test_detect_image_format_gif_magic_bytes(tmp_path): - # GIF files start with "GIF87a" or "GIF89a" and are now detected via magic bytes - img = Image.new("RGB", (1, 1), color="red") - gif_path = tmp_path / "test.gif" - img.save(gif_path, format="GIF") - - gif_bytes = gif_path.read_bytes() - result = detect_image_format(gif_bytes) - assert result == ImageFormat.GIF - - -def test_detect_image_format_with_pil_fallback_jpeg(): - # Test PIL fallback path that normalizes "jpeg" -> JPG enum - mock_img = Mock() - mock_img.format = "JPEG" - - # Use bytes that don't match any magic bytes to trigger PIL fallback - test_bytes = b"\x00\x00\x00\x00" - with patch.object(Image, "open", return_value=mock_img): - result = detect_image_format(test_bytes) - assert result == ImageFormat.JPG - - -def test_is_image_path_non_string_input(): - assert is_image_path(123) is False - assert is_image_path(None) is False - assert is_image_path([]) is False - - -def test_is_base64_image_non_string_input(): - assert is_base64_image(123) is False - assert is_base64_image(None) is False - assert is_base64_image([]) is False - - -def test_is_base64_image_invalid_base64_decode(): - # String with valid base64 characters but incorrect padding that causes decode to fail - # Single '=' in middle of string is invalid base64 (padding only allowed at end) - invalid_base64 = "A" * 50 + "=" + "A" * 49 + "more text" - assert is_base64_image(invalid_base64) is False - - -def test_is_image_url_non_string_input(): - assert is_image_url(123) is False - assert is_image_url(None) is False - assert is_image_url([]) is False - - -# Tests for load_image_path_to_base64 - - -def test_load_image_path_to_base64_absolute_path(tmp_path): - # Create a test image file +def test_load_image_path_to_base64_absolute_path(tmp_path: Path) -> None: img = Image.new("RGB", (1, 1), color="blue") image_path = tmp_path / "test.png" img.save(image_path) - # Load with absolute path result = load_image_path_to_base64(str(image_path)) assert result is not None assert len(result) > 0 - # Verify it's valid base64 decoded = base64.b64decode(result) assert len(decoded) > 0 -def test_load_image_path_to_base64_relative_with_base_path(tmp_path): - # Create a test image file +def test_load_image_path_to_base64_relative_with_base_path(tmp_path: Path) -> None: img = Image.new("RGB", (1, 1), color="green") image_path = tmp_path / "subdir" / "test.png" image_path.parent.mkdir(exist_ok=True) img.save(image_path) - # Load with relative path and base_path result = load_image_path_to_base64("subdir/test.png", base_path=str(tmp_path)) assert result is not None assert len(result) > 0 -def test_load_image_path_to_base64_nonexistent_file(): +def test_load_image_path_to_base64_nonexistent_file() -> None: result = load_image_path_to_base64("/nonexistent/path/to/image.png") assert result is None -def test_load_image_path_to_base64_relative_with_cwd_fallback(tmp_path, monkeypatch): - # Create test image in current working directory - - # Change to tmp_path as cwd +def test_load_image_path_to_base64_relative_with_cwd_fallback(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.chdir(tmp_path) img = Image.new("RGB", (1, 1), color="yellow") image_path = tmp_path / "test_cwd.png" img.save(image_path) - # Use relative path without base_path - should fall back to cwd result = load_image_path_to_base64("test_cwd.png") assert result is not None assert len(result) > 0 -def test_load_image_path_to_base64_base_path_fallback_to_cwd(tmp_path, monkeypatch): - # Test the case where base_path is provided but file isn't there, falls back to cwd +def test_load_image_path_to_base64_base_path_fallback_to_cwd(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.chdir(tmp_path) - # Create image in cwd img = Image.new("RGB", (1, 1), color="red") image_path = tmp_path / "test.png" img.save(image_path) - # Create a different base_path that doesn't have the image wrong_base = tmp_path / "wrong" wrong_base.mkdir() - # Use relative path with wrong base_path - should fall back to cwd result = load_image_path_to_base64("test.png", base_path=str(wrong_base)) assert result is not None assert len(result) > 0 -def test_load_image_path_to_base64_exception_handling(tmp_path): - # Create a directory (not a file) to trigger exception +def test_load_image_path_to_base64_exception_handling(tmp_path: Path) -> None: dir_path = tmp_path / "directory" dir_path.mkdir() diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 9077e807..e5d67928 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -286,6 +286,7 @@ def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max progress_tracker.log_final() if len(self._records_to_drop) > 0: + self._cleanup_dropped_record_images(self._records_to_drop) self.batch_manager.drop_records(self._records_to_drop) self._records_to_drop.clear() @@ -362,6 +363,30 @@ def _run_processors( ) from e return dataframe + def _cleanup_dropped_record_images(self, dropped_indices: set[int]) -> None: + """Remove saved image files for records that will be dropped. + + When a record fails during generation, any images already saved to disk + for that record in previous columns become dangling. This method deletes + those files so they don't accumulate. + """ + media_storage = self.artifact_storage.media_storage + if not self._has_image_columns() or media_storage is None or media_storage.mode != StorageMode.DISK: + return + + image_col_names = [ + col.name for col in self.single_column_configs if col.column_type == DataDesignerColumnType.IMAGE + ] + + buffer = self.batch_manager.get_current_batch(as_dataframe=False) + for idx in dropped_indices: + if idx < 0 or idx >= len(buffer): + continue + for col_name in image_col_names: + paths = buffer[idx].get(col_name, []) + for path in [paths] if isinstance(paths, str) else paths: + media_storage.delete_image(path) + def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) -> None: """If a worker fails, we can handle the exception here.""" logger.warning( diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 7f7972d5..cf3c7e6e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -62,6 +62,7 @@ def _try_extract_base64(source: str | litellm.types.utils.ImageObject) -> str | if getattr(source, "url", None): return load_image_url_to_base64(source.url) except Exception: + logger.debug(f"Failed to extract base64 from source of type {type(source).__name__}") return None return None @@ -561,3 +562,6 @@ def _track_token_usage_from_image_diffusion(self, response: litellm.types.utils. ), request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) + else: + # Successful response but no token usage data (some providers don't report it) + self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=1, failed_requests=0)) diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py index 81387525..1c887c80 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py @@ -57,15 +57,6 @@ def __init__( self.images_subdir = images_subdir self.mode = mode - def _ensure_images_directory(self) -> None: - """Create images directory if it doesn't exist (lazy initialization).""" - self.images_dir.mkdir(parents=True, exist_ok=True) - - def _sanitize_subfolder_name(self, name: str) -> str: - """Sanitize subfolder name to prevent path traversal and filesystem issues.""" - # Replace path separators and parent directory references with underscores - return name.replace("/", "_").replace("\\", "_").replace("..", "_") - def save_base64_image(self, base64_data: str, subfolder_name: str) -> str: """Save or return base64 image based on storage mode. @@ -118,6 +109,24 @@ def save_base64_image(self, base64_data: str, subfolder_name: str) -> str: return relative_path + def delete_image(self, relative_path: str) -> bool: + """Delete a saved image file given its relative path. + + Args: + relative_path: Relative path as returned by save_base64_image (e.g., "images/col/uuid.png") + + Returns: + True if the file was deleted, False if it didn't exist or deletion failed. + """ + try: + full_path = self.base_path / relative_path + if full_path.exists() and self.images_dir in full_path.parents: + full_path.unlink() + return True + except OSError: + pass + return False + def _validate_image(self, image_path: Path) -> None: """Validate that saved image is readable. @@ -133,3 +142,12 @@ def _validate_image(self, image_path: Path) -> None: # Clean up invalid file image_path.unlink(missing_ok=True) raise + + def _ensure_images_directory(self) -> None: + """Create images directory if it doesn't exist (lazy initialization).""" + self.images_dir.mkdir(parents=True, exist_ok=True) + + def _sanitize_subfolder_name(self, name: str) -> str: + """Sanitize subfolder name to prevent path traversal and filesystem issues.""" + # Replace path separators and parent directory references with underscores + return name.replace("/", "_").replace("\\", "_").replace("..", "_") diff --git a/packages/data-designer-engine/tests/engine/storage/test_media_storage.py b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py index 2e690fb4..e79c854b 100644 --- a/packages/data-designer-engine/tests/engine/storage/test_media_storage.py +++ b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py @@ -38,24 +38,21 @@ def sample_base64_jpg() -> str: return base64.b64encode(jpg_bytes).decode() -def test_media_storage_init(tmp_path): - """Test MediaStorage initialization.""" - storage = MediaStorage(base_path=tmp_path) +@pytest.mark.parametrize( + "images_subdir,mode", + [ + (IMAGES_SUBDIR, StorageMode.DISK), + ("custom_images", StorageMode.DATAFRAME), + ], + ids=["defaults", "custom-subdir-dataframe"], +) +def test_media_storage_init(tmp_path, images_subdir: str, mode: StorageMode) -> None: + """Test MediaStorage initialization with various configurations.""" + storage = MediaStorage(base_path=tmp_path, images_subdir=images_subdir, mode=mode) assert storage.base_path == tmp_path - assert storage.images_dir == tmp_path / IMAGES_SUBDIR - assert storage.images_subdir == IMAGES_SUBDIR - assert storage.mode == StorageMode.DISK - # Directory should NOT exist until first save (lazy initialization) - assert not storage.images_dir.exists() - - -def test_media_storage_init_custom_subdir(tmp_path): - """Test MediaStorage initialization with custom subdirectory and mode.""" - custom_subdir = "custom_images" - storage = MediaStorage(base_path=tmp_path, images_subdir=custom_subdir, mode=StorageMode.DATAFRAME) - assert storage.images_subdir == custom_subdir - assert storage.images_dir == tmp_path / custom_subdir - assert storage.mode == StorageMode.DATAFRAME + assert storage.images_subdir == images_subdir + assert storage.images_dir == tmp_path / images_subdir + assert storage.mode == mode # Directory should NOT exist until first save (lazy initialization) assert not storage.images_dir.exists() @@ -149,12 +146,12 @@ def test_save_base64_image_disk_mode_corrupted_image_raises_error(tmp_path): assert len(list(column_dir.iterdir())) == 0 -def test_save_base64_image_dataframe_mode_returns_base64(tmp_path, sample_base64_png): - """Test that DATAFRAME mode returns base64 directly without disk operations.""" +@pytest.mark.parametrize("subfolder_name", ["test_column", "test_subfolder"], ids=["column", "subfolder"]) +def test_save_base64_image_dataframe_mode_returns_base64(tmp_path, sample_base64_png, subfolder_name): + """Test that DATAFRAME mode returns base64 directly regardless of subfolder name.""" storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DATAFRAME) - # Should return the same base64 data (column_name is ignored in DATAFRAME mode) - result = storage.save_base64_image(sample_base64_png, subfolder_name="test_column") + result = storage.save_base64_image(sample_base64_png, subfolder_name=subfolder_name) assert result == sample_base64_png # Directory should not be created in DATAFRAME mode (lazy initialization) @@ -201,18 +198,6 @@ def test_save_base64_image_with_different_subfolder_names(media_storage, sample_ assert (media_storage.base_path / path2).exists() -def test_save_base64_image_dataframe_mode_with_subfolder_name(tmp_path, sample_base64_png): - """Test that DATAFRAME mode returns base64 directly even with subfolder name.""" - storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DATAFRAME) - - # Should return the same base64 data regardless of subfolder name - result = storage.save_base64_image(sample_base64_png, subfolder_name="test_subfolder") - assert result == sample_base64_png - - # Directory should not be created in DATAFRAME mode - assert not storage.images_dir.exists() - - @pytest.mark.parametrize( "unsafe_name,expected_sanitized", [ @@ -236,3 +221,34 @@ def test_save_base64_image_sanitizes_subfolder_name(media_storage, sample_base64 full_path = media_storage.base_path / relative_path assert full_path.exists() assert media_storage.images_dir in full_path.parents + + +# --------------------------------------------------------------------------- +# delete_image +# --------------------------------------------------------------------------- + + +def test_delete_image_removes_saved_file(media_storage, sample_base64_png) -> None: + """Test that delete_image removes a previously saved image.""" + relative_path = media_storage.save_base64_image(sample_base64_png, subfolder_name="col") + full_path = media_storage.base_path / relative_path + assert full_path.exists() + + result = media_storage.delete_image(relative_path) + assert result is True + assert not full_path.exists() + + +def test_delete_image_returns_false_for_nonexistent(media_storage) -> None: + """Test that delete_image returns False when the file doesn't exist.""" + assert media_storage.delete_image(f"{IMAGES_SUBDIR}/col/nonexistent.png") is False + + +def test_delete_image_rejects_path_outside_images_dir(media_storage, tmp_path) -> None: + """Test that delete_image refuses to delete files outside the images directory.""" + outside_file = tmp_path / "outside.txt" + outside_file.write_text("should not be deleted") + + result = media_storage.delete_image("../outside.txt") + assert result is False + assert outside_file.exists() From b1648c78bfc3d4f1b985454bfdd01f34193ac22c Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Thu, 12 Feb 2026 11:33:44 -0700 Subject: [PATCH 62/64] reorder docstring --- .../src/data_designer/config/column_configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index e3ea013d..49dbb831 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -492,7 +492,6 @@ class ImageColumnConfig(SingleColumnConfig): The API used is automatically determined based on the model name: Attributes: - column_type: Discriminator field, always "image" for this configuration type. prompt: Prompt template for image generation. Supports Jinja2 templating to reference other columns (e.g., "Generate an image of a {{ character_name }}"). Must be a valid Jinja2 template. @@ -500,6 +499,7 @@ class ImageColumnConfig(SingleColumnConfig): multi_modal_context: Optional list of image contexts for multi-modal generation. Enables autoregressive multi-modal models to generate images based on image inputs. Only works with autoregressive models that support image-to-image generation. + column_type: Discriminator field, always "image" for this configuration type. """ prompt: str From 78202339a4da2b623ce652b7c473c79bba3c96aa Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Thu, 12 Feb 2026 11:47:32 -0700 Subject: [PATCH 63/64] Fix bug with display sample record with index=0 --- .../src/data_designer/config/utils/visualization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/data-designer-config/src/data_designer/config/utils/visualization.py b/packages/data-designer-config/src/data_designer/config/utils/visualization.py index bd2876d4..ac4df1d8 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/visualization.py +++ b/packages/data-designer-config/src/data_designer/config/utils/visualization.py @@ -171,7 +171,7 @@ def display_sample_record( processors_to_display: List of processors to display the artifacts for. If None, all processors will be displayed. hide_seed_columns: If True, seed columns will not be displayed separately. """ - i = index or self._display_cycle_index + i = self._display_cycle_index if index is None else index try: record = self._record_sampler_dataset.iloc[i] From f491c11d4bd9b8d01ba72f04bd5443290892c52f Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Thu, 12 Feb 2026 11:51:16 -0700 Subject: [PATCH 64/64] remove redundant kwargs consolidation --- .../src/data_designer/engine/models/facade.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index cf3c7e6e..ef328a9a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -407,7 +407,6 @@ def _generate_image_chat_completion( Returns: List of base64-encoded image strings """ - kwargs = self.consolidate_kwargs(**kwargs) messages = prompt_to_messages(user_prompt=prompt, multi_modal_context=multi_modal_context) response = None