diff --git a/docs/colab_notebooks/1-the-basics.ipynb b/docs/colab_notebooks/1-the-basics.ipynb
index ec2c5a99..f50209f7 100644
--- a/docs/colab_notebooks/1-the-basics.ipynb
+++ b/docs/colab_notebooks/1-the-basics.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "c79eea7a",
+ "id": "96178d08",
"metadata": {},
"source": [
"# π¨ Data Designer Tutorial: The Basics\n",
@@ -14,7 +14,7 @@
},
{
"cell_type": "markdown",
- "id": "2476f160",
+ "id": "1d02a1d6",
"metadata": {},
"source": [
"### π¦ Import Data Designer\n",
@@ -26,7 +26,7 @@
},
{
"cell_type": "markdown",
- "id": "3646f62e",
+ "id": "2292d817",
"metadata": {},
"source": [
"### β‘ Colab Setup\n",
@@ -37,7 +37,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3348e5c8",
+ "id": "8af621fc",
"metadata": {},
"outputs": [],
"source": [
@@ -48,7 +48,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "19cd9249",
+ "id": "70e6a11c",
"metadata": {},
"outputs": [],
"source": [
@@ -66,7 +66,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5a6d13a9",
+ "id": "41031828",
"metadata": {},
"outputs": [],
"source": [
@@ -76,7 +76,7 @@
},
{
"cell_type": "markdown",
- "id": "d445af5b",
+ "id": "0b480b10",
"metadata": {},
"source": [
"### βοΈ Initialize the Data Designer interface\n",
@@ -89,7 +89,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4df0031d",
+ "id": "d434a8e2",
"metadata": {},
"outputs": [],
"source": [
@@ -98,7 +98,7 @@
},
{
"cell_type": "markdown",
- "id": "0f69b576",
+ "id": "f88f6792",
"metadata": {},
"source": [
"### ποΈ Define model configurations\n",
@@ -115,7 +115,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "65d9be99",
+ "id": "4261574c",
"metadata": {},
"outputs": [],
"source": [
@@ -145,7 +145,7 @@
},
{
"cell_type": "markdown",
- "id": "72582d09",
+ "id": "bbbc3d58",
"metadata": {},
"source": [
"### ποΈ Initialize the Data Designer Config Builder\n",
@@ -160,7 +160,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8d7992b4",
+ "id": "92c0cf35",
"metadata": {},
"outputs": [],
"source": [
@@ -169,7 +169,7 @@
},
{
"cell_type": "markdown",
- "id": "741a15a0",
+ "id": "44246c7d",
"metadata": {},
"source": [
"## π² Getting started with sampler columns\n",
@@ -186,7 +186,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "c3879c70",
+ "id": "07d20f3f",
"metadata": {},
"outputs": [],
"source": [
@@ -195,7 +195,7 @@
},
{
"cell_type": "markdown",
- "id": "1575ef81",
+ "id": "9d3c87b0",
"metadata": {},
"source": [
"Let's start designing our product review dataset by adding product category and subcategory columns.\n"
@@ -204,7 +204,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "87a88d7b",
+ "id": "c646b021",
"metadata": {},
"outputs": [],
"source": [
@@ -285,7 +285,7 @@
},
{
"cell_type": "markdown",
- "id": "8c74b738",
+ "id": "ff18b032",
"metadata": {},
"source": [
"Next, let's add samplers to generate data related to the customer and their review.\n"
@@ -294,7 +294,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4eb1da1f",
+ "id": "78846d99",
"metadata": {},
"outputs": [],
"source": [
@@ -331,7 +331,7 @@
},
{
"cell_type": "markdown",
- "id": "4324d869",
+ "id": "97059bfc",
"metadata": {},
"source": [
"## π¦ LLM-generated columns\n",
@@ -346,7 +346,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1302a503",
+ "id": "98c66eff",
"metadata": {},
"outputs": [],
"source": [
@@ -382,7 +382,7 @@
},
{
"cell_type": "markdown",
- "id": "7cf8241b",
+ "id": "ff2d52b9",
"metadata": {},
"source": [
"### π Iteration is key βΒ preview the dataset!\n",
@@ -399,7 +399,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6fc6cf39",
+ "id": "6e622478",
"metadata": {},
"outputs": [],
"source": [
@@ -409,7 +409,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "c929e068",
+ "id": "1addc7d8",
"metadata": {},
"outputs": [],
"source": [
@@ -420,7 +420,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "dfb04e2a",
+ "id": "7af4b9c3",
"metadata": {},
"outputs": [],
"source": [
@@ -430,7 +430,7 @@
},
{
"cell_type": "markdown",
- "id": "adb879da",
+ "id": "91d0ee89",
"metadata": {},
"source": [
"### π Analyze the generated data\n",
@@ -443,7 +443,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "ff58dd9f",
+ "id": "e1e3aed0",
"metadata": {},
"outputs": [],
"source": [
@@ -453,7 +453,7 @@
},
{
"cell_type": "markdown",
- "id": "57c7355d",
+ "id": "6eaa402e",
"metadata": {},
"source": [
"### π Scale up!\n",
@@ -466,7 +466,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "df49db99",
+ "id": "f6b148d4",
"metadata": {},
"outputs": [],
"source": [
@@ -476,7 +476,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2bbc48dd",
+ "id": "f4e62e5b",
"metadata": {},
"outputs": [],
"source": [
@@ -489,7 +489,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "dc0673fa",
+ "id": "7d426ab0",
"metadata": {},
"outputs": [],
"source": [
@@ -501,7 +501,7 @@
},
{
"cell_type": "markdown",
- "id": "7688217b",
+ "id": "449d003c",
"metadata": {},
"source": [
"## βοΈ Next Steps\n",
@@ -512,7 +512,9 @@
"\n",
"- [Seeding synthetic data generation with an external dataset](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/3-seeding-with-a-dataset/)\n",
"\n",
- "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n"
+ "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n",
+ "\n",
+ "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/)\n"
]
}
],
diff --git a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb
index c813ea50..a6e04680 100644
--- a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb
+++ b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "258752cd",
+ "id": "ba22504d",
"metadata": {},
"source": [
"# π¨ Data Designer Tutorial: Structured Outputs and Jinja Expressions\n",
@@ -16,7 +16,7 @@
},
{
"cell_type": "markdown",
- "id": "fc4217c3",
+ "id": "c176fe63",
"metadata": {},
"source": [
"### π¦ Import Data Designer\n",
@@ -28,7 +28,7 @@
},
{
"cell_type": "markdown",
- "id": "2b831130",
+ "id": "32c80f72",
"metadata": {},
"source": [
"### β‘ Colab Setup\n",
@@ -39,7 +39,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "fa1eda43",
+ "id": "4ab45e3a",
"metadata": {},
"outputs": [],
"source": [
@@ -50,7 +50,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5f014571",
+ "id": "2ae70d67",
"metadata": {},
"outputs": [],
"source": [
@@ -68,7 +68,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f7409282",
+ "id": "2cdc070b",
"metadata": {},
"outputs": [],
"source": [
@@ -78,7 +78,7 @@
},
{
"cell_type": "markdown",
- "id": "8234dd4b",
+ "id": "a04261b9",
"metadata": {},
"source": [
"### βοΈ Initialize the Data Designer interface\n",
@@ -91,7 +91,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "21633aed",
+ "id": "c8bef18a",
"metadata": {},
"outputs": [],
"source": [
@@ -100,7 +100,7 @@
},
{
"cell_type": "markdown",
- "id": "9b215265",
+ "id": "ed555636",
"metadata": {},
"source": [
"### ποΈ Define model configurations\n",
@@ -117,7 +117,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "76260638",
+ "id": "47208094",
"metadata": {},
"outputs": [],
"source": [
@@ -147,7 +147,7 @@
},
{
"cell_type": "markdown",
- "id": "e6bfd93d",
+ "id": "36c200d9",
"metadata": {},
"source": [
"### ποΈ Initialize the Data Designer Config Builder\n",
@@ -162,7 +162,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a0fbd497",
+ "id": "57c0d82f",
"metadata": {},
"outputs": [],
"source": [
@@ -171,7 +171,7 @@
},
{
"cell_type": "markdown",
- "id": "7faae40e",
+ "id": "01ff63ca",
"metadata": {},
"source": [
"### π§βπ¨ Designing our data\n",
@@ -198,7 +198,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f2f94909",
+ "id": "4fb0f1ca",
"metadata": {},
"outputs": [],
"source": [
@@ -226,7 +226,7 @@
},
{
"cell_type": "markdown",
- "id": "696f19f4",
+ "id": "8f35bd87",
"metadata": {},
"source": [
"Next, let's design our product review dataset using a few more tricks compared to the previous notebook.\n"
@@ -235,7 +235,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "312b50cd",
+ "id": "43341f16",
"metadata": {},
"outputs": [],
"source": [
@@ -344,7 +344,7 @@
},
{
"cell_type": "markdown",
- "id": "ecd971ca",
+ "id": "34c3e08b",
"metadata": {},
"source": [
"Next, we will use more advanced Jinja expressions to create new columns.\n",
@@ -361,7 +361,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "bda01ffc",
+ "id": "c168c089",
"metadata": {},
"outputs": [],
"source": [
@@ -414,7 +414,7 @@
},
{
"cell_type": "markdown",
- "id": "059613e1",
+ "id": "7e6521a2",
"metadata": {},
"source": [
"### π Iteration is key βΒ preview the dataset!\n",
@@ -431,7 +431,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "23c9b839",
+ "id": "03510f78",
"metadata": {},
"outputs": [],
"source": [
@@ -441,7 +441,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e5adcdbd",
+ "id": "ad599c43",
"metadata": {},
"outputs": [],
"source": [
@@ -452,7 +452,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1cc39cae",
+ "id": "dbd3e17c",
"metadata": {},
"outputs": [],
"source": [
@@ -462,7 +462,7 @@
},
{
"cell_type": "markdown",
- "id": "bcca3f06",
+ "id": "4db52c26",
"metadata": {},
"source": [
"### π Analyze the generated data\n",
@@ -475,7 +475,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6e1957ca",
+ "id": "f1007ac4",
"metadata": {},
"outputs": [],
"source": [
@@ -485,7 +485,7 @@
},
{
"cell_type": "markdown",
- "id": "9db283d3",
+ "id": "dcd68de4",
"metadata": {},
"source": [
"### π Scale up!\n",
@@ -498,7 +498,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "30826883",
+ "id": "27b6bfe8",
"metadata": {},
"outputs": [],
"source": [
@@ -508,7 +508,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "88d4d3bd",
+ "id": "d4e9a395",
"metadata": {},
"outputs": [],
"source": [
@@ -521,7 +521,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8762a2bb",
+ "id": "946b3aa8",
"metadata": {},
"outputs": [],
"source": [
@@ -533,7 +533,7 @@
},
{
"cell_type": "markdown",
- "id": "0375fcd2",
+ "id": "f50d996e",
"metadata": {},
"source": [
"## βοΈ Next Steps\n",
@@ -542,7 +542,9 @@
"\n",
"- [Seeding synthetic data generation with an external dataset](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/3-seeding-with-a-dataset/)\n",
"\n",
- "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n"
+ "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n",
+ "\n",
+ "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/)\n"
]
}
],
diff --git a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb
index c5d427d0..639e88df 100644
--- a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb
+++ b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "b2a3e544",
+ "id": "25501772",
"metadata": {},
"source": [
"# π¨ Data Designer Tutorial: Seeding Synthetic Data Generation with an External Dataset\n",
@@ -16,7 +16,7 @@
},
{
"cell_type": "markdown",
- "id": "d57c4f0a",
+ "id": "67ffc49e",
"metadata": {},
"source": [
"### π¦ Import Data Designer\n",
@@ -28,7 +28,7 @@
},
{
"cell_type": "markdown",
- "id": "f7da8723",
+ "id": "54a42504",
"metadata": {},
"source": [
"### β‘ Colab Setup\n",
@@ -39,7 +39,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "90a12556",
+ "id": "05b45354",
"metadata": {},
"outputs": [],
"source": [
@@ -50,7 +50,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8fcdfde5",
+ "id": "039360fe",
"metadata": {},
"outputs": [],
"source": [
@@ -68,7 +68,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5899e85c",
+ "id": "028d5e8a",
"metadata": {},
"outputs": [],
"source": [
@@ -78,7 +78,7 @@
},
{
"cell_type": "markdown",
- "id": "6c093c90",
+ "id": "15a1df61",
"metadata": {},
"source": [
"### βοΈ Initialize the Data Designer interface\n",
@@ -91,7 +91,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6a2066fe",
+ "id": "a87b6ff6",
"metadata": {},
"outputs": [],
"source": [
@@ -100,7 +100,7 @@
},
{
"cell_type": "markdown",
- "id": "f5e81142",
+ "id": "b9166cfd",
"metadata": {},
"source": [
"### ποΈ Define model configurations\n",
@@ -117,7 +117,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "880012ea",
+ "id": "4961d3b0",
"metadata": {},
"outputs": [],
"source": [
@@ -147,7 +147,7 @@
},
{
"cell_type": "markdown",
- "id": "4b77a92c",
+ "id": "b1d8588a",
"metadata": {},
"source": [
"### ποΈ Initialize the Data Designer Config Builder\n",
@@ -162,7 +162,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f4ab6628",
+ "id": "cf42a4dd",
"metadata": {},
"outputs": [],
"source": [
@@ -171,7 +171,7 @@
},
{
"cell_type": "markdown",
- "id": "26fb0a63",
+ "id": "8d6b26aa",
"metadata": {},
"source": [
"## π₯ Prepare a seed dataset\n",
@@ -196,7 +196,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "84908e88",
+ "id": "fc90401d",
"metadata": {},
"outputs": [],
"source": [
@@ -214,7 +214,7 @@
},
{
"cell_type": "markdown",
- "id": "1947e70a",
+ "id": "6f5ee960",
"metadata": {},
"source": [
"## π¨ Designing our synthetic patient notes dataset\n",
@@ -227,7 +227,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "be2fbad1",
+ "id": "e9db2ff0",
"metadata": {},
"outputs": [],
"source": [
@@ -308,7 +308,7 @@
},
{
"cell_type": "markdown",
- "id": "8fcce5dc",
+ "id": "00efc894",
"metadata": {},
"source": [
"### π Iteration is key βΒ preview the dataset!\n",
@@ -325,7 +325,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "82dc02f8",
+ "id": "3e3d824e",
"metadata": {},
"outputs": [],
"source": [
@@ -335,7 +335,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1f2d1583",
+ "id": "27785af7",
"metadata": {},
"outputs": [],
"source": [
@@ -346,7 +346,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "62a9173b",
+ "id": "430998d1",
"metadata": {},
"outputs": [],
"source": [
@@ -356,7 +356,7 @@
},
{
"cell_type": "markdown",
- "id": "5263e705",
+ "id": "dda6458b",
"metadata": {},
"source": [
"### π Analyze the generated data\n",
@@ -369,7 +369,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5295320f",
+ "id": "f45bc088",
"metadata": {},
"outputs": [],
"source": [
@@ -379,7 +379,7 @@
},
{
"cell_type": "markdown",
- "id": "3ecc195f",
+ "id": "1e913fd8",
"metadata": {},
"source": [
"### π Scale up!\n",
@@ -392,7 +392,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3865fb59",
+ "id": "30b8b7f7",
"metadata": {},
"outputs": [],
"source": [
@@ -402,7 +402,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a7acf2b0",
+ "id": "b7ff96d1",
"metadata": {},
"outputs": [],
"source": [
@@ -415,7 +415,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "81a6e999",
+ "id": "dbfef8a8",
"metadata": {},
"outputs": [],
"source": [
@@ -427,14 +427,16 @@
},
{
"cell_type": "markdown",
- "id": "4503b1cf",
+ "id": "5db3f38d",
"metadata": {},
"source": [
"## βοΈ Next Steps\n",
"\n",
"Check out the following notebook to learn more about:\n",
"\n",
- "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n"
+ "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n",
+ "\n",
+ "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/)\n"
]
}
],
diff --git a/docs/colab_notebooks/4-providing-images-as-context.ipynb b/docs/colab_notebooks/4-providing-images-as-context.ipynb
index cd175537..9797695e 100644
--- a/docs/colab_notebooks/4-providing-images-as-context.ipynb
+++ b/docs/colab_notebooks/4-providing-images-as-context.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "90dda708",
+ "id": "19e57933",
"metadata": {},
"source": [
"# π¨ Data Designer Tutorial: Providing Images as Context for Vision-Based Data Generation"
@@ -10,7 +10,7 @@
},
{
"cell_type": "markdown",
- "id": "52ccb1e5",
+ "id": "25e3cc64",
"metadata": {},
"source": [
"#### π What you'll learn\n",
@@ -25,7 +25,7 @@
},
{
"cell_type": "markdown",
- "id": "9627c4eb",
+ "id": "4aae5c82",
"metadata": {},
"source": [
"### π¦ Import Data Designer\n",
@@ -37,7 +37,7 @@
},
{
"cell_type": "markdown",
- "id": "1817171a",
+ "id": "24dfae6c",
"metadata": {},
"source": [
"### β‘ Colab Setup\n",
@@ -48,7 +48,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1f15a669",
+ "id": "619b1aae",
"metadata": {},
"outputs": [],
"source": [
@@ -59,7 +59,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1201c93b",
+ "id": "0d49a542",
"metadata": {},
"outputs": [],
"source": [
@@ -77,7 +77,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f814b76c",
+ "id": "1b28f160",
"metadata": {},
"outputs": [],
"source": [
@@ -100,7 +100,7 @@
},
{
"cell_type": "markdown",
- "id": "ac423d57",
+ "id": "63dc34de",
"metadata": {},
"source": [
"### βοΈ Initialize the Data Designer interface\n",
@@ -113,7 +113,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3c655c2d",
+ "id": "672155c8",
"metadata": {},
"outputs": [],
"source": [
@@ -122,7 +122,7 @@
},
{
"cell_type": "markdown",
- "id": "7d41e922",
+ "id": "4b32c25e",
"metadata": {},
"source": [
"### ποΈ Define model configurations\n",
@@ -139,7 +139,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a8b5f4bf",
+ "id": "72971915",
"metadata": {},
"outputs": [],
"source": [
@@ -162,7 +162,7 @@
},
{
"cell_type": "markdown",
- "id": "6455fc58",
+ "id": "115ad20f",
"metadata": {},
"source": [
"### ποΈ Initialize the Data Designer Config Builder\n",
@@ -177,7 +177,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "462c2e01",
+ "id": "11e844d2",
"metadata": {},
"outputs": [],
"source": [
@@ -186,7 +186,7 @@
},
{
"cell_type": "markdown",
- "id": "31369d10",
+ "id": "77862fce",
"metadata": {},
"source": [
"### π± Seed Dataset Creation\n",
@@ -203,7 +203,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "55d9432a",
+ "id": "e415a502",
"metadata": {},
"outputs": [],
"source": [
@@ -218,7 +218,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8614c4e9",
+ "id": "335f2611",
"metadata": {},
"outputs": [],
"source": [
@@ -266,7 +266,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "80550e46",
+ "id": "f055e88d",
"metadata": {},
"outputs": [],
"source": [
@@ -284,7 +284,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "65ced9bb",
+ "id": "47a1c586",
"metadata": {},
"outputs": [],
"source": [
@@ -294,7 +294,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "34b210e8",
+ "id": "3a77fc52",
"metadata": {},
"outputs": [],
"source": [
@@ -306,7 +306,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d506903d",
+ "id": "c0941cc7",
"metadata": {},
"outputs": [],
"source": [
@@ -335,7 +335,7 @@
},
{
"cell_type": "markdown",
- "id": "b91032a2",
+ "id": "578e77dc",
"metadata": {},
"source": [
"### π Iteration is key β preview the dataset!\n",
@@ -352,7 +352,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4bd947de",
+ "id": "9f0c11ce",
"metadata": {},
"outputs": [],
"source": [
@@ -362,7 +362,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d0ff4c07",
+ "id": "b10412c1",
"metadata": {},
"outputs": [],
"source": [
@@ -373,7 +373,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e97e4dfe",
+ "id": "766ee2d7",
"metadata": {},
"outputs": [],
"source": [
@@ -383,7 +383,7 @@
},
{
"cell_type": "markdown",
- "id": "0a284c12",
+ "id": "6370bfa5",
"metadata": {},
"source": [
"### π Analyze the generated data\n",
@@ -396,7 +396,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2570e7fd",
+ "id": "d57ded0e",
"metadata": {},
"outputs": [],
"source": [
@@ -406,7 +406,7 @@
},
{
"cell_type": "markdown",
- "id": "28b8eb5a",
+ "id": "5afd8e8c",
"metadata": {},
"source": [
"### π Visual Inspection\n",
@@ -417,7 +417,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5d0d9336",
+ "id": "aa4bfcc3",
"metadata": {
"lines_to_next_cell": 2
},
@@ -441,7 +441,7 @@
},
{
"cell_type": "markdown",
- "id": "1c257a81",
+ "id": "4eeaada6",
"metadata": {},
"source": [
"### π Scale up!\n",
@@ -454,7 +454,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e6d840e9",
+ "id": "0ee5b1b9",
"metadata": {},
"outputs": [],
"source": [
@@ -464,7 +464,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "909e6f3f",
+ "id": "e5e8b241",
"metadata": {},
"outputs": [],
"source": [
@@ -477,7 +477,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "adbb4cae",
+ "id": "23ebb3ca",
"metadata": {},
"outputs": [],
"source": [
@@ -489,7 +489,7 @@
},
{
"cell_type": "markdown",
- "id": "d085584c",
+ "id": "14a78533",
"metadata": {},
"source": [
"## βοΈ Next Steps\n",
@@ -499,7 +499,9 @@
"- Experiment with different vision models for specific document types\n",
"- Try different prompt variations to generate specialized descriptions (e.g., technical details, key findings)\n",
"- Combine vision-based summaries with other column types for multi-modal workflows\n",
- "- Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering\n"
+ "- Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering\n",
+ "\n",
+ "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) with Data Designer\n"
]
}
],
diff --git a/docs/colab_notebooks/5-generating-images.ipynb b/docs/colab_notebooks/5-generating-images.ipynb
new file mode 100644
index 00000000..c8092938
--- /dev/null
+++ b/docs/colab_notebooks/5-generating-images.ipynb
@@ -0,0 +1,437 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "735e6197",
+ "metadata": {},
+ "source": [
+ "# π¨ Data Designer Tutorial: Generating Images\n",
+ "\n",
+ "#### π What you'll learn\n",
+ "\n",
+ "This notebook shows how to generate synthetic image data with Data Designer using image-generation models.\n",
+ "\n",
+ "- πΌοΈ **Image generation columns**: Add columns that produce images from text prompts\n",
+ "- π **Jinja2 prompts**: Drive diversity by referencing other columns in your prompt template\n",
+ "- πΎ **Preview vs create**: Preview stores base64 in the dataframe; create saves images to disk and stores paths\n",
+ "\n",
+ "Data Designer supports both **diffusion** (e.g. DALLΒ·E, Stable Diffusion, Imagen) and **autoregressive** (e.g. Gemini image, GPT image) models; the API is chosen automatically from the model name.\n",
+ "\n",
+ "If this is your first time using Data Designer, we recommend starting with the [first notebook](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/) in this tutorial series.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "92ae4afe",
+ "metadata": {},
+ "source": [
+ "### π¦ Import Data Designer\n",
+ "\n",
+ "- `data_designer.config` provides the configuration API.\n",
+ "- `DataDesigner` is the main interface for generation.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ccc77347",
+ "metadata": {},
+ "source": [
+ "### β‘ Colab Setup\n",
+ "\n",
+ "Run the cells below to install the dependencies and set up the API key. If you don't have an API key, you can generate one from [build.nvidia.com](https://build.nvidia.com).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "23627c23",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "!pip install -U data-designer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bf958dc6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import getpass\n",
+ "import os\n",
+ "\n",
+ "from google.colab import userdata\n",
+ "\n",
+ "try:\n",
+ " os.environ[\"NVIDIA_API_KEY\"] = userdata.get(\"NVIDIA_API_KEY\")\n",
+ "except userdata.SecretNotFoundError:\n",
+ " os.environ[\"NVIDIA_API_KEY\"] = getpass.getpass(\"Enter your NVIDIA API key: \")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ab0cfff8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from IPython.display import Image as IPImage\n",
+ "from IPython.display import display\n",
+ "\n",
+ "import data_designer.config as dd\n",
+ "from data_designer.interface import DataDesigner"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a18ef5ce",
+ "metadata": {},
+ "source": [
+ "### βοΈ Initialize the Data Designer interface\n",
+ "\n",
+ "When initialized without arguments, [default model providers](https://nvidia-nemo.github.io/DataDesigner/latest/concepts/models/default-model-settings/) are used. This tutorial uses [OpenRouter](https://openrouter.ai) with the Flux 2 Pro image model; set `OPENROUTER_API_KEY` in your environment.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5fe11301",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data_designer = DataDesigner()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b913d454",
+ "metadata": {},
+ "source": [
+ "### ποΈ Define an image-generation model\n",
+ "\n",
+ "- Use `ImageInferenceParams` so Data Designer treats this model as an image generator.\n",
+ "- Image options (size, quality, aspect ratio, etc.) are model-specific; pass them via `extra_body`.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a50d26ee",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MODEL_PROVIDER = \"openrouter\"\n",
+ "MODEL_ID = \"black-forest-labs/flux.2-pro\"\n",
+ "MODEL_ALIAS = \"image-model\"\n",
+ "\n",
+ "model_configs = [\n",
+ " dd.ModelConfig(\n",
+ " alias=MODEL_ALIAS,\n",
+ " model=MODEL_ID,\n",
+ " provider=MODEL_PROVIDER,\n",
+ " inference_parameters=dd.ImageInferenceParams(\n",
+ " extra_body={\"height\": 512, \"width\": 512},\n",
+ " ),\n",
+ " )\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "122374d9",
+ "metadata": {},
+ "source": [
+ "### ποΈ Build the config: samplers + image column\n",
+ "\n",
+ "We'll generate diverse **dog portrait** images: sampler columns drive subject (breed), age, style, look direction, and emotion. The image-generation column uses a Jinja2 prompt that references all of them.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "940f2b70",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config_builder = dd.DataDesignerConfigBuilder(model_configs=model_configs)\n",
+ "\n",
+ "config_builder.add_column(\n",
+ " dd.SamplerColumnConfig(\n",
+ " name=\"style\",\n",
+ " sampler_type=dd.SamplerType.CATEGORY,\n",
+ " params=dd.CategorySamplerParams(\n",
+ " values=[\n",
+ " \"photorealistic\",\n",
+ " \"oil painting\",\n",
+ " \"watercolor\",\n",
+ " \"digital art\",\n",
+ " \"sketch\",\n",
+ " \"anime\",\n",
+ " ],\n",
+ " ),\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "config_builder.add_column(\n",
+ " dd.SamplerColumnConfig(\n",
+ " name=\"dog_breed\",\n",
+ " sampler_type=dd.SamplerType.CATEGORY,\n",
+ " params=dd.CategorySamplerParams(\n",
+ " values=[\n",
+ " \"a Golden Retriever\",\n",
+ " \"a German Shepherd\",\n",
+ " \"a Labrador Retriever\",\n",
+ " \"a Bulldog\",\n",
+ " \"a Beagle\",\n",
+ " \"a Poodle\",\n",
+ " \"a Corgi\",\n",
+ " \"a Siberian Husky\",\n",
+ " \"a Dalmatian\",\n",
+ " \"a Yorkshire Terrier\",\n",
+ " \"a Boxer\",\n",
+ " \"a Dachshund\",\n",
+ " \"a Doberman Pinscher\",\n",
+ " \"a Shih Tzu\",\n",
+ " \"a Chihuahua\",\n",
+ " \"a Border Collie\",\n",
+ " \"an Australian Shepherd\",\n",
+ " \"a Cocker Spaniel\",\n",
+ " \"a Maltese\",\n",
+ " \"a Pomeranian\",\n",
+ " \"a Saint Bernard\",\n",
+ " \"a Great Dane\",\n",
+ " \"an Akita\",\n",
+ " \"a Samoyed\",\n",
+ " \"a Boston Terrier\",\n",
+ " ],\n",
+ " ),\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "config_builder.add_column(\n",
+ " dd.SamplerColumnConfig(\n",
+ " name=\"cat_breed\",\n",
+ " sampler_type=dd.SamplerType.CATEGORY,\n",
+ " params=dd.CategorySamplerParams(\n",
+ " values=[\n",
+ " \"a Persian\",\n",
+ " \"a Maine Coon\",\n",
+ " \"a Siamese\",\n",
+ " \"a Ragdoll\",\n",
+ " \"a Bengal\",\n",
+ " \"an Abyssinian\",\n",
+ " \"a British Shorthair\",\n",
+ " \"a Sphynx\",\n",
+ " \"a Scottish Fold\",\n",
+ " \"a Russian Blue\",\n",
+ " \"a Birman\",\n",
+ " \"an Oriental Shorthair\",\n",
+ " \"a Norwegian Forest Cat\",\n",
+ " \"a Devon Rex\",\n",
+ " \"a Burmese\",\n",
+ " \"an Egyptian Mau\",\n",
+ " \"a Tonkinese\",\n",
+ " \"a Himalayan\",\n",
+ " \"a Savannah\",\n",
+ " \"a Chartreux\",\n",
+ " \"a Somali\",\n",
+ " \"a Manx\",\n",
+ " \"a Turkish Angora\",\n",
+ " \"a Balinese\",\n",
+ " \"an American Shorthair\",\n",
+ " ],\n",
+ " ),\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "config_builder.add_column(\n",
+ " dd.SamplerColumnConfig(\n",
+ " name=\"dog_age\",\n",
+ " sampler_type=dd.SamplerType.CATEGORY,\n",
+ " params=dd.CategorySamplerParams(\n",
+ " values=[\"1-3\", \"3-6\", \"6-9\", \"9-12\", \"12-15\"],\n",
+ " ),\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "config_builder.add_column(\n",
+ " dd.SamplerColumnConfig(\n",
+ " name=\"cat_age\",\n",
+ " sampler_type=dd.SamplerType.CATEGORY,\n",
+ " params=dd.CategorySamplerParams(\n",
+ " values=[\"1-3\", \"3-6\", \"6-9\", \"9-12\", \"12-18\"],\n",
+ " ),\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "config_builder.add_column(\n",
+ " dd.SamplerColumnConfig(\n",
+ " name=\"dog_look_direction\",\n",
+ " sampler_type=dd.SamplerType.CATEGORY,\n",
+ " params=dd.CategorySamplerParams(\n",
+ " values=[\"left\", \"right\", \"front\", \"up\", \"down\"],\n",
+ " ),\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "config_builder.add_column(\n",
+ " dd.SamplerColumnConfig(\n",
+ " name=\"cat_look_direction\",\n",
+ " sampler_type=dd.SamplerType.CATEGORY,\n",
+ " params=dd.CategorySamplerParams(\n",
+ " values=[\"left\", \"right\", \"front\", \"up\", \"down\"],\n",
+ " ),\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "config_builder.add_column(\n",
+ " dd.SamplerColumnConfig(\n",
+ " name=\"dog_emotion\",\n",
+ " sampler_type=dd.SamplerType.CATEGORY,\n",
+ " params=dd.CategorySamplerParams(\n",
+ " values=[\"happy\", \"curious\", \"serious\", \"sleepy\", \"excited\"],\n",
+ " ),\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "config_builder.add_column(\n",
+ " dd.SamplerColumnConfig(\n",
+ " name=\"cat_emotion\",\n",
+ " sampler_type=dd.SamplerType.CATEGORY,\n",
+ " params=dd.CategorySamplerParams(\n",
+ " values=[\"aloof\", \"curious\", \"content\", \"sleepy\", \"playful\"],\n",
+ " ),\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "config_builder.add_column(\n",
+ " dd.ImageColumnConfig(\n",
+ " name=\"generated_image\",\n",
+ " prompt=(\n",
+ " \"\"\"\n",
+ "A {{ style }} family pet portrait of a {{ dog_breed }} dog of {{ dog_age }} years old looking {{dog_look_direction}} with an {{ dog_emotion }} expression and\n",
+ "{{ cat_breed }} cat of {{ cat_age }} years old looking {{ cat_look_direction }} with an {{ cat_emotion }} expression in the background. Both subjects should be in focus.\n",
+ " \"\"\"\n",
+ " ),\n",
+ " model_alias=MODEL_ALIAS,\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "data_designer.validate(config_builder)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e13e0bb4",
+ "metadata": {},
+ "source": [
+ "### π Preview: images as base64\n",
+ "\n",
+ "In **preview** mode, generated images are stored as base64 strings in the dataframe. Run the next cell to step through each record (images are shown in the sample record display, but only in a notebook environment).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2a60a76f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "preview = data_designer.preview(config_builder, num_records=2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3c831ee8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for i in range(len(preview.dataset)):\n",
+ " preview.display_sample_record()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "143e762f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "preview.dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a84606b4",
+ "metadata": {},
+ "source": [
+ "### π Create: images saved to disk\n",
+ "\n",
+ "In **create** mode, images are written to an `images/` folder with UUID filenames; the dataframe stores relative paths (e.g. `images/1d16b6e2-562f-4f51-91e5-baaa999ea916.png`).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "89147954",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "results = data_designer.create(config_builder, num_records=2, dataset_name=\"tutorial-5-images\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "04c96063",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = results.load_dataset()\n",
+ "dataset.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "edb794bb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Display all image from the created dataset. Paths are relative to the artifact output directory.\n",
+ "for index, row in dataset.iterrows():\n",
+ " path_or_list = row.get(\"generated_image\")\n",
+ " if path_or_list is not None:\n",
+ " for path in path_or_list:\n",
+ " base = results.artifact_storage.base_dataset_path\n",
+ " full_path = base / path\n",
+ " display(IPImage(data=full_path))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e0a72bf6",
+ "metadata": {},
+ "source": [
+ "## βοΈ Next steps\n",
+ "\n",
+ "- [The basics](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/): samplers and LLM text columns\n",
+ "- [Structured outputs and Jinja](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/2-structured-outputs-and-jinja-expressions/)\n",
+ "- [Seeding with a dataset](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/3-seeding-with-a-dataset/)\n",
+ "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/notebook_source/1-the-basics.py b/docs/notebook_source/1-the-basics.py
index 392efb34..8735d582 100644
--- a/docs/notebook_source/1-the-basics.py
+++ b/docs/notebook_source/1-the-basics.py
@@ -330,3 +330,5 @@
#
# - [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)
#
+# - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/)
+#
diff --git a/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py b/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py
index 66b3773f..df581612 100644
--- a/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py
+++ b/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py
@@ -372,3 +372,5 @@ class ProductReview(BaseModel):
#
# - [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)
#
+# - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/)
+#
diff --git a/docs/notebook_source/3-seeding-with-a-dataset.py b/docs/notebook_source/3-seeding-with-a-dataset.py
index c9d694a8..e4f9218e 100644
--- a/docs/notebook_source/3-seeding-with-a-dataset.py
+++ b/docs/notebook_source/3-seeding-with-a-dataset.py
@@ -274,3 +274,5 @@
#
# - [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)
#
+# - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/)
+#
diff --git a/docs/notebook_source/4-providing-images-as-context.py b/docs/notebook_source/4-providing-images-as-context.py
index a11880ba..1fd68dac 100644
--- a/docs/notebook_source/4-providing-images-as-context.py
+++ b/docs/notebook_source/4-providing-images-as-context.py
@@ -299,3 +299,5 @@ def convert_image_to_chat_format(record, height: int) -> dict:
# - Combine vision-based summaries with other column types for multi-modal workflows
# - Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering
#
+# - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) with Data Designer
+#
diff --git a/docs/notebook_source/5-generating-images.py b/docs/notebook_source/5-generating-images.py
new file mode 100644
index 00000000..b445b950
--- /dev/null
+++ b/docs/notebook_source/5-generating-images.py
@@ -0,0 +1,296 @@
+# ---
+# jupyter:
+# jupytext:
+# text_representation:
+# extension: .py
+# format_name: percent
+# format_version: '1.3'
+# jupytext_version: 1.18.1
+# kernelspec:
+# display_name: .venv
+# language: python
+# name: python3
+# ---
+
+# %% [markdown]
+# # π¨ Data Designer Tutorial: Generating Images
+#
+# #### π What you'll learn
+#
+# This notebook shows how to generate synthetic image data with Data Designer using image-generation models.
+#
+# - πΌοΈ **Image generation columns**: Add columns that produce images from text prompts
+# - π **Jinja2 prompts**: Drive diversity by referencing other columns in your prompt template
+# - πΎ **Preview vs create**: Preview stores base64 in the dataframe; create saves images to disk and stores paths
+#
+# Data Designer supports both **diffusion** (e.g. DALLΒ·E, Stable Diffusion, Imagen) and **autoregressive** (e.g. Gemini image, GPT image) models; the API is chosen automatically from the model name.
+#
+# If this is your first time using Data Designer, we recommend starting with the [first notebook](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/) in this tutorial series.
+#
+
+# %% [markdown]
+# ### π¦ Import Data Designer
+#
+# - `data_designer.config` provides the configuration API.
+# - `DataDesigner` is the main interface for generation.
+#
+
+# %%
+from IPython.display import Image as IPImage
+from IPython.display import display
+
+import data_designer.config as dd
+from data_designer.interface import DataDesigner
+
+# %% [markdown]
+# ### βοΈ Initialize the Data Designer interface
+#
+# When initialized without arguments, [default model providers](https://nvidia-nemo.github.io/DataDesigner/latest/concepts/models/default-model-settings/) are used. This tutorial uses [OpenRouter](https://openrouter.ai) with the Flux 2 Pro image model; set `OPENROUTER_API_KEY` in your environment.
+#
+
+# %%
+data_designer = DataDesigner()
+
+# %% [markdown]
+# ### ποΈ Define an image-generation model
+#
+# - Use `ImageInferenceParams` so Data Designer treats this model as an image generator.
+# - Image options (size, quality, aspect ratio, etc.) are model-specific; pass them via `extra_body`.
+#
+
+# %%
+MODEL_PROVIDER = "openrouter"
+MODEL_ID = "black-forest-labs/flux.2-pro"
+MODEL_ALIAS = "image-model"
+
+model_configs = [
+ dd.ModelConfig(
+ alias=MODEL_ALIAS,
+ model=MODEL_ID,
+ provider=MODEL_PROVIDER,
+ inference_parameters=dd.ImageInferenceParams(
+ extra_body={"height": 512, "width": 512},
+ ),
+ )
+]
+
+# %% [markdown]
+# ### ποΈ Build the config: samplers + image column
+#
+# We'll generate diverse **dog portrait** images: sampler columns drive subject (breed), age, style, look direction, and emotion. The image-generation column uses a Jinja2 prompt that references all of them.
+#
+
+# %%
+config_builder = dd.DataDesignerConfigBuilder(model_configs=model_configs)
+
+config_builder.add_column(
+ dd.SamplerColumnConfig(
+ name="style",
+ sampler_type=dd.SamplerType.CATEGORY,
+ params=dd.CategorySamplerParams(
+ values=[
+ "photorealistic",
+ "oil painting",
+ "watercolor",
+ "digital art",
+ "sketch",
+ "anime",
+ ],
+ ),
+ )
+)
+
+config_builder.add_column(
+ dd.SamplerColumnConfig(
+ name="dog_breed",
+ sampler_type=dd.SamplerType.CATEGORY,
+ params=dd.CategorySamplerParams(
+ values=[
+ "a Golden Retriever",
+ "a German Shepherd",
+ "a Labrador Retriever",
+ "a Bulldog",
+ "a Beagle",
+ "a Poodle",
+ "a Corgi",
+ "a Siberian Husky",
+ "a Dalmatian",
+ "a Yorkshire Terrier",
+ "a Boxer",
+ "a Dachshund",
+ "a Doberman Pinscher",
+ "a Shih Tzu",
+ "a Chihuahua",
+ "a Border Collie",
+ "an Australian Shepherd",
+ "a Cocker Spaniel",
+ "a Maltese",
+ "a Pomeranian",
+ "a Saint Bernard",
+ "a Great Dane",
+ "an Akita",
+ "a Samoyed",
+ "a Boston Terrier",
+ ],
+ ),
+ )
+)
+
+config_builder.add_column(
+ dd.SamplerColumnConfig(
+ name="cat_breed",
+ sampler_type=dd.SamplerType.CATEGORY,
+ params=dd.CategorySamplerParams(
+ values=[
+ "a Persian",
+ "a Maine Coon",
+ "a Siamese",
+ "a Ragdoll",
+ "a Bengal",
+ "an Abyssinian",
+ "a British Shorthair",
+ "a Sphynx",
+ "a Scottish Fold",
+ "a Russian Blue",
+ "a Birman",
+ "an Oriental Shorthair",
+ "a Norwegian Forest Cat",
+ "a Devon Rex",
+ "a Burmese",
+ "an Egyptian Mau",
+ "a Tonkinese",
+ "a Himalayan",
+ "a Savannah",
+ "a Chartreux",
+ "a Somali",
+ "a Manx",
+ "a Turkish Angora",
+ "a Balinese",
+ "an American Shorthair",
+ ],
+ ),
+ )
+)
+
+config_builder.add_column(
+ dd.SamplerColumnConfig(
+ name="dog_age",
+ sampler_type=dd.SamplerType.CATEGORY,
+ params=dd.CategorySamplerParams(
+ values=["1-3", "3-6", "6-9", "9-12", "12-15"],
+ ),
+ )
+)
+
+config_builder.add_column(
+ dd.SamplerColumnConfig(
+ name="cat_age",
+ sampler_type=dd.SamplerType.CATEGORY,
+ params=dd.CategorySamplerParams(
+ values=["1-3", "3-6", "6-9", "9-12", "12-18"],
+ ),
+ )
+)
+
+config_builder.add_column(
+ dd.SamplerColumnConfig(
+ name="dog_look_direction",
+ sampler_type=dd.SamplerType.CATEGORY,
+ params=dd.CategorySamplerParams(
+ values=["left", "right", "front", "up", "down"],
+ ),
+ )
+)
+
+config_builder.add_column(
+ dd.SamplerColumnConfig(
+ name="cat_look_direction",
+ sampler_type=dd.SamplerType.CATEGORY,
+ params=dd.CategorySamplerParams(
+ values=["left", "right", "front", "up", "down"],
+ ),
+ )
+)
+
+config_builder.add_column(
+ dd.SamplerColumnConfig(
+ name="dog_emotion",
+ sampler_type=dd.SamplerType.CATEGORY,
+ params=dd.CategorySamplerParams(
+ values=["happy", "curious", "serious", "sleepy", "excited"],
+ ),
+ )
+)
+
+config_builder.add_column(
+ dd.SamplerColumnConfig(
+ name="cat_emotion",
+ sampler_type=dd.SamplerType.CATEGORY,
+ params=dd.CategorySamplerParams(
+ values=["aloof", "curious", "content", "sleepy", "playful"],
+ ),
+ )
+)
+
+config_builder.add_column(
+ dd.ImageColumnConfig(
+ name="generated_image",
+ prompt=(
+ """
+A {{ style }} family pet portrait of a {{ dog_breed }} dog of {{ dog_age }} years old looking {{dog_look_direction}} with an {{ dog_emotion }} expression and
+{{ cat_breed }} cat of {{ cat_age }} years old looking {{ cat_look_direction }} with an {{ cat_emotion }} expression in the background. Both subjects should be in focus.
+ """
+ ),
+ model_alias=MODEL_ALIAS,
+ )
+)
+
+data_designer.validate(config_builder)
+
+# %% [markdown]
+# ### π Preview: images as base64
+#
+# In **preview** mode, generated images are stored as base64 strings in the dataframe. Run the next cell to step through each record (images are shown in the sample record display, but only in a notebook environment).
+#
+
+# %%
+preview = data_designer.preview(config_builder, num_records=2)
+
+# %%
+for i in range(len(preview.dataset)):
+ preview.display_sample_record()
+
+# %%
+preview.dataset
+
+# %% [markdown]
+# ### π Create: images saved to disk
+#
+# In **create** mode, images are written to an `images/` folder with UUID filenames; the dataframe stores relative paths (e.g. `images/1d16b6e2-562f-4f51-91e5-baaa999ea916.png`).
+#
+
+# %%
+results = data_designer.create(config_builder, num_records=2, dataset_name="tutorial-5-images")
+
+# %%
+dataset = results.load_dataset()
+dataset.head()
+
+# %%
+# Display all image from the created dataset. Paths are relative to the artifact output directory.
+for index, row in dataset.iterrows():
+ path_or_list = row.get("generated_image")
+ if path_or_list is not None:
+ for path in path_or_list:
+ base = results.artifact_storage.base_dataset_path
+ full_path = base / path
+ display(IPImage(data=full_path))
+
+# %% [markdown]
+# ## βοΈ Next steps
+#
+# - [The basics](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/): samplers and LLM text columns
+# - [Structured outputs and Jinja](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/2-structured-outputs-and-jinja-expressions/)
+# - [Seeding with a dataset](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/3-seeding-with-a-dataset/)
+# - [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)
+#
diff --git a/docs/notebook_source/_README.md b/docs/notebook_source/_README.md
index 09053c22..7bcd77d1 100644
--- a/docs/notebook_source/_README.md
+++ b/docs/notebook_source/_README.md
@@ -97,6 +97,15 @@ Learn how to use vision-language models to generate text descriptions from image
- Generating detailed summaries from document images
- Inspecting and validating vision-based generation results
+### [5. Generating Images](5-generating-images.ipynb)
+
+Generate synthetic image data with Data Designer:
+
+- Configuring image-generation models with `ImageInferenceParams`
+- Adding image columns with Jinja2 prompts and sampler-driven diversity
+- Preview (base64 in dataframe) vs create (images saved to disk, paths in dataframe)
+- Displaying generated images in the notebook
+
## π Important Documentation Sections
Before diving into the tutorials, familiarize yourself with these key documentation sections:
diff --git a/packages/data-designer-config/pyproject.toml b/packages/data-designer-config/pyproject.toml
index cb61dceb..c69d1ba0 100644
--- a/packages/data-designer-config/pyproject.toml
+++ b/packages/data-designer-config/pyproject.toml
@@ -22,6 +22,7 @@ dependencies = [
"jinja2>=3.1.6,<4",
"numpy>=1.23.5,<3",
"pandas>=2.3.3,<3",
+ "pillow>=12.0.0,<13",
"pyarrow>=19.0.1,<20", # Required for parquet I/O operations
"pydantic[email]>=2.9.2,<3",
"pygments>=2.19.2,<3",
diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py
index 2ea641e7..baf21754 100644
--- a/packages/data-designer-config/src/data_designer/config/__init__.py
+++ b/packages/data-designer-config/src/data_designer/config/__init__.py
@@ -17,6 +17,7 @@
EmbeddingColumnConfig,
ExpressionColumnConfig,
GenerationStrategy,
+ ImageColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
LLMStructuredColumnConfig,
@@ -41,6 +42,7 @@
GenerationType,
ImageContext,
ImageFormat,
+ ImageInferenceParams,
ManualDistribution,
ManualDistributionParams,
Modality,
@@ -123,6 +125,7 @@
"CustomColumnConfig": (_MOD_COLUMN_CONFIGS, "CustomColumnConfig"),
"EmbeddingColumnConfig": (_MOD_COLUMN_CONFIGS, "EmbeddingColumnConfig"),
"ExpressionColumnConfig": (_MOD_COLUMN_CONFIGS, "ExpressionColumnConfig"),
+ "ImageColumnConfig": (_MOD_COLUMN_CONFIGS, "ImageColumnConfig"),
"GenerationStrategy": (_MOD_COLUMN_CONFIGS, "GenerationStrategy"),
"LLMCodeColumnConfig": (_MOD_COLUMN_CONFIGS, "LLMCodeColumnConfig"),
"LLMJudgeColumnConfig": (_MOD_COLUMN_CONFIGS, "LLMJudgeColumnConfig"),
@@ -150,6 +153,7 @@
"GenerationType": (_MOD_MODELS, "GenerationType"),
"ImageContext": (_MOD_MODELS, "ImageContext"),
"ImageFormat": (_MOD_MODELS, "ImageFormat"),
+ "ImageInferenceParams": (_MOD_MODELS, "ImageInferenceParams"),
"ManualDistribution": (_MOD_MODELS, "ManualDistribution"),
"ManualDistributionParams": (_MOD_MODELS, "ManualDistributionParams"),
"Modality": (_MOD_MODELS, "Modality"),
diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py
index b2eefd26..49dbb831 100644
--- a/packages/data-designer-config/src/data_designer/config/column_configs.py
+++ b/packages/data-designer-config/src/data_designer/config/column_configs.py
@@ -485,6 +485,59 @@ def side_effect_columns(self) -> list[str]:
return []
+class ImageColumnConfig(SingleColumnConfig):
+ """Configuration for image generation columns.
+
+ Image columns generate images using either autoregressive or diffusion models.
+ The API used is automatically determined based on the model name:
+
+ Attributes:
+ prompt: Prompt template for image generation. Supports Jinja2 templating to
+ reference other columns (e.g., "Generate an image of a {{ character_name }}").
+ Must be a valid Jinja2 template.
+ model_alias: The model to use for image generation.
+ multi_modal_context: Optional list of image contexts for multi-modal generation.
+ Enables autoregressive multi-modal models to generate images based on image inputs.
+ Only works with autoregressive models that support image-to-image generation.
+ column_type: Discriminator field, always "image" for this configuration type.
+ """
+
+ prompt: str
+ model_alias: str
+ multi_modal_context: list[ImageContext] | None = None
+ column_type: Literal["image"] = "image"
+
+ @staticmethod
+ def get_column_emoji() -> str:
+ return "πΌοΈ"
+
+ @property
+ def required_columns(self) -> list[str]:
+ """Get columns referenced in the prompt template.
+
+ Returns:
+ List of unique column names referenced in Jinja2 templates.
+ """
+ return list(extract_keywords_from_jinja2_template(self.prompt))
+
+ @model_validator(mode="after")
+ def assert_prompt_valid_jinja(self) -> Self:
+ """Validate that prompt is a valid Jinja2 template.
+
+ Returns:
+ The validated instance.
+
+ Raises:
+ InvalidConfigError: If prompt contains invalid Jinja2 syntax.
+ """
+ assert_valid_jinja2_template(self.prompt)
+ return self
+
+ @property
+ def side_effect_columns(self) -> list[str]:
+ return []
+
+
class CustomColumnConfig(SingleColumnConfig):
"""Configuration for custom user-defined column generators.
diff --git a/packages/data-designer-config/src/data_designer/config/column_types.py b/packages/data-designer-config/src/data_designer/config/column_types.py
index 8a82223e..baba25dd 100644
--- a/packages/data-designer-config/src/data_designer/config/column_types.py
+++ b/packages/data-designer-config/src/data_designer/config/column_types.py
@@ -9,6 +9,7 @@
CustomColumnConfig,
EmbeddingColumnConfig,
ExpressionColumnConfig,
+ ImageColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
LLMStructuredColumnConfig,
@@ -39,6 +40,7 @@
| SeedDatasetColumnConfig
| ValidationColumnConfig
| EmbeddingColumnConfig
+ | ImageColumnConfig
)
ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT)
@@ -87,6 +89,7 @@ def get_column_display_order() -> list[DataDesignerColumnType]:
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.EMBEDDING,
+ DataDesignerColumnType.IMAGE,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EXPRESSION,
DataDesignerColumnType.CUSTOM,
@@ -142,4 +145,5 @@ def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict:
DataDesignerColumnType.SAMPLER: SamplerColumnConfig,
DataDesignerColumnType.SEED_DATASET: SeedDatasetColumnConfig,
DataDesignerColumnType.EMBEDDING: EmbeddingColumnConfig,
+ DataDesignerColumnType.IMAGE: ImageColumnConfig,
}
diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py
index c91443d0..51954a74 100644
--- a/packages/data-designer-config/src/data_designer/config/models.py
+++ b/packages/data-designer-config/src/data_designer/config/models.py
@@ -242,6 +242,7 @@ def sample(self) -> float:
class GenerationType(str, Enum):
CHAT_COMPLETION = "chat-completion"
EMBEDDING = "embedding"
+ IMAGE = "image"
class BaseInferenceParams(ConfigBase, ABC):
@@ -421,8 +422,41 @@ def generate_kwargs(self) -> dict[str, float | int]:
return result
+class ImageInferenceParams(BaseInferenceParams):
+ """Configuration for image generation models.
+
+ Works for both diffusion and autoregressive image generation models. Pass all model-specific image options via `extra_body`.
+
+ Attributes:
+ generation_type: Type of generation, always "image" for this class.
+
+ Example:
+ ```python
+ # OpenAI-style (DALLΒ·E): quality and size in extra_body or as top-level kwargs
+ dd.ImageInferenceParams(
+ extra_body={"size": "1024x1024", "quality": "hd"}
+ )
+
+ # Gemini-style: generationConfig.imageConfig
+ dd.ImageInferenceParams(
+ extra_body={
+ "generationConfig": {
+ "imageConfig": {
+ "aspectRatio": "1:1",
+ "imageSize": "1024"
+ }
+ }
+ }
+ )
+ ```
+ """
+
+ generation_type: Literal[GenerationType.IMAGE] = GenerationType.IMAGE
+
+
InferenceParamsT: TypeAlias = Annotated[
- ChatCompletionInferenceParams | EmbeddingInferenceParams, Field(discriminator="generation_type")
+ ChatCompletionInferenceParams | EmbeddingInferenceParams | ImageInferenceParams,
+ Field(discriminator="generation_type"),
]
@@ -454,8 +488,13 @@ def generation_type(self) -> GenerationType:
def _convert_inference_parameters(cls, value: Any) -> Any:
"""Convert raw dict to appropriate inference parameters type based on field presence."""
if isinstance(value, dict):
- # Infer type from presence of embedding-specific fields
- if "encoding_format" in value or "dimensions" in value:
+ # Check for explicit generation_type first
+ gen_type = value.get("generation_type")
+
+ # Infer type from generation_type or field presence
+ if gen_type == "image":
+ return ImageInferenceParams(**value)
+ elif gen_type == "embedding" or "encoding_format" in value or "dimensions" in value:
return EmbeddingInferenceParams(**value)
else:
return ChatCompletionInferenceParams(**value)
diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py
new file mode 100644
index 00000000..45f43622
--- /dev/null
+++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py
@@ -0,0 +1,269 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Helper utilities for working with images."""
+
+from __future__ import annotations
+
+import base64
+import io
+import re
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+import requests
+
+from data_designer.config.models import ImageFormat
+from data_designer.lazy_heavy_imports import Image
+
+if TYPE_CHECKING:
+ from PIL import Image
+
+# Magic bytes for image format detection
+IMAGE_FORMAT_MAGIC_BYTES = {
+ ImageFormat.PNG: b"\x89PNG\r\n\x1a\n",
+ ImageFormat.JPG: b"\xff\xd8\xff",
+ ImageFormat.GIF: b"GIF8",
+ # WEBP uses RIFF header - handled separately
+}
+
+# Maps PIL format name (lowercase) to our ImageFormat enum.
+# PIL reports "JPEG" (not "JPG"), so we normalize it here.
+_PIL_FORMAT_TO_IMAGE_FORMAT: dict[str, ImageFormat] = {
+ "png": ImageFormat.PNG,
+ "jpeg": ImageFormat.JPG,
+ "jpg": ImageFormat.JPG,
+ "gif": ImageFormat.GIF,
+ "webp": ImageFormat.WEBP,
+}
+
+_BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/=]+$")
+
+# Patterns for diffusion-based image models only (use image_generation API).
+IMAGE_DIFFUSION_MODEL_PATTERNS = (
+ "dall-e-",
+ "dalle",
+ "stable-diffusion",
+ "sd-",
+ "sd_",
+ "imagen",
+ "gpt-image-",
+)
+
+SUPPORTED_IMAGE_EXTENSIONS = [f".{fmt.value.lower()}" for fmt in ImageFormat]
+
+
+def is_image_diffusion_model(model_name: str) -> bool:
+ """Return True if the model is a diffusion-based image generation model.
+
+ Args:
+ model_name: Model name or identifier (e.g. from provider).
+
+ Returns:
+ True if the model is detected as diffusion-based, False otherwise.
+ """
+ return any(pattern in model_name.lower() for pattern in IMAGE_DIFFUSION_MODEL_PATTERNS)
+
+
+def extract_base64_from_data_uri(data: str) -> str:
+ """Extract base64 from data URI or return as-is.
+
+ Handles data URIs like "data:image/png;base64,iVBORw0..." and returns
+ just the base64 portion.
+
+ Args:
+ data: Data URI (e.g., "data:image/png;base64,XXX") or plain base64
+
+ Returns:
+ Base64 string without data URI prefix
+
+ Raises:
+ ValueError: If data URI format is invalid
+ """
+ if data.startswith("data:"):
+ if "," in data:
+ return data.split(",", 1)[1]
+ raise ValueError("Invalid data URI format: missing comma separator")
+ return data
+
+
+def decode_base64_image(base64_data: str) -> bytes:
+ """Decode base64 string to image bytes.
+
+ Automatically handles data URIs by extracting the base64 portion first.
+
+ Args:
+ base64_data: Base64 string (with or without data URI prefix)
+
+ Returns:
+ Decoded image bytes
+
+ Raises:
+ ValueError: If base64 data is invalid
+ """
+ # Remove data URI prefix if present
+ base64_data = extract_base64_from_data_uri(base64_data)
+
+ try:
+ return base64.b64decode(base64_data, validate=True)
+ except Exception as e:
+ raise ValueError(f"Invalid base64 data: {e}") from e
+
+
+def detect_image_format(image_bytes: bytes) -> ImageFormat:
+ """Detect image format from bytes.
+
+ Uses magic bytes for fast detection, falls back to PIL for robust detection.
+
+ Args:
+ image_bytes: Image data as bytes
+
+ Returns:
+ Detected ImageFormat
+
+ Raises:
+ ValueError: If the image format cannot be determined
+ """
+ # Check magic bytes first (fast)
+ if image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.PNG]):
+ return ImageFormat.PNG
+ elif image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.JPG]):
+ return ImageFormat.JPG
+ elif image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.GIF]):
+ return ImageFormat.GIF
+ elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]:
+ return ImageFormat.WEBP
+
+ # Fallback to PIL for robust detection
+ try:
+ img = Image.open(io.BytesIO(image_bytes))
+ format_str = img.format.lower() if img.format else None
+ if format_str in _PIL_FORMAT_TO_IMAGE_FORMAT:
+ return _PIL_FORMAT_TO_IMAGE_FORMAT[format_str]
+ except Exception:
+ pass
+
+ raise ValueError(
+ f"Unable to detect image format (first 8 bytes: {image_bytes[:8]!r}). "
+ f"Supported formats: {', '.join(SUPPORTED_IMAGE_EXTENSIONS)}."
+ )
+
+
+def is_image_path(value: str) -> bool:
+ """Check if a string is an image file path.
+
+ Args:
+ value: String to check
+
+ Returns:
+ True if the string looks like an image file path, False otherwise
+ """
+ if not isinstance(value, str):
+ return False
+ return any(value.lower().endswith(ext) for ext in SUPPORTED_IMAGE_EXTENSIONS)
+
+
+def is_base64_image(value: str) -> bool:
+ """Check if a string is base64-encoded image data.
+
+ Args:
+ value: String to check
+
+ Returns:
+ True if the string looks like base64-encoded image data, False otherwise
+ """
+ if not isinstance(value, str):
+ return False
+ # Check if it starts with data URI scheme
+ if value.startswith("data:image/"):
+ return True
+ # Check if it looks like base64 (at least 100 chars, contains only base64 chars)
+ if len(value) > 100 and _BASE64_PATTERN.match(value[:100]):
+ try:
+ # Try to decode a small portion to verify it's valid base64
+ base64.b64decode(value[:100])
+ return True
+ except Exception:
+ return False
+ return False
+
+
+def is_image_url(value: str) -> bool:
+ """Check if a string is an image URL.
+
+ Args:
+ value: String to check
+
+ Returns:
+ True if the string looks like an image URL, False otherwise
+ """
+ if not isinstance(value, str):
+ return False
+ return value.startswith(("http://", "https://")) and any(ext in value.lower() for ext in SUPPORTED_IMAGE_EXTENSIONS)
+
+
+def load_image_path_to_base64(image_path: str, base_path: str | None = None) -> str | None:
+ """Load an image from a file path and return as base64.
+
+ Args:
+ image_path: Relative or absolute path to the image file.
+ base_path: Optional base path to resolve relative paths from.
+
+ Returns:
+ Base64-encoded image data or None if loading fails.
+ """
+ try:
+ path = Path(image_path)
+
+ # If path is not absolute, try to resolve it
+ if not path.is_absolute():
+ if base_path:
+ path = Path(base_path) / path
+ # If still not found, try current working directory
+ if not path.exists():
+ path = Path.cwd() / image_path
+
+ # Check if file exists
+ if not path.exists():
+ return None
+
+ # Read image file and convert to base64
+ with open(path, "rb") as f:
+ image_bytes = f.read()
+ return base64.b64encode(image_bytes).decode()
+ except Exception:
+ return None
+
+
+def load_image_url_to_base64(url: str, timeout: int = 60) -> str:
+ """Download an image from a URL and return as base64.
+
+ Args:
+ url: HTTP(S) URL pointing to an image.
+ timeout: Request timeout in seconds.
+
+ Returns:
+ Base64-encoded image data.
+
+ Raises:
+ requests.HTTPError: If the download fails with a non-2xx status.
+ """
+ resp = requests.get(url, timeout=timeout)
+ resp.raise_for_status()
+ return base64.b64encode(resp.content).decode()
+
+
+def validate_image(image_path: Path) -> None:
+ """Validate that an image file is readable and not corrupted.
+
+ Args:
+ image_path: Path to image file
+
+ Raises:
+ ValueError: If image is corrupted or unreadable
+ """
+ try:
+ with Image.open(image_path) as img:
+ img.verify()
+ except Exception as e:
+ raise ValueError(f"Image validation failed: {e}") from e
diff --git a/packages/data-designer-config/src/data_designer/config/utils/visualization.py b/packages/data-designer-config/src/data_designer/config/utils/visualization.py
index 4f388bce..ac4df1d8 100644
--- a/packages/data-designer-config/src/data_designer/config/utils/visualization.py
+++ b/packages/data-designer-config/src/data_designer/config/utils/visualization.py
@@ -3,6 +3,7 @@
from __future__ import annotations
+import html
import json
import os
from collections import OrderedDict
@@ -26,6 +27,13 @@
from data_designer.config.utils.code_lang import code_lang_to_syntax_lexer
from data_designer.config.utils.constants import NVIDIA_API_KEY_ENV_VAR_NAME, OPENAI_API_KEY_ENV_VAR_NAME
from data_designer.config.utils.errors import DatasetSampleDisplayError
+from data_designer.config.utils.image_helpers import (
+ extract_base64_from_data_uri,
+ is_base64_image,
+ is_image_path,
+ is_image_url,
+ load_image_path_to_base64,
+)
from data_designer.lazy_heavy_imports import np, pd
if TYPE_CHECKING:
@@ -39,6 +47,69 @@
console = Console()
+def _display_image_if_in_notebook(image_data: str, col_name: str) -> bool:
+ """Display image with caption in Jupyter notebook if available.
+
+ Args:
+ image_data: Base64-encoded image data, data URI, file path, or URL.
+ col_name: Name of the column (used for caption).
+
+ Returns:
+ True if image was displayed, False otherwise.
+ """
+ try:
+ # Check if we're in a Jupyter environment
+ from IPython.display import HTML, display
+
+ get_ipython() # This will raise NameError if not in IPython/Jupyter
+
+ # Escape column name to prevent HTML injection
+ escaped_col_name = html.escape(col_name)
+
+ # URLs: render directly as
+ if is_image_url(image_data):
+ escaped_url = html.escape(image_data)
+ html_content = f"""
+
+
πΌοΈ {escaped_col_name}
+

+
+ """
+ display(HTML(html_content))
+ return True
+
+ # File paths: load from disk and convert to base64
+ if is_image_path(image_data) and not image_data.startswith("data:image/"):
+ loaded_base64 = load_image_path_to_base64(image_data)
+ if loaded_base64 is None:
+ console.print(
+ f"[yellow]β οΈ Could not load image from path '{image_data}' for column '{col_name}'[/yellow]"
+ )
+ return False
+ base64_data = loaded_base64
+ else:
+ base64_data = image_data
+
+ # Extract base64 from data URI if present
+ img_base64 = extract_base64_from_data_uri(base64_data)
+
+ # Create HTML with caption and image in left-aligned container
+ html_content = f"""
+
+
πΌοΈ {escaped_col_name}
+

+
+ """
+ display(HTML(html_content))
+ return True
+ except (ImportError, NameError):
+ # Not in a notebook environment
+ return False
+ except Exception as e:
+ console.print(f"[yellow]β οΈ Could not display image for column '{col_name}': {e}[/yellow]")
+ return False
+
+
def get_nvidia_api_key() -> str | None:
return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)
@@ -100,7 +171,7 @@ def display_sample_record(
processors_to_display: List of processors to display the artifacts for. If None, all processors will be displayed.
hide_seed_columns: If True, seed columns will not be displayed separately.
"""
- i = index or self._display_cycle_index
+ i = self._display_cycle_index if index is None else index
try:
record = self._record_sampler_dataset.iloc[i]
@@ -223,6 +294,66 @@ def display_sample_record(
table.add_row(output_col, convert_to_row_element(record[output_col]))
render_list.append(pad_console_element(table))
+ # Collect image generation columns (will be displayed at the end)
+ image_columns = config_builder.get_columns_of_type(DataDesignerColumnType.IMAGE)
+ images_to_display_later = []
+ if len(image_columns) > 0:
+ # Check if we're in a notebook to decide display style
+ try:
+ get_ipython()
+ in_notebook = True
+ except NameError:
+ in_notebook = False
+
+ # Create table for image columns
+ table = Table(title="Images", **table_kws)
+ table.add_column("Name")
+ table.add_column("Preview")
+
+ for col in image_columns:
+ if col.drop:
+ continue
+ image_data = record[col.name]
+
+ # Handle list of images
+ if isinstance(image_data, list):
+ previews = []
+ for idx, img in enumerate(image_data):
+ if is_base64_image(img):
+ previews.append(f"[{idx}] ")
+ if in_notebook:
+ images_to_display_later.append((f"{col.name}[{idx}]", img))
+ elif is_image_url(img):
+ previews.append(f"[{idx}] ")
+ if in_notebook:
+ images_to_display_later.append((f"{col.name}[{idx}]", img))
+ elif is_image_path(img):
+ previews.append(f"[{idx}] ")
+ if in_notebook:
+ images_to_display_later.append((f"{col.name}[{idx}]", img))
+ else:
+ previews.append(f"[{idx}] {str(img)[:30]}")
+ preview = "\n".join(previews) if previews else ""
+ # Handle single image (backwards compatibility)
+ elif is_base64_image(image_data):
+ preview = f""
+ if in_notebook:
+ images_to_display_later.append((col.name, image_data))
+ elif is_image_url(image_data):
+ preview = f""
+ if in_notebook:
+ images_to_display_later.append((col.name, image_data))
+ elif is_image_path(image_data):
+ preview = f""
+ if in_notebook:
+ images_to_display_later.append((col.name, image_data))
+ else:
+ preview = str(image_data)[:100] + "..." if len(str(image_data)) > 100 else str(image_data)
+
+ table.add_row(col.name, preview)
+
+ render_list.append(pad_console_element(table))
+
for col in config_builder.get_columns_of_type(DataDesignerColumnType.LLM_CODE):
panel = Panel(
Syntax(
@@ -287,6 +418,11 @@ def display_sample_record(
console.print(Group(*render_list), markup=False)
+ # Display images at the bottom with captions (only in notebook)
+ if len(images_to_display_later) > 0:
+ for col_name, image_data in images_to_display_later:
+ _display_image_if_in_notebook(image_data, col_name)
+
def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> str:
if max_items <= 0:
diff --git a/packages/data-designer-config/src/data_designer/lazy_heavy_imports.py b/packages/data-designer-config/src/data_designer/lazy_heavy_imports.py
index be7b7185..0e95f248 100644
--- a/packages/data-designer-config/src/data_designer/lazy_heavy_imports.py
+++ b/packages/data-designer-config/src/data_designer/lazy_heavy_imports.py
@@ -35,6 +35,8 @@
"nx": "networkx",
"scipy": "scipy",
"jsonschema": "jsonschema",
+ "PIL": "PIL",
+ "Image": "PIL.Image",
}
diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py
index a069fea0..e633518d 100644
--- a/packages/data-designer-config/tests/config/test_columns.py
+++ b/packages/data-designer-config/tests/config/test_columns.py
@@ -53,6 +53,7 @@ def test_data_designer_column_type_get_display_order():
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.EMBEDDING,
+ DataDesignerColumnType.IMAGE,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EXPRESSION,
DataDesignerColumnType.CUSTOM,
diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py
index 38b8079e..564b235c 100644
--- a/packages/data-designer-config/tests/config/test_models.py
+++ b/packages/data-designer-config/tests/config/test_models.py
@@ -17,6 +17,7 @@
GenerationType,
ImageContext,
ImageFormat,
+ ImageInferenceParams,
ManualDistribution,
ManualDistributionParams,
ModalityDataType,
@@ -412,6 +413,12 @@ def test_model_config_construction():
assert model_config.inference_parameters == embedding_params
assert model_config.generation_type == GenerationType.EMBEDDING
+ # test construction with image inference parameters
+ image_params = ImageInferenceParams(extra_body={"size": "1024x1024", "quality": "hd"})
+ model_config = ModelConfig(alias="test", model="test", inference_parameters=image_params)
+ assert model_config.inference_parameters == image_params
+ assert model_config.generation_type == GenerationType.IMAGE
+
def test_model_config_generation_type_from_dict():
# Test that generation_type in dict is used to create the right inference params type
@@ -435,6 +442,30 @@ def test_model_config_generation_type_from_dict():
assert isinstance(model_config.inference_parameters, ChatCompletionInferenceParams)
assert model_config.generation_type == GenerationType.CHAT_COMPLETION
+ model_config = ModelConfig.model_validate(
+ {
+ "alias": "test",
+ "model": "image-model",
+ "inference_parameters": {
+ "generation_type": "image",
+ "extra_body": {"size": "1024x1024", "quality": "hd"},
+ },
+ }
+ )
+ assert isinstance(model_config.inference_parameters, ImageInferenceParams)
+ assert model_config.inference_parameters.extra_body == {"size": "1024x1024", "quality": "hd"}
+ assert model_config.generation_type == GenerationType.IMAGE
+
+
+def test_image_inference_params_generate_kwargs() -> None:
+ """ImageInferenceParams.generate_kwargs delegates to base; image params go via extra_body."""
+ params = ImageInferenceParams()
+ assert "quality" not in params.generate_kwargs
+ assert "size" not in params.generate_kwargs
+
+ params = ImageInferenceParams(extra_body={"size": "1024x1024", "quality": "hd"})
+ assert params.generate_kwargs.get("extra_body") == {"size": "1024x1024", "quality": "hd"}
+
def test_chat_completion_params_format_for_display_all_params():
"""Test formatting chat completion model with all parameters."""
diff --git a/packages/data-designer-config/tests/config/utils/test_image_helpers.py b/packages/data-designer-config/tests/config/utils/test_image_helpers.py
new file mode 100644
index 00000000..8b2f557a
--- /dev/null
+++ b/packages/data-designer-config/tests/config/utils/test_image_helpers.py
@@ -0,0 +1,335 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import base64
+import io
+from pathlib import Path
+from unittest.mock import Mock, patch
+
+import pytest
+
+from data_designer.config.models import ImageFormat
+from data_designer.config.utils.image_helpers import (
+ decode_base64_image,
+ detect_image_format,
+ extract_base64_from_data_uri,
+ is_base64_image,
+ is_image_diffusion_model,
+ is_image_path,
+ is_image_url,
+ load_image_path_to_base64,
+ validate_image,
+)
+from data_designer.lazy_heavy_imports import Image
+
+
+@pytest.fixture
+def sample_png_bytes() -> bytes:
+ """Create a valid 1x1 PNG as raw bytes."""
+ img = Image.new("RGB", (1, 1), color="red")
+ buf = io.BytesIO()
+ img.save(buf, format="PNG")
+ return buf.getvalue()
+
+
+# ---------------------------------------------------------------------------
+# extract_base64_from_data_uri
+# ---------------------------------------------------------------------------
+
+
+def test_extract_base64_from_data_uri_with_prefix() -> None:
+ data_uri = "data:image/png;base64,iVBORw0KGgoAAAANS"
+ result = extract_base64_from_data_uri(data_uri)
+ assert result == "iVBORw0KGgoAAAANS"
+
+
+def test_extract_base64_plain_base64_without_prefix() -> None:
+ plain_base64 = "iVBORw0KGgoAAAANS"
+ result = extract_base64_from_data_uri(plain_base64)
+ assert result == plain_base64
+
+
+def test_extract_base64_invalid_data_uri_raises_error() -> None:
+ with pytest.raises(ValueError, match="Invalid data URI format: missing comma separator"):
+ extract_base64_from_data_uri("data:image/png;base64")
+
+
+# ---------------------------------------------------------------------------
+# decode_base64_image
+# ---------------------------------------------------------------------------
+
+
+def test_decode_base64_image_valid() -> None:
+ png_bytes = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01"
+ base64_data = base64.b64encode(png_bytes).decode()
+ result = decode_base64_image(base64_data)
+ assert result == png_bytes
+
+
+def test_decode_base64_image_with_data_uri() -> None:
+ png_bytes = b"\x89PNG\r\n\x1a\n"
+ base64_data = base64.b64encode(png_bytes).decode()
+ data_uri = f"data:image/png;base64,{base64_data}"
+ result = decode_base64_image(data_uri)
+ assert result == png_bytes
+
+
+def test_decode_base64_image_invalid_raises_error() -> None:
+ with pytest.raises(ValueError, match="Invalid base64 data"):
+ decode_base64_image("not-valid-base64!!!")
+
+
+# ---------------------------------------------------------------------------
+# detect_image_format (magic bytes)
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize(
+ "header_bytes,expected_format",
+ [
+ (b"\x89PNG\r\n\x1a\n" + b"\x00" * 10, ImageFormat.PNG),
+ (b"\xff\xd8\xff" + b"\x00" * 10, ImageFormat.JPG),
+ (b"RIFF" + b"\x00" * 4 + b"WEBP", ImageFormat.WEBP),
+ ],
+ ids=["png", "jpg", "webp"],
+)
+def test_detect_image_format_magic_bytes(header_bytes: bytes, expected_format: ImageFormat) -> None:
+ assert detect_image_format(header_bytes) == expected_format
+
+
+def test_detect_image_format_gif_magic_bytes(tmp_path: Path) -> None:
+ img = Image.new("RGB", (1, 1), color="red")
+ gif_path = tmp_path / "test.gif"
+ img.save(gif_path, format="GIF")
+ gif_bytes = gif_path.read_bytes()
+ assert detect_image_format(gif_bytes) == ImageFormat.GIF
+
+
+def test_detect_image_format_with_pil_fallback_jpeg() -> None:
+ mock_img = Mock()
+ mock_img.format = "JPEG"
+ test_bytes = b"\x00\x00\x00\x00"
+
+ with patch.object(Image, "open", return_value=mock_img):
+ result = detect_image_format(test_bytes)
+ assert result == ImageFormat.JPG
+
+
+def test_detect_image_format_unknown_raises_error() -> None:
+ unknown_bytes = b"\x00\x00\x00\x00" + b"\x00" * 10
+ with pytest.raises(ValueError, match="Unable to detect image format"):
+ detect_image_format(unknown_bytes)
+
+
+# ---------------------------------------------------------------------------
+# is_image_path
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize(
+ "value,expected",
+ [
+ ("/path/to/image.png", True),
+ ("image.PNG", True),
+ ("image.jpg", True),
+ ("image.jpeg", True),
+ ("/path/to/file.txt", False),
+ ("document.pdf", False),
+ ("/some.png/file.txt", False),
+ ],
+ ids=["png", "png-upper", "jpg", "jpeg", "txt", "pdf", "ext-in-dir"],
+)
+def test_is_image_path(value: str, expected: bool) -> None:
+ assert is_image_path(value) is expected
+
+
+# ---------------------------------------------------------------------------
+# is_image_url
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize(
+ "value,expected",
+ [
+ ("http://example.com/image.png", True),
+ ("https://example.com/photo.jpg", True),
+ ("https://example.com/image.png?size=large", True),
+ ("https://example.com/page.html", False),
+ ("ftp://example.com/image.png", False),
+ ],
+ ids=["http", "https", "query-params", "non-image-ext", "ftp"],
+)
+def test_is_image_url(value: str, expected: bool) -> None:
+ assert is_image_url(value) is expected
+
+
+# ---------------------------------------------------------------------------
+# is_base64_image
+# ---------------------------------------------------------------------------
+
+
+def test_is_base64_image_data_uri() -> None:
+ assert is_base64_image("data:image/png;base64,iVBORw0KGgo") is True
+
+
+def test_is_base64_image_long_valid_base64() -> None:
+ long_base64 = base64.b64encode(b"x" * 100).decode()
+ assert is_base64_image(long_base64) is True
+
+
+def test_is_base64_image_short_string() -> None:
+ assert is_base64_image("short") is False
+
+
+def test_is_base64_image_invalid_base64_decode() -> None:
+ invalid_base64 = "A" * 50 + "=" + "A" * 49 + "more text"
+ assert is_base64_image(invalid_base64) is False
+
+
+# ---------------------------------------------------------------------------
+# Non-string guard (is_image_path, is_base64_image, is_image_url)
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize(
+ "func",
+ [is_image_path, is_base64_image, is_image_url],
+ ids=["is_image_path", "is_base64_image", "is_image_url"],
+)
+@pytest.mark.parametrize("value", [123, None, []], ids=["int", "none", "list"])
+def test_non_string_input_returns_false(func: object, value: object) -> None:
+ assert func(value) is False
+
+
+# ---------------------------------------------------------------------------
+# is_image_diffusion_model
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize(
+ "model_name,expected",
+ [
+ ("dall-e-3", True),
+ ("DALL-E-2", True),
+ ("openai/dalle-2", True),
+ ("stable-diffusion-xl", True),
+ ("sd-2.1", True),
+ ("sd_1.5", True),
+ ("imagen-3", True),
+ ("google/imagen", True),
+ ("gpt-image-1", True),
+ ("gemini-3-pro-image-preview", False),
+ ("gpt-5-image", False),
+ ("flux.2-pro", False),
+ ],
+ ids=[
+ "dall-e-3",
+ "DALL-E-2",
+ "dalle-2",
+ "stable-diffusion-xl",
+ "sd-2.1",
+ "sd_1.5",
+ "imagen-3",
+ "google-imagen",
+ "gpt-image-1",
+ "gemini-not-diffusion",
+ "gpt-5-not-diffusion",
+ "flux-not-diffusion",
+ ],
+)
+def test_is_image_diffusion_model(model_name: str, expected: bool) -> None:
+ assert is_image_diffusion_model(model_name) is expected
+
+
+# ---------------------------------------------------------------------------
+# validate_image
+# ---------------------------------------------------------------------------
+
+
+def test_validate_image_valid_png(tmp_path: Path, sample_png_bytes: bytes) -> None:
+ image_path = tmp_path / "test.png"
+ image_path.write_bytes(sample_png_bytes)
+ validate_image(image_path)
+
+
+def test_validate_image_corrupted_raises_error(tmp_path: Path) -> None:
+ image_path = tmp_path / "corrupted.png"
+ image_path.write_bytes(b"not a valid image")
+ with pytest.raises(ValueError, match="Image validation failed"):
+ validate_image(image_path)
+
+
+def test_validate_image_nonexistent_raises_error(tmp_path: Path) -> None:
+ image_path = tmp_path / "nonexistent.png"
+ with pytest.raises(ValueError, match="Image validation failed"):
+ validate_image(image_path)
+
+
+# ---------------------------------------------------------------------------
+# load_image_path_to_base64
+# ---------------------------------------------------------------------------
+
+
+def test_load_image_path_to_base64_absolute_path(tmp_path: Path) -> None:
+ img = Image.new("RGB", (1, 1), color="blue")
+ image_path = tmp_path / "test.png"
+ img.save(image_path)
+
+ result = load_image_path_to_base64(str(image_path))
+ assert result is not None
+ assert len(result) > 0
+ decoded = base64.b64decode(result)
+ assert len(decoded) > 0
+
+
+def test_load_image_path_to_base64_relative_with_base_path(tmp_path: Path) -> None:
+ img = Image.new("RGB", (1, 1), color="green")
+ image_path = tmp_path / "subdir" / "test.png"
+ image_path.parent.mkdir(exist_ok=True)
+ img.save(image_path)
+
+ result = load_image_path_to_base64("subdir/test.png", base_path=str(tmp_path))
+ assert result is not None
+ assert len(result) > 0
+
+
+def test_load_image_path_to_base64_nonexistent_file() -> None:
+ result = load_image_path_to_base64("/nonexistent/path/to/image.png")
+ assert result is None
+
+
+def test_load_image_path_to_base64_relative_with_cwd_fallback(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
+ monkeypatch.chdir(tmp_path)
+
+ img = Image.new("RGB", (1, 1), color="yellow")
+ image_path = tmp_path / "test_cwd.png"
+ img.save(image_path)
+
+ result = load_image_path_to_base64("test_cwd.png")
+ assert result is not None
+ assert len(result) > 0
+
+
+def test_load_image_path_to_base64_base_path_fallback_to_cwd(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
+ monkeypatch.chdir(tmp_path)
+
+ img = Image.new("RGB", (1, 1), color="red")
+ image_path = tmp_path / "test.png"
+ img.save(image_path)
+
+ wrong_base = tmp_path / "wrong"
+ wrong_base.mkdir()
+
+ result = load_image_path_to_base64("test.png", base_path=str(wrong_base))
+ assert result is not None
+ assert len(result) > 0
+
+
+def test_load_image_path_to_base64_exception_handling(tmp_path: Path) -> None:
+ dir_path = tmp_path / "directory"
+ dir_path.mkdir()
+
+ result = load_image_path_to_base64(str(dir_path))
+ assert result is None
diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py
new file mode 100644
index 00000000..730e73bb
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py
@@ -0,0 +1,80 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from data_designer.config.column_configs import ImageColumnConfig
+from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModel, GenerationStrategy
+from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
+from data_designer.engine.processing.utils import deserialize_json_values
+
+if TYPE_CHECKING:
+ from data_designer.engine.storage.media_storage import MediaStorage
+
+
+class ImageCellGenerator(WithJinja2UserTemplateRendering, ColumnGeneratorWithModel[ImageColumnConfig]):
+ """Generator for image columns with disk or dataframe persistence.
+
+ Media storage always exists and determines behavior via its mode:
+ - DISK mode: Saves images to disk and stores relative paths in dataframe
+ - DATAFRAME mode: Stores base64 directly in dataframe
+ """
+
+ @property
+ def media_storage(self) -> MediaStorage:
+ """Get media storage from resource provider."""
+ return self._resource_provider.artifact_storage.media_storage
+
+ @staticmethod
+ def get_generation_strategy() -> GenerationStrategy:
+ return GenerationStrategy.CELL_BY_CELL
+
+ def generate(self, data: dict) -> dict:
+ """Generate image(s) and optionally save to disk.
+
+ Args:
+ data: Record data
+
+ Returns:
+ Record with image path(s) (create mode) or base64 data (preview mode) added
+ """
+ deserialized_record = deserialize_json_values(data)
+
+ # Validate required columns
+ missing_columns = list(set(self.config.required_columns) - set(data.keys()))
+ if len(missing_columns) > 0:
+ error_msg = (
+ f"There was an error preparing the Jinja2 expression template. "
+ f"The following columns {missing_columns} are missing!"
+ )
+ raise ValueError(error_msg)
+
+ # Render prompt template
+ self.prepare_jinja2_template_renderer(self.config.prompt, list(deserialized_record.keys()))
+ prompt = self.render_template(deserialized_record)
+
+ # Validate prompt is non-empty
+ if not prompt or not prompt.strip():
+ raise ValueError(f"Rendered prompt for column {self.config.name!r} is empty")
+
+ # Process multi-modal context if provided
+ multi_modal_context = None
+ if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0:
+ multi_modal_context = []
+ for context in self.config.multi_modal_context:
+ multi_modal_context.extend(context.get_contexts(deserialized_record))
+
+ # Generate images (returns list of base64 strings)
+ base64_images = self.model.generate_image(prompt=prompt, multi_modal_context=multi_modal_context)
+
+ # Store via media storage (mode determines disk vs dataframe storage)
+ # Use column name as subfolder to organize images
+ results = [
+ self.media_storage.save_base64_image(base64_image, subfolder_name=self.config.name)
+ for base64_image in base64_images
+ ]
+ data[self.config.name] = results
+
+ return data
diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/registry.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/registry.py
index 46642622..f4fc27b9 100644
--- a/packages/data-designer-engine/src/data_designer/engine/column_generators/registry.py
+++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/registry.py
@@ -8,6 +8,7 @@
CustomColumnConfig,
EmbeddingColumnConfig,
ExpressionColumnConfig,
+ ImageColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
LLMStructuredColumnConfig,
@@ -19,6 +20,7 @@
from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator
from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator
from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
+from data_designer.engine.column_generators.generators.image import ImageCellGenerator
from data_designer.engine.column_generators.generators.llm_completion import (
LLMCodeCellGenerator,
LLMJudgeCellGenerator,
@@ -52,6 +54,7 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum
registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig)
registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig)
registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig)
+ registry.register(DataDesignerColumnType.IMAGE, ImageCellGenerator, ImageColumnConfig)
if with_plugins:
for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
registry.register(
diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/generator_classification.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/generator_classification.py
index 1b891b16..1411374d 100644
--- a/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/generator_classification.py
+++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/generator_classification.py
@@ -23,6 +23,7 @@ def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType)
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EMBEDDING,
+ DataDesignerColumnType.IMAGE,
}
dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType))
return column_type in dag_column_types
@@ -37,6 +38,7 @@ def column_type_is_model_generated(column_type: str | DataDesignerColumnType) ->
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.EMBEDDING,
+ DataDesignerColumnType.IMAGE,
}
for plugin in plugin_manager.get_column_generator_plugins():
if issubclass(plugin.impl_cls, ColumnGeneratorWithModelRegistry):
diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py
index 35e7d4f8..43b817b0 100644
--- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py
+++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py
@@ -11,11 +11,12 @@
from pathlib import Path
from typing import TYPE_CHECKING
-from pydantic import BaseModel, field_validator, model_validator
+from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator, model_validator
from data_designer.config.utils.io_helpers import read_parquet_dataset
from data_designer.config.utils.type_helpers import StrEnum, resolve_string_enum
from data_designer.engine.dataset_builders.errors import ArtifactStorageError
+from data_designer.engine.storage.media_storage import MediaStorage, StorageMode
from data_designer.lazy_heavy_imports import pd
if TYPE_CHECKING:
@@ -38,12 +39,25 @@ class BatchStage(StrEnum):
class ArtifactStorage(BaseModel):
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
artifact_path: Path | str
dataset_name: str = "dataset"
final_dataset_folder_name: str = FINAL_DATASET_FOLDER_NAME
partial_results_folder_name: str = "tmp-partial-parquet-files"
dropped_columns_folder_name: str = "dropped-columns-parquet-files"
processors_outputs_folder_name: str = PROCESSORS_OUTPUTS_FOLDER_NAME
+ _media_storage: MediaStorage = PrivateAttr(default=None)
+
+ @property
+ def media_storage(self) -> MediaStorage:
+ """Access media storage instance."""
+ return self._media_storage
+
+ @media_storage.setter
+ def media_storage(self, value: MediaStorage) -> None:
+ """Set media storage instance."""
+ self._media_storage = value
@property
def artifact_path_exists(self) -> bool:
@@ -114,8 +128,22 @@ def validate_folder_names(self):
if any(char in invalid_chars for char in name):
raise ArtifactStorageError(f"π Directory name '{name}' contains invalid characters.")
+ # Initialize media storage with DISK mode by default
+ self._media_storage = MediaStorage(
+ base_path=self.base_dataset_path,
+ mode=StorageMode.DISK,
+ )
+
return self
+ def set_media_storage_mode(self, mode: StorageMode) -> None:
+ """Set media storage mode.
+
+ Args:
+ mode: StorageMode.DISK (save to disk) or StorageMode.DATAFRAME (store in memory)
+ """
+ self._media_storage.mode = mode
+
@staticmethod
def mkdir_if_needed(path: Path | str) -> Path:
"""Create the directory if it does not exist."""
diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py
index e5404d49..6e42844b 100644
--- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py
+++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py
@@ -11,7 +11,7 @@
from typing import TYPE_CHECKING, Callable
from data_designer.config.column_configs import CustomColumnConfig
-from data_designer.config.column_types import ColumnConfigT
+from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType
from data_designer.config.config_builder import BuilderConfig
from data_designer.config.data_designer_config import DataDesignerConfig
from data_designer.config.processors import (
@@ -40,6 +40,7 @@
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
from data_designer.engine.resources.resource_provider import ResourceProvider
+from data_designer.engine.storage.media_storage import StorageMode
from data_designer.lazy_heavy_imports import pd
if TYPE_CHECKING:
@@ -108,10 +109,29 @@ def build(
*,
num_records: int,
on_batch_complete: Callable[[Path], None] | None = None,
+ save_multimedia_to_disk: bool = True,
) -> Path:
+ """Build the dataset.
+
+ Args:
+ num_records: Number of records to generate.
+ on_batch_complete: Optional callback function called when each batch completes.
+ save_multimedia_to_disk: Whether to save generated multimedia (images, audio, video) to disk.
+ If False, multimedia is stored directly in the DataFrame (e.g., images as base64).
+ Default is True.
+
+ Returns:
+ Path to the generated dataset directory.
+ """
self._run_model_health_check_if_needed()
self._run_mcp_tool_check_if_needed()
self._write_builder_config()
+
+ # Set media storage mode based on parameters
+ if self._has_image_columns():
+ mode = StorageMode.DISK if save_multimedia_to_disk else StorageMode.DATAFRAME
+ self.artifact_storage.set_media_storage_mode(mode)
+
generators = self._initialize_generators()
start_time = time.perf_counter()
group_id = uuid.uuid4().hex
@@ -138,6 +158,10 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame:
self._run_model_health_check_if_needed()
self._run_mcp_tool_check_if_needed()
+ # Set media storage to DATAFRAME mode for preview - base64 stored directly in DataFrame
+ if self._has_image_columns():
+ self.artifact_storage.set_media_storage_mode(StorageMode.DATAFRAME)
+
generators = self._initialize_generators()
group_id = uuid.uuid4().hex
start_time = time.perf_counter()
@@ -154,6 +178,10 @@ def process_preview(self, dataset: pd.DataFrame) -> pd.DataFrame:
df = self._processor_runner.run_post_batch(dataset.copy(), current_batch_number=None)
return self._processor_runner.run_after_generation_on_df(df)
+ def _has_image_columns(self) -> bool:
+ """Check if config has any image generation columns."""
+ return any(col.column_type == DataDesignerColumnType.IMAGE for col in self.single_column_configs)
+
def _initialize_generators(self) -> list[ColumnGenerator]:
return [
self._registry.column_generators.get_for_config_type(type(config))(
@@ -286,6 +314,7 @@ def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max
progress_tracker.log_final()
if len(self._records_to_drop) > 0:
+ self._cleanup_dropped_record_images(self._records_to_drop)
self.batch_manager.drop_records(self._records_to_drop)
self._records_to_drop.clear()
@@ -349,6 +378,30 @@ def _initialize_processors(self, processor_configs: list[ProcessorConfig]) -> li
return processors
+ def _cleanup_dropped_record_images(self, dropped_indices: set[int]) -> None:
+ """Remove saved image files for records that will be dropped.
+
+ When a record fails during generation, any images already saved to disk
+ for that record in previous columns become dangling. This method deletes
+ those files so they don't accumulate.
+ """
+ media_storage = self.artifact_storage.media_storage
+ if not self._has_image_columns() or media_storage is None or media_storage.mode != StorageMode.DISK:
+ return
+
+ image_col_names = [
+ col.name for col in self.single_column_configs if col.column_type == DataDesignerColumnType.IMAGE
+ ]
+
+ buffer = self.batch_manager.get_current_batch(as_dataframe=False)
+ for idx in dropped_indices:
+ if idx < 0 or idx >= len(buffer):
+ continue
+ for col_name in image_col_names:
+ paths = buffer[idx].get(col_name, [])
+ for path in [paths] if isinstance(paths, str) else paths:
+ media_storage.delete_image(path)
+
def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) -> None:
"""If a worker fails, we can handle the exception here."""
logger.warning(
diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py
index 3e1ddf01..8ca1ebfd 100644
--- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py
+++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py
@@ -83,6 +83,9 @@ class ModelStructuredOutputError(DataDesignerError): ...
class ModelGenerationValidationFailureError(DataDesignerError): ...
+class ImageGenerationError(DataDesignerError): ...
+
+
class FormattedLLMErrorMessage(BaseModel):
cause: str
solution: str
diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py
index 9f5fc85c..ef328a9a 100644
--- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py
+++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py
@@ -9,16 +9,23 @@
from typing import TYPE_CHECKING, Any
from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
+from data_designer.config.utils.image_helpers import (
+ extract_base64_from_data_uri,
+ is_base64_image,
+ is_image_diffusion_model,
+ load_image_url_to_base64,
+)
from data_designer.engine.mcp.errors import MCPConfigurationError
from data_designer.engine.model_provider import ModelProviderRegistry
from data_designer.engine.models.errors import (
GenerationValidationFailureError,
+ ImageGenerationError,
catch_llm_exceptions,
get_exception_primary_cause,
)
from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
from data_designer.engine.models.parsers.errors import ParserException
-from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
+from data_designer.engine.models.usage import ImageUsageStats, ModelUsageStats, RequestUsageStats, TokenUsageStats
from data_designer.engine.models.utils import ChatMessage, prompt_to_messages
from data_designer.engine.secret_resolver import SecretResolver
from data_designer.lazy_heavy_imports import litellm
@@ -35,6 +42,32 @@ def _identity(x: Any) -> Any:
return x
+def _try_extract_base64(source: str | litellm.types.utils.ImageObject) -> str | None:
+ """Try to extract base64 image data from a data URI string or image response object.
+
+ Args:
+ source: Either a data URI string (e.g. "data:image/png;base64,...")
+ or a litellm ImageObject with b64_json/url attributes.
+
+ Returns:
+ Base64-encoded image string, or None if extraction fails.
+ """
+ try:
+ if isinstance(source, str):
+ return extract_base64_from_data_uri(source)
+
+ if getattr(source, "b64_json", None):
+ return source.b64_json
+
+ if getattr(source, "url", None):
+ return load_image_url_to_base64(source.url)
+ except Exception:
+ logger.debug(f"Failed to extract base64 from source of type {type(source).__name__}")
+ return None
+
+ return None
+
+
logger = logging.getLogger(__name__)
@@ -105,7 +138,7 @@ def completion(
raise e
finally:
if not skip_usage_tracking and response is not None:
- self._track_usage(response)
+ self._track_token_usage_from_completion(response)
def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
# Remove purpose from kwargs to avoid passing it to the model
@@ -117,50 +150,6 @@ def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
kwargs["extra_headers"] = self.model_provider.extra_headers
return kwargs
- def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None:
- if tool_alias is None:
- return None
- if self._mcp_registry is None:
- raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.")
-
- try:
- return self._mcp_registry.get_mcp(tool_alias=tool_alias)
- except ValueError as exc:
- raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc
-
- @catch_llm_exceptions
- def generate_text_embeddings(
- self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
- ) -> list[list[float]]:
- logger.debug(
- f"Generating embeddings with model {self.model_name!r}...",
- extra={
- "model": self.model_name,
- "input_count": len(input_texts),
- },
- )
- kwargs = self.consolidate_kwargs(**kwargs)
- response = None
- try:
- response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
- logger.debug(
- f"Received embeddings from model {self.model_name!r}",
- extra={
- "model": self.model_name,
- "embedding_count": len(response.data) if response.data else 0,
- "usage": self._usage_stats.model_dump(),
- },
- )
- if response.data and len(response.data) == len(input_texts):
- return [data["embedding"] for data in response.data]
- else:
- raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
- except Exception as e:
- raise e
- finally:
- if not skip_usage_tracking and response is not None:
- self._track_usage_from_embedding(response)
-
@catch_llm_exceptions
def generate(
self,
@@ -309,6 +298,208 @@ def generate(
return output_obj, messages
+ @catch_llm_exceptions
+ def generate_text_embeddings(
+ self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
+ ) -> list[list[float]]:
+ logger.debug(
+ f"Generating embeddings with model {self.model_name!r}...",
+ extra={
+ "model": self.model_name,
+ "input_count": len(input_texts),
+ },
+ )
+ kwargs = self.consolidate_kwargs(**kwargs)
+ response = None
+ try:
+ response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
+ logger.debug(
+ f"Received embeddings from model {self.model_name!r}",
+ extra={
+ "model": self.model_name,
+ "embedding_count": len(response.data) if response.data else 0,
+ "usage": self._usage_stats.model_dump(),
+ },
+ )
+ if response.data and len(response.data) == len(input_texts):
+ return [data["embedding"] for data in response.data]
+ else:
+ raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
+ except Exception as e:
+ raise e
+ finally:
+ if not skip_usage_tracking and response is not None:
+ self._track_token_usage_from_embedding(response)
+
+ @catch_llm_exceptions
+ def generate_image(
+ self,
+ prompt: str,
+ multi_modal_context: list[dict[str, Any]] | None = None,
+ skip_usage_tracking: bool = False,
+ **kwargs,
+ ) -> list[str]:
+ """Generate image(s) and return base64-encoded data.
+
+ Automatically detects the appropriate API based on model name:
+ - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) β image_generation API
+ - All other models β chat/completions API (default)
+
+ Both paths return base64-encoded image data. If the API returns multiple images,
+ all are returned in the list.
+
+ Args:
+ prompt: The prompt for image generation
+ multi_modal_context: Optional list of image contexts for multi-modal generation.
+ Only used with autoregressive models via chat completions API.
+ skip_usage_tracking: Whether to skip usage tracking
+ **kwargs: Additional arguments to pass to the model (including n=number of images)
+
+ Returns:
+ List of base64-encoded image strings (without data URI prefix)
+
+ Raises:
+ ImageGenerationError: If image generation fails or returns invalid data
+ """
+ logger.debug(
+ f"Generating image with model {self.model_name!r}...",
+ extra={"model": self.model_name, "prompt": prompt},
+ )
+
+ # Auto-detect API type based on model name
+ if is_image_diffusion_model(self.model_name):
+ images = self._generate_image_diffusion(prompt, skip_usage_tracking, **kwargs)
+ else:
+ images = self._generate_image_chat_completion(prompt, multi_modal_context, skip_usage_tracking, **kwargs)
+
+ # Track image usage
+ if not skip_usage_tracking and len(images) > 0:
+ self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images)))
+
+ return images
+
+ def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None:
+ if tool_alias is None:
+ return None
+ if self._mcp_registry is None:
+ raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.")
+
+ try:
+ return self._mcp_registry.get_mcp(tool_alias=tool_alias)
+ except ValueError as exc:
+ raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc
+
+ def _generate_image_chat_completion(
+ self,
+ prompt: str,
+ multi_modal_context: list[dict[str, Any]] | None = None,
+ skip_usage_tracking: bool = False,
+ **kwargs,
+ ) -> list[str]:
+ """Generate image(s) using autoregressive model via chat completions API.
+
+ Args:
+ prompt: The prompt for image generation
+ multi_modal_context: Optional list of image contexts for multi-modal generation
+ skip_usage_tracking: Whether to skip usage tracking
+ **kwargs: Additional arguments to pass to the model
+
+ Returns:
+ List of base64-encoded image strings
+ """
+ messages = prompt_to_messages(user_prompt=prompt, multi_modal_context=multi_modal_context)
+
+ response = None
+ try:
+ response = self.completion(
+ messages=messages,
+ skip_usage_tracking=skip_usage_tracking,
+ **kwargs,
+ )
+
+ logger.debug(
+ f"Received image(s) from autoregressive model {self.model_name!r}",
+ extra={"model": self.model_name, "response": response},
+ )
+
+ # Validate response structure
+ if not response.choices or len(response.choices) == 0:
+ raise ImageGenerationError("Image generation response missing choices")
+
+ message = response.choices[0].message
+ images = []
+
+ # Extract base64 from images attribute (primary path)
+ if hasattr(message, "images") and message.images:
+ for image in message.images:
+ # Handle different response formats
+ if isinstance(image, dict) and "image_url" in image:
+ image_url = image["image_url"]
+
+ if isinstance(image_url, dict) and "url" in image_url:
+ if (b64 := _try_extract_base64(image_url["url"])) is not None:
+ images.append(b64)
+ elif isinstance(image_url, str):
+ if (b64 := _try_extract_base64(image_url)) is not None:
+ images.append(b64)
+ # Fallback: treat as base64 string
+ elif isinstance(image, str):
+ if (b64 := _try_extract_base64(image)) is not None:
+ images.append(b64)
+
+ # Fallback: check content field if it looks like image data
+ if not images:
+ content = message.content or ""
+ if content and (content.startswith("data:image/") or is_base64_image(content)):
+ if (b64 := _try_extract_base64(content)) is not None:
+ images.append(b64)
+
+ if not images:
+ raise ImageGenerationError("No image data found in image generation response")
+
+ return images
+
+ except Exception:
+ raise
+
+ def _generate_image_diffusion(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]:
+ """Generate image(s) using diffusion model via image_generation API.
+
+ Always returns base64. If the API returns URLs instead of inline base64,
+ the images are downloaded and converted automatically.
+
+ Returns:
+ List of base64-encoded image strings
+ """
+ kwargs = self.consolidate_kwargs(**kwargs)
+
+ response = None
+
+ try:
+ response = self._router.image_generation(prompt=prompt, model=self.model_name, **kwargs)
+
+ logger.debug(
+ f"Received {len(response.data)} image(s) from diffusion model {self.model_name!r}",
+ extra={"model": self.model_name, "response": response},
+ )
+
+ # Validate response
+ if not response.data or len(response.data) == 0:
+ raise ImageGenerationError("Image generation returned no data")
+
+ images = [b64 for img in response.data if (b64 := _try_extract_base64(img)) is not None]
+
+ if not images:
+ raise ImageGenerationError("No image data could be extracted from response")
+
+ return images
+
+ except Exception:
+ raise
+ finally:
+ if not skip_usage_tracking and response is not None:
+ self._track_token_usage_from_image_diffusion(response)
+
def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
provider = self._model_provider_registry.get_provider(model_config.provider)
api_key = None
@@ -326,7 +517,7 @@ def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.Deployme
"litellm_params": litellm_params.model_dump(),
}
- def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> None:
+ def _track_token_usage_from_completion(self, response: litellm.types.utils.ModelResponse | None) -> None:
if response is None:
self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
return
@@ -343,7 +534,7 @@ def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> No
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
)
- def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None:
+ def _track_token_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None:
if response is None:
self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
return
@@ -355,3 +546,21 @@ def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingRes
),
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
)
+
+ def _track_token_usage_from_image_diffusion(self, response: litellm.types.utils.ImageResponse | None) -> None:
+ """Track token usage from image_generation API response."""
+ if response is None:
+ self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
+ return
+
+ if response.usage is not None and isinstance(response.usage, litellm.types.utils.ImageUsage):
+ self._usage_stats.extend(
+ token_usage=TokenUsageStats(
+ input_tokens=response.usage.input_tokens,
+ output_tokens=response.usage.output_tokens,
+ ),
+ request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
+ )
+ else:
+ # Successful response but no token usage data (some providers don't report it)
+ self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=1, failed_requests=0))
diff --git a/packages/data-designer-engine/src/data_designer/engine/models/registry.py b/packages/data-designer-engine/src/data_designer/engine/models/registry.py
index 03bef223..0b103e76 100644
--- a/packages/data-designer-engine/src/data_designer/engine/models/registry.py
+++ b/packages/data-designer-engine/src/data_designer/engine/models/registry.py
@@ -120,6 +120,10 @@ def log_model_usage(self, total_time_elapsed: float) -> None:
f"turns={tool_usage['total_tool_call_turns']}"
)
+ if image_usage := stats.get("image_usage"):
+ total_images = image_usage["total_images"]
+ logger.info(f"{LOG_INDENT}images: total={total_images}")
+
if model_index < len(sorted_model_names) - 1:
logger.info(LOG_INDENT.rstrip())
@@ -183,6 +187,12 @@ def run_health_check(self, model_aliases: list[str]) -> None:
skip_usage_tracking=True,
purpose="running health checks",
)
+ elif model.model_generation_type == GenerationType.IMAGE:
+ model.generate_image(
+ prompt="Generate a simple illustration of a thumbs up sign.",
+ skip_usage_tracking=True,
+ purpose="running health checks",
+ )
else:
raise ValueError(f"Unsupported generation type: {model.model_generation_type}")
logger.info(f"{LOG_INDENT}β
Passed!")
diff --git a/packages/data-designer-engine/src/data_designer/engine/models/usage.py b/packages/data-designer-engine/src/data_designer/engine/models/usage.py
index f44a31ae..64e82b47 100644
--- a/packages/data-designer-engine/src/data_designer/engine/models/usage.py
+++ b/packages/data-designer-engine/src/data_designer/engine/models/usage.py
@@ -71,14 +71,27 @@ def merge(self, other: ToolUsageStats) -> ToolUsageStats:
return self
+class ImageUsageStats(BaseModel):
+ total_images: int = 0
+
+ @property
+ def has_usage(self) -> bool:
+ return self.total_images > 0
+
+ def extend(self, *, images: int) -> None:
+ """Extend stats with generated images count."""
+ self.total_images += images
+
+
class ModelUsageStats(BaseModel):
token_usage: TokenUsageStats = TokenUsageStats()
request_usage: RequestUsageStats = RequestUsageStats()
tool_usage: ToolUsageStats = ToolUsageStats()
+ image_usage: ImageUsageStats = ImageUsageStats()
@property
def has_usage(self) -> bool:
- return self.token_usage.has_usage and self.request_usage.has_usage
+ return self.token_usage.has_usage or self.request_usage.has_usage or self.image_usage.has_usage
def extend(
self,
@@ -86,6 +99,7 @@ def extend(
token_usage: TokenUsageStats | None = None,
request_usage: RequestUsageStats | None = None,
tool_usage: ToolUsageStats | None = None,
+ image_usage: ImageUsageStats | None = None,
) -> None:
if token_usage is not None:
self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens)
@@ -95,9 +109,16 @@ def extend(
)
if tool_usage is not None:
self.tool_usage.merge(tool_usage)
+ if image_usage is not None:
+ self.image_usage.extend(images=image_usage.total_images)
def get_usage_stats(self, *, total_time_elapsed: float) -> dict:
- exclude = {"tool_usage"} if not self.tool_usage.has_usage else None
+ exclude = set()
+ if not self.tool_usage.has_usage:
+ exclude.add("tool_usage")
+ if not self.image_usage.has_usage:
+ exclude.add("image_usage")
+ exclude = exclude if exclude else None
return self.model_dump(exclude=exclude) | {
"tokens_per_second": int(self.token_usage.total_tokens / total_time_elapsed)
if total_time_elapsed > 0
diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py b/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py
new file mode 100644
index 00000000..9d416c65
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py
@@ -0,0 +1,6 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from data_designer.engine.storage.media_storage import MediaStorage, StorageMode
+
+__all__ = ["MediaStorage", "StorageMode"]
diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py
new file mode 100644
index 00000000..1c887c80
--- /dev/null
+++ b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py
@@ -0,0 +1,153 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import uuid
+from pathlib import Path
+
+from data_designer.config.utils.image_helpers import decode_base64_image, detect_image_format, validate_image
+from data_designer.config.utils.type_helpers import StrEnum
+
+IMAGES_SUBDIR = "images"
+
+
+class StorageMode(StrEnum):
+ """Storage mode for generated media content.
+
+ - DISK: Save media to disk and store relative paths in dataframe (for dataset creation)
+ - DATAFRAME: Store base64 data directly in dataframe (for preview mode)
+ """
+
+ DISK = "disk"
+ DATAFRAME = "dataframe"
+
+
+class MediaStorage:
+ """Manages storage of generated media content.
+
+ Currently supports:
+ - Images (PNG, JPG, WEBP)
+
+ Storage modes:
+ - DISK: Save media to disk and return relative paths (for dataset creation)
+ - DATAFRAME: Return base64 data directly (for preview mode)
+
+ Handles:
+ - Creating storage directories
+ - Decoding base64 to bytes
+ - Detecting media format
+ - Saving with UUID filenames (DISK mode)
+ - Returning relative paths or base64 data based on mode
+ - Always validates images to ensure data quality
+ """
+
+ def __init__(
+ self, base_path: Path, images_subdir: str = IMAGES_SUBDIR, mode: StorageMode = StorageMode.DISK
+ ) -> None:
+ """Initialize media storage manager.
+
+ Args:
+ base_path: Base directory for dataset
+ images_subdir: Subdirectory name for images (default: "images")
+ mode: Storage mode - DISK (save to disk) or DATAFRAME (return base64)
+ """
+ self.base_path = Path(base_path)
+ self.images_dir = self.base_path / images_subdir
+ self.images_subdir = images_subdir
+ self.mode = mode
+
+ def save_base64_image(self, base64_data: str, subfolder_name: str) -> str:
+ """Save or return base64 image based on storage mode.
+
+ Args:
+ base64_data: Base64 encoded image string (with or without data URI prefix)
+ subfolder_name: Subfolder name to organize images (e.g., "images//")
+
+ Returns:
+ DISK mode: Relative path to saved image (e.g., "images/subfolder_name/f47ac10b-58cc.png")
+ DATAFRAME mode: Original base64 data string
+
+ Raises:
+ ValueError: If base64 data is invalid (DISK mode only)
+ OSError: If disk write fails (DISK mode only)
+ """
+ # DATAFRAME mode: return base64 directly without disk operations
+ if self.mode == StorageMode.DATAFRAME:
+ return base64_data
+
+ # DISK mode: save to disk, validate, and return relative path
+ # Sanitize subfolder name to prevent path traversal
+ sanitized_subfolder = self._sanitize_subfolder_name(subfolder_name)
+
+ # Determine the target directory (organized by subfolder)
+ target_dir = self.images_dir / sanitized_subfolder
+
+ # Ensure target directory exists (lazy initialization)
+ target_dir.mkdir(parents=True, exist_ok=True)
+
+ # Decode base64 to bytes
+ image_bytes = decode_base64_image(base64_data)
+
+ # Detect format
+ image_format = detect_image_format(image_bytes)
+
+ # Generate unique filename
+ image_id = uuid.uuid4()
+ filename = f"{image_id}.{image_format.value}"
+ full_path = target_dir / filename
+
+ # Build relative path
+ relative_path = f"{self.images_subdir}/{sanitized_subfolder}/{filename}"
+
+ # Write to disk
+ with open(full_path, "wb") as f:
+ f.write(image_bytes)
+
+ # Always validate in DISK mode to ensure data quality
+ self._validate_image(full_path)
+
+ return relative_path
+
+ def delete_image(self, relative_path: str) -> bool:
+ """Delete a saved image file given its relative path.
+
+ Args:
+ relative_path: Relative path as returned by save_base64_image (e.g., "images/col/uuid.png")
+
+ Returns:
+ True if the file was deleted, False if it didn't exist or deletion failed.
+ """
+ try:
+ full_path = self.base_path / relative_path
+ if full_path.exists() and self.images_dir in full_path.parents:
+ full_path.unlink()
+ return True
+ except OSError:
+ pass
+ return False
+
+ def _validate_image(self, image_path: Path) -> None:
+ """Validate that saved image is readable.
+
+ Args:
+ image_path: Path to image file
+
+ Raises:
+ ValueError: If image is corrupted or unreadable
+ """
+ try:
+ validate_image(image_path)
+ except ValueError:
+ # Clean up invalid file
+ image_path.unlink(missing_ok=True)
+ raise
+
+ def _ensure_images_directory(self) -> None:
+ """Create images directory if it doesn't exist (lazy initialization)."""
+ self.images_dir.mkdir(parents=True, exist_ok=True)
+
+ def _sanitize_subfolder_name(self, name: str) -> str:
+ """Sanitize subfolder name to prevent path traversal and filesystem issues."""
+ # Replace path separators and parent directory references with underscores
+ return name.replace("/", "_").replace("\\", "_").replace("..", "_")
diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py
new file mode 100644
index 00000000..ca5cbfae
--- /dev/null
+++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py
@@ -0,0 +1,218 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from unittest.mock import Mock, patch
+
+import pytest
+
+from data_designer.config.column_configs import ImageColumnConfig
+from data_designer.config.models import ImageContext, ImageFormat, ModalityDataType
+from data_designer.engine.column_generators.generators.base import GenerationStrategy
+from data_designer.engine.column_generators.generators.image import ImageCellGenerator
+from data_designer.engine.processing.ginja.exceptions import UserTemplateError
+
+
+@pytest.fixture
+def stub_image_column_config():
+ return ImageColumnConfig(name="test_image", prompt="A {{ style }} image of {{ subject }}", model_alias="test_model")
+
+
+@pytest.fixture
+def stub_base64_images() -> list[str]:
+ return ["base64_image_1", "base64_image_2"]
+
+
+def test_image_cell_generator_generation_strategy(
+ stub_image_column_config: ImageColumnConfig, stub_resource_provider: None
+) -> None:
+ generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider)
+ assert generator.get_generation_strategy() == GenerationStrategy.CELL_BY_CELL
+
+
+def test_image_cell_generator_media_storage_property(
+ stub_image_column_config: ImageColumnConfig, stub_resource_provider: None
+) -> None:
+ generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider)
+ # Should return media_storage from artifact_storage (always exists)
+ assert generator.media_storage is not None
+
+
+def test_image_cell_generator_generate_with_storage(
+ stub_image_column_config, stub_resource_provider, stub_base64_images
+):
+ """Test generate with media storage (create mode) - saves to disk."""
+ # Setup mock media storage
+ mock_storage = Mock()
+ mock_storage.save_base64_image.side_effect = [
+ "images/test_image/uuid1.png",
+ "images/test_image/uuid2.png",
+ ]
+ stub_resource_provider.artifact_storage.media_storage = mock_storage
+
+ with patch.object(
+ stub_resource_provider.model_registry.get_model.return_value,
+ "generate_image",
+ return_value=stub_base64_images,
+ ) as mock_generate:
+ generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider)
+ data = generator.generate(data={"style": "photorealistic", "subject": "cat"})
+
+ # Check that column was added with relative paths (organized in subfolder)
+ assert stub_image_column_config.name in data
+ assert data[stub_image_column_config.name] == [
+ "images/test_image/uuid1.png",
+ "images/test_image/uuid2.png",
+ ]
+
+ # Verify model was called with rendered prompt
+ mock_generate.assert_called_once_with(prompt="A photorealistic image of cat", multi_modal_context=None)
+
+ # Verify storage was called for each image with subfolder name
+ assert mock_storage.save_base64_image.call_count == 2
+ mock_storage.save_base64_image.assert_any_call("base64_image_1", subfolder_name="test_image")
+ mock_storage.save_base64_image.assert_any_call("base64_image_2", subfolder_name="test_image")
+
+
+def test_image_cell_generator_generate_in_dataframe_mode(
+ stub_image_column_config, stub_resource_provider, stub_base64_images
+):
+ """Test generate with media storage in DATAFRAME mode - stores base64 directly."""
+ # Mock save_base64_image to return base64 directly (simulating DATAFRAME mode)
+ mock_storage = Mock()
+ mock_storage.save_base64_image.side_effect = stub_base64_images
+ stub_resource_provider.artifact_storage.media_storage = mock_storage
+
+ with patch.object(
+ stub_resource_provider.model_registry.get_model.return_value,
+ "generate_image",
+ return_value=stub_base64_images,
+ ) as mock_generate:
+ generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider)
+ data = generator.generate(data={"style": "watercolor", "subject": "dog"})
+
+ # Check that column was added with base64 data (simulating DATAFRAME mode)
+ assert stub_image_column_config.name in data
+ assert data[stub_image_column_config.name] == stub_base64_images
+
+ # Verify model was called with rendered prompt
+ mock_generate.assert_called_once_with(prompt="A watercolor image of dog", multi_modal_context=None)
+
+ # Verify storage was called for each image with subfolder name (even in DATAFRAME mode)
+ assert mock_storage.save_base64_image.call_count == 2
+ mock_storage.save_base64_image.assert_any_call("base64_image_1", subfolder_name="test_image")
+ mock_storage.save_base64_image.assert_any_call("base64_image_2", subfolder_name="test_image")
+
+
+def test_image_cell_generator_missing_columns_error(stub_image_column_config, stub_resource_provider):
+ """Test that missing required columns raises ValueError."""
+ generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider)
+
+ with pytest.raises(ValueError, match="columns.*missing"):
+ # Missing 'subject' column
+ generator.generate(data={"style": "photorealistic"})
+
+
+def test_image_cell_generator_empty_prompt_error(stub_resource_provider):
+ """Test that empty rendered prompt raises UserTemplateError."""
+ # Create config with template that renders to empty string
+ config = ImageColumnConfig(name="test_image", prompt="{{ empty }}", model_alias="test_model")
+
+ generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider)
+
+ with pytest.raises(UserTemplateError):
+ generator.generate(data={"empty": ""})
+
+
+def test_image_cell_generator_whitespace_only_prompt_error(stub_resource_provider):
+ """Test that whitespace-only rendered prompt raises ValueError."""
+ config = ImageColumnConfig(name="test_image", prompt="{{ spaces }}", model_alias="test_model")
+
+ generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider)
+
+ with pytest.raises(ValueError, match="empty"):
+ generator.generate(data={"spaces": " "})
+
+
+def test_image_cell_generator_with_multi_modal_context(stub_resource_provider):
+ """Test generate with multi-modal context for autoregressive models."""
+ # Create image context that references a column with URL
+ image_context = ImageContext(column_name="reference_image", data_type=ModalityDataType.URL)
+
+ config = ImageColumnConfig(
+ name="test_image",
+ prompt="Generate a similar image to the reference",
+ model_alias="test_model",
+ multi_modal_context=[image_context],
+ )
+
+ # Setup mock media storage
+ mock_storage = Mock()
+ mock_storage.save_base64_image.return_value = "images/generated.png"
+ stub_resource_provider.artifact_storage.media_storage = mock_storage
+
+ stub_base64_images = ["base64_generated_image"]
+
+ with patch.object(
+ stub_resource_provider.model_registry.get_model.return_value,
+ "generate_image",
+ return_value=stub_base64_images,
+ ) as mock_generate:
+ generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider)
+ data = generator.generate(data={"reference_image": "https://example.com/image.png"})
+
+ # Check that column was added
+ assert config.name in data
+ assert data[config.name] == ["images/generated.png"]
+
+ # Verify model was called with prompt and multi_modal_context
+ mock_generate.assert_called_once()
+ call_args = mock_generate.call_args
+ assert call_args.kwargs["prompt"] == "Generate a similar image to the reference"
+ assert call_args.kwargs["multi_modal_context"] is not None
+ assert len(call_args.kwargs["multi_modal_context"]) == 1
+ assert call_args.kwargs["multi_modal_context"][0]["type"] == "image_url"
+ assert call_args.kwargs["multi_modal_context"][0]["image_url"] == "https://example.com/image.png"
+
+
+def test_image_cell_generator_with_base64_multi_modal_context(stub_resource_provider):
+ """Test generate with base64 multi-modal context."""
+ # Create image context that references a column with base64 data
+ image_context = ImageContext(
+ column_name="reference_image", data_type=ModalityDataType.BASE64, image_format=ImageFormat.PNG
+ )
+
+ config = ImageColumnConfig(
+ name="test_image",
+ prompt="Generate a variation of this image",
+ model_alias="test_model",
+ multi_modal_context=[image_context],
+ )
+
+ # Setup mock media storage
+ mock_storage = Mock()
+ mock_storage.save_base64_image.return_value = "images/generated.png"
+ stub_resource_provider.artifact_storage.media_storage = mock_storage
+
+ stub_base64_images = ["base64_generated_image"]
+
+ with patch.object(
+ stub_resource_provider.model_registry.get_model.return_value,
+ "generate_image",
+ return_value=stub_base64_images,
+ ) as mock_generate:
+ generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider)
+ data = generator.generate(data={"reference_image": "iVBORw0KGgoAAAANS"})
+
+ # Check that column was added
+ assert config.name in data
+ assert data[config.name] == ["images/generated.png"]
+
+ # Verify model was called with prompt and multi_modal_context
+ mock_generate.assert_called_once()
+ call_args = mock_generate.call_args
+ assert call_args.kwargs["prompt"] == "Generate a variation of this image"
+ assert call_args.kwargs["multi_modal_context"] is not None
+ assert len(call_args.kwargs["multi_modal_context"]) == 1
+ assert call_args.kwargs["multi_modal_context"][0]["type"] == "image_url"
+ # Should be formatted as data URI
+ assert "data:image/png;base64," in call_args.kwargs["multi_modal_context"][0]["image_url"]["url"]
diff --git a/packages/data-designer-engine/tests/engine/column_generators/utils/test_generator_classification.py b/packages/data-designer-engine/tests/engine/column_generators/utils/test_generator_classification.py
index e1136233..fb3cdab5 100644
--- a/packages/data-designer-engine/tests/engine/column_generators/utils/test_generator_classification.py
+++ b/packages/data-designer-engine/tests/engine/column_generators/utils/test_generator_classification.py
@@ -14,6 +14,7 @@ def test_column_type_is_model_generated() -> None:
assert column_type_is_model_generated(DataDesignerColumnType.LLM_STRUCTURED)
assert column_type_is_model_generated(DataDesignerColumnType.LLM_JUDGE)
assert column_type_is_model_generated(DataDesignerColumnType.EMBEDDING)
+ assert column_type_is_model_generated(DataDesignerColumnType.IMAGE)
assert not column_type_is_model_generated(DataDesignerColumnType.SAMPLER)
assert not column_type_is_model_generated(DataDesignerColumnType.VALIDATION)
assert not column_type_is_model_generated(DataDesignerColumnType.EXPRESSION)
@@ -29,5 +30,6 @@ def test_column_type_used_in_execution_dag() -> None:
assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_TEXT)
assert column_type_used_in_execution_dag(DataDesignerColumnType.VALIDATION)
assert column_type_used_in_execution_dag(DataDesignerColumnType.EMBEDDING)
+ assert column_type_used_in_execution_dag(DataDesignerColumnType.IMAGE)
assert not column_type_used_in_execution_dag(DataDesignerColumnType.SAMPLER)
assert not column_type_used_in_execution_dag(DataDesignerColumnType.SEED_DATASET)
diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_artifact_storage.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_artifact_storage.py
index df15b4f7..35edf892 100644
--- a/packages/data-designer-engine/tests/engine/dataset_builders/test_artifact_storage.py
+++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_artifact_storage.py
@@ -213,10 +213,11 @@ def test_artifact_storage_resolved_dataset_name(mock_datetime, tmp_path):
(af_storage.artifact_path / af_storage.dataset_name).mkdir()
assert af_storage.resolved_dataset_name == "dataset"
- # dataset path exists and is not empty
+ # dataset path exists and is not empty (create file BEFORE constructing ArtifactStorage)
+ dataset_dir = tmp_path / "dataset"
+ dataset_dir.mkdir(exist_ok=True)
+ (dataset_dir / "stub_file.txt").touch()
af_storage = ArtifactStorage(artifact_path=tmp_path)
- (af_storage.artifact_path / af_storage.dataset_name / "stub_file.txt").touch()
- print(af_storage.resolved_dataset_name)
assert af_storage.resolved_dataset_name == "dataset_01-01-2025_120304"
diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py
index 235ead55..84da6ebb 100644
--- a/packages/data-designer-engine/tests/engine/models/test_facade.py
+++ b/packages/data-designer-engine/tests/engine/models/test_facade.py
@@ -1,18 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
-from typing import Any
-from unittest.mock import patch
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+from unittest.mock import MagicMock, patch
import pytest
-from litellm.types.utils import Choices, EmbeddingResponse, Message, ModelResponse
from data_designer.engine.mcp.errors import MCPConfigurationError, MCPToolError
-from data_designer.engine.models.errors import ModelGenerationValidationFailureError
+from data_designer.engine.models.errors import ImageGenerationError, ModelGenerationValidationFailureError
from data_designer.engine.models.facade import ModelFacade
from data_designer.engine.models.parsers.errors import ParserException
from data_designer.engine.models.utils import ChatMessage
from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, StubMessage, StubResponse
+from data_designer.lazy_heavy_imports import litellm
+
+if TYPE_CHECKING:
+ import litellm
def mock_oai_response_object(response_text: str) -> StubResponse:
@@ -35,12 +40,14 @@ def stub_completion_messages() -> list[ChatMessage]:
@pytest.fixture
def stub_expected_completion_response():
- return ModelResponse(choices=Choices(message=Message(content="Test response")))
+ return litellm.types.utils.ModelResponse(
+ choices=litellm.types.utils.Choices(message=litellm.types.utils.Message(content="Test response"))
+ )
@pytest.fixture
def stub_expected_embedding_response():
- return EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2)
+ return litellm.types.utils.EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2)
@pytest.mark.parametrize(
@@ -106,9 +113,11 @@ def test_generate_with_system_prompt(
# Capture messages at call time since they get mutated after the call
captured_messages = []
- def capture_and_return(*args: Any, **kwargs: Any) -> ModelResponse:
+ def capture_and_return(*args: Any, **kwargs: Any) -> litellm.types.utils.ModelResponse:
captured_messages.append(list(args[1])) # Copy the messages list
- return ModelResponse(choices=Choices(message=Message(content="Hello!")))
+ return litellm.types.utils.ModelResponse(
+ choices=litellm.types.utils.Choices(message=litellm.types.utils.Message(content="Hello!"))
+ )
mock_completion.side_effect = capture_and_return
@@ -188,7 +197,7 @@ def test_completion_success(
stub_completion_messages: list[ChatMessage],
stub_model_configs: Any,
stub_model_facade: ModelFacade,
- stub_expected_completion_response: ModelResponse,
+ stub_expected_completion_response: litellm.types.utils.ModelResponse,
skip_usage_tracking: bool,
) -> None:
mock_router_completion.side_effect = lambda self, model, messages, **kwargs: stub_expected_completion_response
@@ -221,11 +230,13 @@ def test_completion_with_kwargs(
stub_completion_messages: list[ChatMessage],
stub_model_configs: Any,
stub_model_facade: ModelFacade,
- stub_expected_completion_response: ModelResponse,
+ stub_expected_completion_response: litellm.types.utils.ModelResponse,
) -> None:
captured_kwargs = {}
- def mock_completion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> ModelResponse:
+ def mock_completion(
+ self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any
+ ) -> litellm.types.utils.ModelResponse:
captured_kwargs.update(kwargs)
return stub_expected_completion_response
@@ -1011,3 +1022,253 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe
with patch.object(ModelFacade, "completion", new=_completion):
with pytest.raises(MCPToolError, match="Invalid tool arguments"):
model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
+
+
+# =============================================================================
+# Image generation tests
+# =============================================================================
+
+
+@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True)
+def test_generate_image_diffusion_tracks_image_usage(
+ mock_image_generation: Any,
+ stub_model_facade: ModelFacade,
+) -> None:
+ """Test that generate_image tracks image usage for diffusion models."""
+ # Mock response with 3 images
+ mock_response = litellm.types.utils.ImageResponse(
+ data=[
+ litellm.types.utils.ImageObject(b64_json="image1_base64"),
+ litellm.types.utils.ImageObject(b64_json="image2_base64"),
+ litellm.types.utils.ImageObject(b64_json="image3_base64"),
+ ]
+ )
+ mock_image_generation.return_value = mock_response
+
+ # Verify initial state
+ assert stub_model_facade.usage_stats.image_usage.total_images == 0
+
+ # Generate images
+ with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
+ images = stub_model_facade.generate_image(prompt="test prompt", n=3)
+
+ # Verify results
+ assert len(images) == 3
+ assert images == ["image1_base64", "image2_base64", "image3_base64"]
+
+ # Verify image usage was tracked
+ assert stub_model_facade.usage_stats.image_usage.total_images == 3
+ assert stub_model_facade.usage_stats.image_usage.has_usage is True
+
+
+@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
+def test_generate_image_chat_completion_tracks_image_usage(
+ mock_completion: Any,
+ stub_model_facade: ModelFacade,
+) -> None:
+ """Test that generate_image tracks image usage for chat completion models."""
+ # Mock response with images attribute (Message requires type and index per ImageURLListItem)
+ mock_message = litellm.types.utils.Message(
+ role="assistant",
+ content="",
+ images=[
+ litellm.types.utils.ImageURLListItem(
+ type="image_url", image_url={"url": "data:image/png;base64,image1"}, index=0
+ ),
+ litellm.types.utils.ImageURLListItem(
+ type="image_url", image_url={"url": "data:image/png;base64,image2"}, index=1
+ ),
+ ],
+ )
+ mock_response = litellm.types.utils.ModelResponse(choices=[litellm.types.utils.Choices(message=mock_message)])
+ mock_completion.return_value = mock_response
+
+ # Verify initial state
+ assert stub_model_facade.usage_stats.image_usage.total_images == 0
+
+ # Generate images
+ with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
+ images = stub_model_facade.generate_image(prompt="test prompt")
+
+ # Verify results
+ assert len(images) == 2
+ assert images == ["image1", "image2"]
+
+ # Verify image usage was tracked
+ assert stub_model_facade.usage_stats.image_usage.total_images == 2
+ assert stub_model_facade.usage_stats.image_usage.has_usage is True
+
+
+@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
+def test_generate_image_chat_completion_with_dict_format(
+ mock_completion: Any,
+ stub_model_facade: ModelFacade,
+) -> None:
+ """Test that generate_image handles images as dicts with image_url string."""
+ # Create mock message with images as dict with string image_url
+ mock_message = MagicMock()
+ mock_message.role = "assistant"
+ mock_message.content = ""
+ mock_message.images = [
+ {"image_url": "data:image/png;base64,image1"},
+ {"image_url": "data:image/jpeg;base64,image2"},
+ ]
+
+ mock_choice = MagicMock()
+ mock_choice.message = mock_message
+
+ mock_response = MagicMock()
+ mock_response.choices = [mock_choice]
+
+ mock_completion.return_value = mock_response
+
+ # Generate images
+ with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
+ images = stub_model_facade.generate_image(prompt="test prompt")
+
+ # Verify results
+ assert len(images) == 2
+ assert images == ["image1", "image2"]
+
+
+@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
+def test_generate_image_chat_completion_with_plain_strings(
+ mock_completion: Any,
+ stub_model_facade: ModelFacade,
+) -> None:
+ """Test that generate_image handles images as plain strings."""
+ # Create mock message with images as plain strings
+ mock_message = MagicMock()
+ mock_message.role = "assistant"
+ mock_message.content = ""
+ mock_message.images = [
+ "data:image/png;base64,image1",
+ "image2", # Plain base64 without data URI prefix
+ ]
+
+ mock_choice = MagicMock()
+ mock_choice.message = mock_message
+
+ mock_response = MagicMock()
+ mock_response.choices = [mock_choice]
+
+ mock_completion.return_value = mock_response
+
+ # Generate images
+ with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
+ images = stub_model_facade.generate_image(prompt="test prompt")
+
+ # Verify results
+ assert len(images) == 2
+ assert images == ["image1", "image2"]
+
+
+@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True)
+def test_generate_image_skip_usage_tracking(
+ mock_image_generation: Any,
+ stub_model_facade: ModelFacade,
+) -> None:
+ """Test that generate_image respects skip_usage_tracking flag."""
+ mock_response = litellm.types.utils.ImageResponse(
+ data=[
+ litellm.types.utils.ImageObject(b64_json="image1_base64"),
+ litellm.types.utils.ImageObject(b64_json="image2_base64"),
+ ]
+ )
+ mock_image_generation.return_value = mock_response
+
+ # Verify initial state
+ assert stub_model_facade.usage_stats.image_usage.total_images == 0
+
+ # Generate images with skip_usage_tracking=True
+ with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
+ images = stub_model_facade.generate_image(prompt="test prompt", skip_usage_tracking=True)
+
+ # Verify results
+ assert len(images) == 2
+
+ # Verify image usage was NOT tracked
+ assert stub_model_facade.usage_stats.image_usage.total_images == 0
+ assert stub_model_facade.usage_stats.image_usage.has_usage is False
+
+
+@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
+def test_generate_image_chat_completion_no_choices(
+ mock_completion: Any,
+ stub_model_facade: ModelFacade,
+) -> None:
+ """Test that generate_image raises ImageGenerationError when response has no choices."""
+ mock_response = litellm.types.utils.ModelResponse(choices=[])
+ mock_completion.return_value = mock_response
+
+ with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
+ with pytest.raises(ImageGenerationError, match="Image generation response missing choices"):
+ stub_model_facade.generate_image(prompt="test prompt")
+
+
+@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
+def test_generate_image_chat_completion_no_image_data(
+ mock_completion: Any,
+ stub_model_facade: ModelFacade,
+) -> None:
+ """Test that generate_image raises ImageGenerationError when no image data in response."""
+ mock_message = litellm.types.utils.Message(role="assistant", content="just text, no image")
+ mock_response = litellm.types.utils.ModelResponse(choices=[litellm.types.utils.Choices(message=mock_message)])
+ mock_completion.return_value = mock_response
+
+ with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
+ with pytest.raises(ImageGenerationError, match="No image data found in image generation response"):
+ stub_model_facade.generate_image(prompt="test prompt")
+
+
+@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True)
+def test_generate_image_diffusion_no_data(
+ mock_image_generation: Any,
+ stub_model_facade: ModelFacade,
+) -> None:
+ """Test that generate_image raises ImageGenerationError when diffusion API returns no data."""
+ mock_response = litellm.types.utils.ImageResponse(data=[])
+ mock_image_generation.return_value = mock_response
+
+ with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
+ with pytest.raises(ImageGenerationError, match="Image generation returned no data"):
+ stub_model_facade.generate_image(prompt="test prompt")
+
+
+@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True)
+def test_generate_image_accumulates_usage(
+ mock_image_generation: Any,
+ stub_model_facade: ModelFacade,
+) -> None:
+ """Test that generate_image accumulates image usage across multiple calls."""
+ # First call - 2 images
+ mock_response1 = litellm.types.utils.ImageResponse(
+ data=[
+ litellm.types.utils.ImageObject(b64_json="image1"),
+ litellm.types.utils.ImageObject(b64_json="image2"),
+ ]
+ )
+ # Second call - 3 images
+ mock_response2 = litellm.types.utils.ImageResponse(
+ data=[
+ litellm.types.utils.ImageObject(b64_json="image3"),
+ litellm.types.utils.ImageObject(b64_json="image4"),
+ litellm.types.utils.ImageObject(b64_json="image5"),
+ ]
+ )
+ mock_image_generation.side_effect = [mock_response1, mock_response2]
+
+ # Verify initial state
+ assert stub_model_facade.usage_stats.image_usage.total_images == 0
+
+ # First generation
+ with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
+ images1 = stub_model_facade.generate_image(prompt="test1")
+ assert len(images1) == 2
+ assert stub_model_facade.usage_stats.image_usage.total_images == 2
+
+ # Second generation
+ images2 = stub_model_facade.generate_image(prompt="test2")
+ assert len(images2) == 3
+ # Usage should accumulate
+ assert stub_model_facade.usage_stats.image_usage.total_images == 5
diff --git a/packages/data-designer-engine/tests/engine/models/test_usage.py b/packages/data-designer-engine/tests/engine/models/test_usage.py
index 8e7adb04..2bfea4b4 100644
--- a/packages/data-designer-engine/tests/engine/models/test_usage.py
+++ b/packages/data-designer-engine/tests/engine/models/test_usage.py
@@ -1,7 +1,13 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
-from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats, ToolUsageStats
+from data_designer.engine.models.usage import (
+ ImageUsageStats,
+ ModelUsageStats,
+ RequestUsageStats,
+ TokenUsageStats,
+ ToolUsageStats,
+)
def test_token_usage_stats() -> None:
@@ -32,6 +38,20 @@ def test_request_usage_stats() -> None:
assert request_usage_stats.has_usage is True
+def test_image_usage_stats() -> None:
+ image_usage_stats = ImageUsageStats()
+ assert image_usage_stats.total_images == 0
+ assert image_usage_stats.has_usage is False
+
+ image_usage_stats.extend(images=5)
+ assert image_usage_stats.total_images == 5
+ assert image_usage_stats.has_usage is True
+
+ image_usage_stats.extend(images=3)
+ assert image_usage_stats.total_images == 8
+ assert image_usage_stats.has_usage is True
+
+
def test_tool_usage_stats_empty_state() -> None:
"""Test ToolUsageStats initialization with empty state."""
tool_usage = ToolUsageStats()
@@ -132,9 +152,10 @@ def test_model_usage_stats() -> None:
assert model_usage_stats.token_usage.output_tokens == 0
assert model_usage_stats.request_usage.successful_requests == 0
assert model_usage_stats.request_usage.failed_requests == 0
+ assert model_usage_stats.image_usage.total_images == 0
assert model_usage_stats.has_usage is False
- # tool_usage is excluded when has_usage is False
+ # tool_usage and image_usage are excluded when has_usage is False
assert model_usage_stats.get_usage_stats(total_time_elapsed=10) == {
"token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0},
"request_usage": {"successful_requests": 0, "failed_requests": 0, "total_requests": 0},
@@ -152,7 +173,7 @@ def test_model_usage_stats() -> None:
assert model_usage_stats.request_usage.failed_requests == 1
assert model_usage_stats.has_usage is True
- # tool_usage is excluded when has_usage is False
+ # tool_usage and image_usage are excluded when has_usage is False
assert model_usage_stats.get_usage_stats(total_time_elapsed=2) == {
"token_usage": {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
"request_usage": {"successful_requests": 2, "failed_requests": 1, "total_requests": 3},
@@ -177,3 +198,58 @@ def test_model_usage_stats_extend_with_tool_usage() -> None:
assert stats1.tool_usage.total_tool_call_turns == 6
assert stats1.tool_usage.total_generations == 4
assert stats1.tool_usage.generations_with_tools == 3
+
+
+def test_model_usage_stats_with_image_usage() -> None:
+ """Test that ModelUsageStats includes image_usage when it has usage."""
+ model_usage_stats = ModelUsageStats()
+ model_usage_stats.extend(
+ token_usage=TokenUsageStats(input_tokens=10, output_tokens=20),
+ request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
+ image_usage=ImageUsageStats(total_images=5),
+ )
+
+ assert model_usage_stats.image_usage.total_images == 5
+ assert model_usage_stats.image_usage.has_usage is True
+
+ # image_usage should be included in output
+ usage_stats = model_usage_stats.get_usage_stats(total_time_elapsed=2)
+ assert "image_usage" in usage_stats
+ assert usage_stats["image_usage"] == {"total_images": 5}
+
+
+def test_model_usage_stats_has_usage_any_of() -> None:
+ """Test that has_usage is True when any of token, request, or image usage is present."""
+ # Only token usage
+ stats = ModelUsageStats()
+ stats.extend(token_usage=TokenUsageStats(input_tokens=1, output_tokens=0))
+ assert stats.has_usage is True
+
+ # Only request usage (e.g. diffusion API without token counts)
+ stats = ModelUsageStats()
+ stats.extend(request_usage=RequestUsageStats(successful_requests=1, failed_requests=0))
+ assert stats.has_usage is True
+
+ # Only image usage
+ stats = ModelUsageStats()
+ stats.extend(image_usage=ImageUsageStats(total_images=2))
+ assert stats.has_usage is True
+
+ # None of the three
+ stats = ModelUsageStats()
+ assert stats.has_usage is False
+
+
+def test_model_usage_stats_exclude_unused_stats() -> None:
+ """Test that ModelUsageStats excludes tool_usage and image_usage when they have no usage."""
+ model_usage_stats = ModelUsageStats()
+ model_usage_stats.extend(
+ token_usage=TokenUsageStats(input_tokens=10, output_tokens=20),
+ request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
+ )
+
+ usage_stats = model_usage_stats.get_usage_stats(total_time_elapsed=2)
+ assert "tool_usage" not in usage_stats
+ assert "image_usage" not in usage_stats
+ assert "token_usage" in usage_stats
+ assert "request_usage" in usage_stats
diff --git a/packages/data-designer-engine/tests/engine/storage/__init__.py b/packages/data-designer-engine/tests/engine/storage/__init__.py
new file mode 100644
index 00000000..e5725ea5
--- /dev/null
+++ b/packages/data-designer-engine/tests/engine/storage/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
diff --git a/packages/data-designer-engine/tests/engine/storage/test_media_storage.py b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py
new file mode 100644
index 00000000..e79c854b
--- /dev/null
+++ b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py
@@ -0,0 +1,254 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import base64
+import io
+
+import pytest
+
+from data_designer.engine.storage.media_storage import IMAGES_SUBDIR, MediaStorage, StorageMode
+from data_designer.lazy_heavy_imports import Image
+
+
+@pytest.fixture
+def media_storage(tmp_path):
+ """Create a MediaStorage instance with a temporary directory."""
+ return MediaStorage(base_path=tmp_path)
+
+
+@pytest.fixture
+def sample_base64_png() -> str:
+ """Create a valid 1x1 PNG as base64."""
+ img = Image.new("RGB", (1, 1), color="red")
+ buf = io.BytesIO()
+ img.save(buf, format="PNG")
+ png_bytes = buf.getvalue()
+ return base64.b64encode(png_bytes).decode()
+
+
+@pytest.fixture
+def sample_base64_jpg() -> str:
+ """Create a valid 1x1 JPEG as base64."""
+ img = Image.new("RGB", (1, 1), color="blue")
+ buf = io.BytesIO()
+ img.save(buf, format="JPEG")
+ jpg_bytes = buf.getvalue()
+ return base64.b64encode(jpg_bytes).decode()
+
+
+@pytest.mark.parametrize(
+ "images_subdir,mode",
+ [
+ (IMAGES_SUBDIR, StorageMode.DISK),
+ ("custom_images", StorageMode.DATAFRAME),
+ ],
+ ids=["defaults", "custom-subdir-dataframe"],
+)
+def test_media_storage_init(tmp_path, images_subdir: str, mode: StorageMode) -> None:
+ """Test MediaStorage initialization with various configurations."""
+ storage = MediaStorage(base_path=tmp_path, images_subdir=images_subdir, mode=mode)
+ assert storage.base_path == tmp_path
+ assert storage.images_subdir == images_subdir
+ assert storage.images_dir == tmp_path / images_subdir
+ assert storage.mode == mode
+ # Directory should NOT exist until first save (lazy initialization)
+ assert not storage.images_dir.exists()
+
+
+@pytest.mark.parametrize(
+ "image_fixture,expected_extension",
+ [
+ ("sample_base64_png", ".png"),
+ ("sample_base64_jpg", ".jpg"),
+ ],
+)
+def test_save_base64_image_format(media_storage, image_fixture, expected_extension, request):
+ """Test saving images from base64 in different formats."""
+ # Get the actual fixture value using request.getfixturevalue
+ sample_base64 = request.getfixturevalue(image_fixture)
+
+ relative_path = media_storage.save_base64_image(sample_base64, subfolder_name="test_column")
+
+ # Check return value format (organized by column name)
+ assert relative_path.startswith(f"{IMAGES_SUBDIR}/test_column/")
+ assert relative_path.endswith(expected_extension)
+
+ # Check file exists on disk
+ full_path = media_storage.base_path / relative_path
+ assert full_path.exists()
+
+ # Verify file content
+ saved_bytes = full_path.read_bytes()
+ expected_bytes = base64.b64decode(sample_base64)
+ assert saved_bytes == expected_bytes
+
+
+def test_save_base64_image_with_data_uri(media_storage, sample_base64_png):
+ """Test saving image from data URI format."""
+ data_uri = f"data:image/png;base64,{sample_base64_png}"
+ relative_path = media_storage.save_base64_image(data_uri, subfolder_name="test_column")
+
+ # Should successfully extract base64 and save (organized by column name)
+ assert relative_path.startswith(f"{IMAGES_SUBDIR}/test_column/")
+ assert relative_path.endswith(".png")
+
+ # Verify file exists and content is correct
+ full_path = media_storage.base_path / relative_path
+ assert full_path.exists()
+ saved_bytes = full_path.read_bytes()
+ expected_bytes = base64.b64decode(sample_base64_png)
+ assert saved_bytes == expected_bytes
+
+
+def test_save_base64_image_invalid_base64_raises_error(media_storage):
+ """Test that invalid base64 data raises ValueError."""
+ with pytest.raises(ValueError, match="Invalid base64"):
+ media_storage.save_base64_image("not-valid-base64!!!", subfolder_name="test_column")
+
+
+def test_save_base64_image_multiple_images_unique_filenames(media_storage, sample_base64_png):
+ """Test that multiple images get unique filenames."""
+ path1 = media_storage.save_base64_image(sample_base64_png, subfolder_name="test_column")
+ path2 = media_storage.save_base64_image(sample_base64_png, subfolder_name="test_column")
+
+ # Paths should be different (different UUIDs)
+ assert path1 != path2
+
+ # Both files should exist
+ assert (media_storage.base_path / path1).exists()
+ assert (media_storage.base_path / path2).exists()
+
+
+def test_save_base64_image_disk_mode_validates(tmp_path, sample_base64_png):
+ """Test that DISK mode validates images."""
+ storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DISK)
+ # Should succeed with valid image
+ relative_path = storage.save_base64_image(sample_base64_png, subfolder_name="test_column")
+ assert relative_path.startswith(f"{IMAGES_SUBDIR}/test_column/")
+
+
+def test_save_base64_image_disk_mode_corrupted_image_raises_error(tmp_path):
+ """Test that DISK mode validates and rejects corrupted images."""
+ storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DISK)
+
+ # Create base64 of invalid image data
+ corrupted_bytes = b"not a valid image"
+ corrupted_base64 = base64.b64encode(corrupted_bytes).decode()
+
+ with pytest.raises(ValueError, match="Unable to detect image format"):
+ storage.save_base64_image(corrupted_base64, subfolder_name="test_column")
+
+ # Check that no files were left behind (cleanup on validation failure)
+ column_dir = storage.images_dir / "test_column"
+ if column_dir.exists():
+ assert len(list(column_dir.iterdir())) == 0
+
+
+@pytest.mark.parametrize("subfolder_name", ["test_column", "test_subfolder"], ids=["column", "subfolder"])
+def test_save_base64_image_dataframe_mode_returns_base64(tmp_path, sample_base64_png, subfolder_name):
+ """Test that DATAFRAME mode returns base64 directly regardless of subfolder name."""
+ storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DATAFRAME)
+
+ result = storage.save_base64_image(sample_base64_png, subfolder_name=subfolder_name)
+ assert result == sample_base64_png
+
+ # Directory should not be created in DATAFRAME mode (lazy initialization)
+ assert not storage.images_dir.exists()
+
+
+def test_save_base64_image_with_subfolder_name(media_storage, sample_base64_png):
+ """Test saving image with subfolder name organizes into subdirectory."""
+ subfolder = "test_subfolder"
+ relative_path = media_storage.save_base64_image(sample_base64_png, subfolder_name=subfolder)
+
+ # Check return value format includes subfolder
+ assert relative_path.startswith(f"{IMAGES_SUBDIR}/{subfolder}/")
+ assert relative_path.endswith(".png")
+
+ # Check file exists in correct subdirectory
+ full_path = media_storage.base_path / relative_path
+ assert full_path.exists()
+ assert full_path.parent.name == subfolder
+
+ # Verify file content
+ saved_bytes = full_path.read_bytes()
+ expected_bytes = base64.b64decode(sample_base64_png)
+ assert saved_bytes == expected_bytes
+
+
+def test_save_base64_image_with_different_subfolder_names(media_storage, sample_base64_png, sample_base64_jpg):
+ """Test that images with different subfolder names are stored in separate subdirectories."""
+ path1 = media_storage.save_base64_image(sample_base64_png, subfolder_name="subfolder_a")
+ path2 = media_storage.save_base64_image(sample_base64_jpg, subfolder_name="subfolder_b")
+
+ # Check paths are in different subdirectories
+ assert "subfolder_a" in path1
+ assert "subfolder_b" in path2
+
+ # Check both directories exist
+ subfolder_a_dir = media_storage.images_dir / "subfolder_a"
+ subfolder_b_dir = media_storage.images_dir / "subfolder_b"
+ assert subfolder_a_dir.exists()
+ assert subfolder_b_dir.exists()
+
+ # Check files exist in their respective directories
+ assert (media_storage.base_path / path1).exists()
+ assert (media_storage.base_path / path2).exists()
+
+
+@pytest.mark.parametrize(
+ "unsafe_name,expected_sanitized",
+ [
+ ("../evil", "__evil"), # Parent directory traversal: .. -> _, / -> _
+ ("foo/bar", "foo_bar"), # Path separator (forward slash)
+ ("foo\\bar", "foo_bar"), # Path separator (backslash)
+ ("test..name", "test_name"), # Double dots in middle: .. -> _
+ ],
+)
+def test_save_base64_image_sanitizes_subfolder_name(media_storage, sample_base64_png, unsafe_name, expected_sanitized):
+ """Test that subfolder names are sanitized to prevent path traversal."""
+ relative_path = media_storage.save_base64_image(sample_base64_png, subfolder_name=unsafe_name)
+
+ # Check that path contains sanitized subfolder name
+ assert expected_sanitized in relative_path
+ assert "/" not in expected_sanitized # No path separators
+ assert "\\" not in expected_sanitized # No backslashes
+ assert ".." not in expected_sanitized # No parent references
+
+ # Verify file is inside images directory (not escaped via path traversal)
+ full_path = media_storage.base_path / relative_path
+ assert full_path.exists()
+ assert media_storage.images_dir in full_path.parents
+
+
+# ---------------------------------------------------------------------------
+# delete_image
+# ---------------------------------------------------------------------------
+
+
+def test_delete_image_removes_saved_file(media_storage, sample_base64_png) -> None:
+ """Test that delete_image removes a previously saved image."""
+ relative_path = media_storage.save_base64_image(sample_base64_png, subfolder_name="col")
+ full_path = media_storage.base_path / relative_path
+ assert full_path.exists()
+
+ result = media_storage.delete_image(relative_path)
+ assert result is True
+ assert not full_path.exists()
+
+
+def test_delete_image_returns_false_for_nonexistent(media_storage) -> None:
+ """Test that delete_image returns False when the file doesn't exist."""
+ assert media_storage.delete_image(f"{IMAGES_SUBDIR}/col/nonexistent.png") is False
+
+
+def test_delete_image_rejects_path_outside_images_dir(media_storage, tmp_path) -> None:
+ """Test that delete_image refuses to delete files outside the images directory."""
+ outside_file = tmp_path / "outside.txt"
+ outside_file.write_text("should not be deleted")
+
+ result = media_storage.delete_image("../outside.txt")
+ assert result is False
+ assert outside_file.exists()
diff --git a/packages/data-designer-engine/tests/engine/test_configurable_task.py b/packages/data-designer-engine/tests/engine/test_configurable_task.py
index f20936a2..6e3673de 100644
--- a/packages/data-designer-engine/tests/engine/test_configurable_task.py
+++ b/packages/data-designer-engine/tests/engine/test_configurable_task.py
@@ -25,7 +25,7 @@ def test_configurable_task_generic_type_variables() -> None:
assert TaskConfigT.__bound__ == ConfigBase
-def test_configurable_task_concrete_implementation() -> None:
+def test_configurable_task_concrete_implementation(tmp_path) -> None:
class TestConfig(ConfigBase):
value: str
@@ -41,13 +41,8 @@ def _initialize(self) -> None:
pass
config = TestConfig(value="test")
- mock_artifact_storage = Mock(spec=ArtifactStorage)
- mock_artifact_storage.dataset_name = "test_dataset"
- mock_artifact_storage.final_dataset_folder_name = "final_dataset"
- mock_artifact_storage.partial_results_folder_name = "partial_results"
- mock_artifact_storage.dropped_columns_folder_name = "dropped_columns"
- mock_artifact_storage.processors_outputs_folder_name = "processors_outputs"
- resource_provider = ResourceProvider(artifact_storage=mock_artifact_storage)
+ artifact_storage = ArtifactStorage(artifact_path=tmp_path)
+ resource_provider = ResourceProvider(artifact_storage=artifact_storage)
task = TestTask(config=config, resource_provider=resource_provider)
@@ -55,7 +50,7 @@ def _initialize(self) -> None:
assert task._resource_provider == resource_provider
-def test_configurable_task_config_validation() -> None:
+def test_configurable_task_config_validation(tmp_path) -> None:
class TestConfig(ConfigBase):
value: str
@@ -69,13 +64,8 @@ def _validate(self) -> None:
raise ValueError("Invalid config")
config = TestConfig(value="test")
- mock_artifact_storage = Mock(spec=ArtifactStorage)
- mock_artifact_storage.dataset_name = "test_dataset"
- mock_artifact_storage.final_dataset_folder_name = "final_dataset"
- mock_artifact_storage.partial_results_folder_name = "partial_results"
- mock_artifact_storage.dropped_columns_folder_name = "dropped_columns"
- mock_artifact_storage.processors_outputs_folder_name = "processors_outputs"
- resource_provider = ResourceProvider(artifact_storage=mock_artifact_storage)
+ artifact_storage = ArtifactStorage(artifact_path=tmp_path)
+ resource_provider = ResourceProvider(artifact_storage=artifact_storage)
task = TestTask(config=config, resource_provider=resource_provider)
assert task._config.value == "test"
@@ -85,7 +75,7 @@ def _validate(self) -> None:
TestTask(config=invalid_config, resource_provider=resource_provider)
-def test_configurable_task_resource_validation() -> None:
+def test_configurable_task_resource_validation(tmp_path) -> None:
class TestConfig(ConfigBase):
value: str
@@ -102,14 +92,9 @@ def _initialize(self) -> None:
config = TestConfig(value="test")
- mock_artifact_storage = Mock(spec=ArtifactStorage)
- mock_artifact_storage.dataset_name = "test_dataset"
- mock_artifact_storage.final_dataset_folder_name = "final_dataset"
- mock_artifact_storage.partial_results_folder_name = "partial_results"
- mock_artifact_storage.dropped_columns_folder_name = "dropped_columns"
- mock_artifact_storage.processors_outputs_folder_name = "processors_outputs"
+ artifact_storage = ArtifactStorage(artifact_path=tmp_path)
mock_model_registry = Mock(spec=ModelRegistry)
- resource_provider = ResourceProvider(artifact_storage=mock_artifact_storage, model_registry=mock_model_registry)
+ resource_provider = ResourceProvider(artifact_storage=artifact_storage, model_registry=mock_model_registry)
task = TestTask(config=config, resource_provider=resource_provider)
assert task._resource_provider == resource_provider
diff --git a/packages/data-designer/src/data_designer/integrations/huggingface/client.py b/packages/data-designer/src/data_designer/integrations/huggingface/client.py
index c047d73b..1d0a0f0e 100644
--- a/packages/data-designer/src/data_designer/integrations/huggingface/client.py
+++ b/packages/data-designer/src/data_designer/integrations/huggingface/client.py
@@ -66,6 +66,7 @@ def upload_dataset(
Uploads the complete dataset including:
- Main parquet batch files from parquet-files/ β data/
+ - Images from images/ β images/ (if present)
- Processor output batch files from processors-files/{name}/ β {name}/
- Existing builder_config.json and metadata.json files
- Auto-generated README.md (dataset card)
@@ -102,6 +103,7 @@ def upload_dataset(
raise HuggingFaceHubClientUploadError(f"Failed to upload dataset card: {e}") from e
self._upload_main_dataset_files(repo_id=repo_id, parquet_folder=base_dataset_path / FINAL_DATASET_FOLDER_NAME)
+ self._upload_images_folder(repo_id=repo_id, images_folder=base_dataset_path / "images")
self._upload_processor_files(
repo_id=repo_id, processors_folder=base_dataset_path / PROCESSORS_OUTPUTS_FOLDER_NAME
)
@@ -178,6 +180,36 @@ def _upload_main_dataset_files(self, repo_id: str, parquet_folder: Path) -> None
except Exception as e:
raise HuggingFaceHubClientUploadError(f"Failed to upload parquet files: {e}") from e
+ def _upload_images_folder(self, repo_id: str, images_folder: Path) -> None:
+ """Upload images folder to Hugging Face Hub.
+
+ Args:
+ repo_id: Hugging Face dataset repo ID
+ images_folder: Path to images folder
+
+ Raises:
+ HuggingFaceUploadError: If upload fails
+ """
+ if not images_folder.exists():
+ return
+
+ image_files = list(images_folder.rglob("*.*"))
+ if not image_files:
+ return
+
+ logger.info(f" |-- {RandomEmoji.loading()} Uploading {len(image_files)} image files...")
+
+ try:
+ self._api.upload_folder(
+ repo_id=repo_id,
+ folder_path=str(images_folder),
+ path_in_repo="images",
+ repo_type="dataset",
+ commit_message="Upload images",
+ )
+ except Exception as e:
+ raise HuggingFaceHubClientUploadError(f"Failed to upload images: {e}") from e
+
def _upload_processor_files(self, repo_id: str, processors_folder: Path) -> None:
"""Upload processor output files.
diff --git a/packages/data-designer/tests/integrations/huggingface/test_client.py b/packages/data-designer/tests/integrations/huggingface/test_client.py
index 735ea3bc..924a6bfe 100644
--- a/packages/data-designer/tests/integrations/huggingface/test_client.py
+++ b/packages/data-designer/tests/integrations/huggingface/test_client.py
@@ -462,6 +462,76 @@ def test_validate_dataset_path_invalid_builder_config_json(tmp_path: Path) -> No
client.upload_dataset("test/dataset", base_path, "Test")
+def test_upload_dataset_uploads_images_folder(
+ mock_hf_api: MagicMock, mock_dataset_card: MagicMock, sample_dataset_path: Path
+) -> None:
+ """Test that upload_dataset uploads images when images folder exists with subfolders."""
+ # Create images directory with column subfolders (matches MediaStorage structure)
+ images_dir = sample_dataset_path / "images"
+ col_dir = images_dir / "my_image_column"
+ col_dir.mkdir(parents=True)
+ (col_dir / "uuid1.png").write_bytes(b"fake png data")
+ (col_dir / "uuid2.png").write_bytes(b"fake png data")
+
+ client = HuggingFaceHubClient(token="test-token")
+ client.upload_dataset(repo_id="test/dataset", base_dataset_path=sample_dataset_path, description="Test dataset")
+
+ # Check that upload_folder was called for images
+ image_calls = [call for call in mock_hf_api.upload_folder.call_args_list if call.kwargs["path_in_repo"] == "images"]
+ assert len(image_calls) == 1
+ assert image_calls[0].kwargs["folder_path"] == str(images_dir)
+ assert image_calls[0].kwargs["repo_type"] == "dataset"
+
+
+def test_upload_dataset_skips_images_when_folder_missing(
+ mock_hf_api: MagicMock, mock_dataset_card: MagicMock, sample_dataset_path: Path
+) -> None:
+ """Test that upload_dataset skips images upload when images folder doesn't exist."""
+ # sample_dataset_path has no images/ directory by default
+ client = HuggingFaceHubClient(token="test-token")
+ client.upload_dataset(repo_id="test/dataset", base_dataset_path=sample_dataset_path, description="Test dataset")
+
+ # No upload_folder call should target "images"
+ image_calls = [call for call in mock_hf_api.upload_folder.call_args_list if call.kwargs["path_in_repo"] == "images"]
+ assert len(image_calls) == 0
+
+
+def test_upload_dataset_skips_images_when_folder_empty(
+ mock_hf_api: MagicMock, mock_dataset_card: MagicMock, sample_dataset_path: Path
+) -> None:
+ """Test that upload_dataset skips images upload when images folder exists but is empty."""
+ images_dir = sample_dataset_path / "images"
+ images_dir.mkdir()
+
+ client = HuggingFaceHubClient(token="test-token")
+ client.upload_dataset(repo_id="test/dataset", base_dataset_path=sample_dataset_path, description="Test dataset")
+
+ image_calls = [call for call in mock_hf_api.upload_folder.call_args_list if call.kwargs["path_in_repo"] == "images"]
+ assert len(image_calls) == 0
+
+
+def test_upload_dataset_images_upload_failure(
+ mock_hf_api: MagicMock, mock_dataset_card: MagicMock, sample_dataset_path: Path
+) -> None:
+ """Test that upload_dataset raises error when images upload fails."""
+ # Create images directory with a file
+ images_dir = sample_dataset_path / "images"
+ col_dir = images_dir / "col"
+ col_dir.mkdir(parents=True)
+ (col_dir / "img.png").write_bytes(b"fake")
+
+ # Make upload_folder fail only for images
+ def failing_upload_folder(**kwargs):
+ if kwargs.get("path_in_repo") == "images":
+ raise Exception("Network error")
+
+ mock_hf_api.upload_folder.side_effect = failing_upload_folder
+
+ client = HuggingFaceHubClient(token="test-token")
+ with pytest.raises(HuggingFaceHubClientUploadError, match="Failed to upload images"):
+ client.upload_dataset(repo_id="test/dataset", base_dataset_path=sample_dataset_path, description="Test dataset")
+
+
def test_upload_dataset_invalid_repo_id(mock_hf_api: MagicMock, sample_dataset_path: Path) -> None:
"""Test upload_dataset fails with invalid repo_id."""
client = HuggingFaceHubClient(token="test-token")
diff --git a/pyproject.toml b/pyproject.toml
index f93bdcd5..4f8578b2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -57,7 +57,6 @@ notebooks = [
"datasets>=4.0.0,<5",
"ipykernel>=6.29.0,<7",
"jupyter>=1.0.0,<2",
- "pillow>=12.0.0,<13",
]
recipes = [
"bm25s>=0.2.0,<1",
diff --git a/uv.lock b/uv.lock
index 2271ad88..dbb512fd 100644
--- a/uv.lock
+++ b/uv.lock
@@ -799,6 +799,7 @@ dependencies = [
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "numpy", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
{ name = "pandas" },
+ { name = "pillow" },
{ name = "pyarrow" },
{ name = "pydantic", extra = ["email"] },
{ name = "pygments" },
@@ -813,6 +814,7 @@ requires-dist = [
{ name = "jinja2", specifier = ">=3.1.6,<4" },
{ name = "numpy", specifier = ">=1.23.5,<3" },
{ name = "pandas", specifier = ">=2.3.3,<3" },
+ { name = "pillow", specifier = ">=12.0.0,<13" },
{ name = "pyarrow", specifier = ">=19.0.1,<20" },
{ name = "pydantic", extras = ["email"], specifier = ">=2.9.2,<3" },
{ name = "pygments", specifier = ">=2.19.2,<3" },
@@ -902,7 +904,6 @@ notebooks = [
{ name = "datasets" },
{ name = "ipykernel" },
{ name = "jupyter" },
- { name = "pillow" },
]
recipes = [
{ name = "bm25s" },
@@ -936,7 +937,6 @@ notebooks = [
{ name = "datasets", specifier = ">=4.0.0,<5" },
{ name = "ipykernel", specifier = ">=6.29.0,<7" },
{ name = "jupyter", specifier = ">=1.0.0,<2" },
- { name = "pillow", specifier = ">=12.0.0,<13" },
]
recipes = [
{ name = "bm25s", specifier = ">=0.2.0,<1" },