diff --git a/docs/colab_notebooks/1-the-basics.ipynb b/docs/colab_notebooks/1-the-basics.ipynb index ec2c5a99..f50209f7 100644 --- a/docs/colab_notebooks/1-the-basics.ipynb +++ b/docs/colab_notebooks/1-the-basics.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "c79eea7a", + "id": "96178d08", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: The Basics\n", @@ -14,7 +14,7 @@ }, { "cell_type": "markdown", - "id": "2476f160", + "id": "1d02a1d6", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -26,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "3646f62e", + "id": "2292d817", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -37,7 +37,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3348e5c8", + "id": "8af621fc", "metadata": {}, "outputs": [], "source": [ @@ -48,7 +48,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19cd9249", + "id": "70e6a11c", "metadata": {}, "outputs": [], "source": [ @@ -66,7 +66,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5a6d13a9", + "id": "41031828", "metadata": {}, "outputs": [], "source": [ @@ -76,7 +76,7 @@ }, { "cell_type": "markdown", - "id": "d445af5b", + "id": "0b480b10", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -89,7 +89,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4df0031d", + "id": "d434a8e2", "metadata": {}, "outputs": [], "source": [ @@ -98,7 +98,7 @@ }, { "cell_type": "markdown", - "id": "0f69b576", + "id": "f88f6792", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -115,7 +115,7 @@ { "cell_type": "code", "execution_count": null, - "id": "65d9be99", + "id": "4261574c", "metadata": {}, "outputs": [], "source": [ @@ -145,7 +145,7 @@ }, { "cell_type": "markdown", - "id": "72582d09", + "id": "bbbc3d58", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -160,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8d7992b4", + "id": "92c0cf35", "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "markdown", - "id": "741a15a0", + "id": "44246c7d", "metadata": {}, "source": [ "## 🎲 Getting started with sampler columns\n", @@ -186,7 +186,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c3879c70", + "id": "07d20f3f", "metadata": {}, "outputs": [], "source": [ @@ -195,7 +195,7 @@ }, { "cell_type": "markdown", - "id": "1575ef81", + "id": "9d3c87b0", "metadata": {}, "source": [ "Let's start designing our product review dataset by adding product category and subcategory columns.\n" @@ -204,7 +204,7 @@ { "cell_type": "code", "execution_count": null, - "id": "87a88d7b", + "id": "c646b021", "metadata": {}, "outputs": [], "source": [ @@ -285,7 +285,7 @@ }, { "cell_type": "markdown", - "id": "8c74b738", + "id": "ff18b032", "metadata": {}, "source": [ "Next, let's add samplers to generate data related to the customer and their review.\n" @@ -294,7 +294,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4eb1da1f", + "id": "78846d99", "metadata": {}, "outputs": [], "source": [ @@ -331,7 +331,7 @@ }, { "cell_type": "markdown", - "id": "4324d869", + "id": "97059bfc", "metadata": {}, "source": [ "## 🦜 LLM-generated columns\n", @@ -346,7 +346,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1302a503", + "id": "98c66eff", "metadata": {}, "outputs": [], "source": [ @@ -382,7 +382,7 @@ }, { "cell_type": "markdown", - "id": "7cf8241b", + "id": "ff2d52b9", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -399,7 +399,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6fc6cf39", + "id": "6e622478", "metadata": {}, "outputs": [], "source": [ @@ -409,7 +409,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c929e068", + "id": "1addc7d8", "metadata": {}, "outputs": [], "source": [ @@ -420,7 +420,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dfb04e2a", + "id": "7af4b9c3", "metadata": {}, "outputs": [], "source": [ @@ -430,7 +430,7 @@ }, { "cell_type": "markdown", - "id": "adb879da", + "id": "91d0ee89", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -443,7 +443,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ff58dd9f", + "id": "e1e3aed0", "metadata": {}, "outputs": [], "source": [ @@ -453,7 +453,7 @@ }, { "cell_type": "markdown", - "id": "57c7355d", + "id": "6eaa402e", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -466,7 +466,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df49db99", + "id": "f6b148d4", "metadata": {}, "outputs": [], "source": [ @@ -476,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2bbc48dd", + "id": "f4e62e5b", "metadata": {}, "outputs": [], "source": [ @@ -489,7 +489,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dc0673fa", + "id": "7d426ab0", "metadata": {}, "outputs": [], "source": [ @@ -501,7 +501,7 @@ }, { "cell_type": "markdown", - "id": "7688217b", + "id": "449d003c", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", @@ -512,7 +512,9 @@ "\n", "- [Seeding synthetic data generation with an external dataset](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/3-seeding-with-a-dataset/)\n", "\n", - "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n" + "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n", + "\n", + "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/)\n" ] } ], diff --git a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb index c813ea50..a6e04680 100644 --- a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb +++ b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "258752cd", + "id": "ba22504d", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Structured Outputs and Jinja Expressions\n", @@ -16,7 +16,7 @@ }, { "cell_type": "markdown", - "id": "fc4217c3", + "id": "c176fe63", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -28,7 +28,7 @@ }, { "cell_type": "markdown", - "id": "2b831130", + "id": "32c80f72", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -39,7 +39,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fa1eda43", + "id": "4ab45e3a", "metadata": {}, "outputs": [], "source": [ @@ -50,7 +50,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5f014571", + "id": "2ae70d67", "metadata": {}, "outputs": [], "source": [ @@ -68,7 +68,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f7409282", + "id": "2cdc070b", "metadata": {}, "outputs": [], "source": [ @@ -78,7 +78,7 @@ }, { "cell_type": "markdown", - "id": "8234dd4b", + "id": "a04261b9", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -91,7 +91,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21633aed", + "id": "c8bef18a", "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "markdown", - "id": "9b215265", + "id": "ed555636", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -117,7 +117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76260638", + "id": "47208094", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "e6bfd93d", + "id": "36c200d9", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -162,7 +162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a0fbd497", + "id": "57c0d82f", "metadata": {}, "outputs": [], "source": [ @@ -171,7 +171,7 @@ }, { "cell_type": "markdown", - "id": "7faae40e", + "id": "01ff63ca", "metadata": {}, "source": [ "### πŸ§‘β€πŸŽ¨ Designing our data\n", @@ -198,7 +198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f2f94909", + "id": "4fb0f1ca", "metadata": {}, "outputs": [], "source": [ @@ -226,7 +226,7 @@ }, { "cell_type": "markdown", - "id": "696f19f4", + "id": "8f35bd87", "metadata": {}, "source": [ "Next, let's design our product review dataset using a few more tricks compared to the previous notebook.\n" @@ -235,7 +235,7 @@ { "cell_type": "code", "execution_count": null, - "id": "312b50cd", + "id": "43341f16", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +344,7 @@ }, { "cell_type": "markdown", - "id": "ecd971ca", + "id": "34c3e08b", "metadata": {}, "source": [ "Next, we will use more advanced Jinja expressions to create new columns.\n", @@ -361,7 +361,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bda01ffc", + "id": "c168c089", "metadata": {}, "outputs": [], "source": [ @@ -414,7 +414,7 @@ }, { "cell_type": "markdown", - "id": "059613e1", + "id": "7e6521a2", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -431,7 +431,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23c9b839", + "id": "03510f78", "metadata": {}, "outputs": [], "source": [ @@ -441,7 +441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5adcdbd", + "id": "ad599c43", "metadata": {}, "outputs": [], "source": [ @@ -452,7 +452,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1cc39cae", + "id": "dbd3e17c", "metadata": {}, "outputs": [], "source": [ @@ -462,7 +462,7 @@ }, { "cell_type": "markdown", - "id": "bcca3f06", + "id": "4db52c26", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -475,7 +475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6e1957ca", + "id": "f1007ac4", "metadata": {}, "outputs": [], "source": [ @@ -485,7 +485,7 @@ }, { "cell_type": "markdown", - "id": "9db283d3", + "id": "dcd68de4", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -498,7 +498,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30826883", + "id": "27b6bfe8", "metadata": {}, "outputs": [], "source": [ @@ -508,7 +508,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88d4d3bd", + "id": "d4e9a395", "metadata": {}, "outputs": [], "source": [ @@ -521,7 +521,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8762a2bb", + "id": "946b3aa8", "metadata": {}, "outputs": [], "source": [ @@ -533,7 +533,7 @@ }, { "cell_type": "markdown", - "id": "0375fcd2", + "id": "f50d996e", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", @@ -542,7 +542,9 @@ "\n", "- [Seeding synthetic data generation with an external dataset](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/3-seeding-with-a-dataset/)\n", "\n", - "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n" + "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n", + "\n", + "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/)\n" ] } ], diff --git a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb index c5d427d0..639e88df 100644 --- a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb +++ b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "b2a3e544", + "id": "25501772", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Seeding Synthetic Data Generation with an External Dataset\n", @@ -16,7 +16,7 @@ }, { "cell_type": "markdown", - "id": "d57c4f0a", + "id": "67ffc49e", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -28,7 +28,7 @@ }, { "cell_type": "markdown", - "id": "f7da8723", + "id": "54a42504", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -39,7 +39,7 @@ { "cell_type": "code", "execution_count": null, - "id": "90a12556", + "id": "05b45354", "metadata": {}, "outputs": [], "source": [ @@ -50,7 +50,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8fcdfde5", + "id": "039360fe", "metadata": {}, "outputs": [], "source": [ @@ -68,7 +68,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5899e85c", + "id": "028d5e8a", "metadata": {}, "outputs": [], "source": [ @@ -78,7 +78,7 @@ }, { "cell_type": "markdown", - "id": "6c093c90", + "id": "15a1df61", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -91,7 +91,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6a2066fe", + "id": "a87b6ff6", "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "markdown", - "id": "f5e81142", + "id": "b9166cfd", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -117,7 +117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "880012ea", + "id": "4961d3b0", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "4b77a92c", + "id": "b1d8588a", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -162,7 +162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f4ab6628", + "id": "cf42a4dd", "metadata": {}, "outputs": [], "source": [ @@ -171,7 +171,7 @@ }, { "cell_type": "markdown", - "id": "26fb0a63", + "id": "8d6b26aa", "metadata": {}, "source": [ "## πŸ₯ Prepare a seed dataset\n", @@ -196,7 +196,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84908e88", + "id": "fc90401d", "metadata": {}, "outputs": [], "source": [ @@ -214,7 +214,7 @@ }, { "cell_type": "markdown", - "id": "1947e70a", + "id": "6f5ee960", "metadata": {}, "source": [ "## 🎨 Designing our synthetic patient notes dataset\n", @@ -227,7 +227,7 @@ { "cell_type": "code", "execution_count": null, - "id": "be2fbad1", + "id": "e9db2ff0", "metadata": {}, "outputs": [], "source": [ @@ -308,7 +308,7 @@ }, { "cell_type": "markdown", - "id": "8fcce5dc", + "id": "00efc894", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -325,7 +325,7 @@ { "cell_type": "code", "execution_count": null, - "id": "82dc02f8", + "id": "3e3d824e", "metadata": {}, "outputs": [], "source": [ @@ -335,7 +335,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1f2d1583", + "id": "27785af7", "metadata": {}, "outputs": [], "source": [ @@ -346,7 +346,7 @@ { "cell_type": "code", "execution_count": null, - "id": "62a9173b", + "id": "430998d1", "metadata": {}, "outputs": [], "source": [ @@ -356,7 +356,7 @@ }, { "cell_type": "markdown", - "id": "5263e705", + "id": "dda6458b", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -369,7 +369,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5295320f", + "id": "f45bc088", "metadata": {}, "outputs": [], "source": [ @@ -379,7 +379,7 @@ }, { "cell_type": "markdown", - "id": "3ecc195f", + "id": "1e913fd8", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -392,7 +392,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3865fb59", + "id": "30b8b7f7", "metadata": {}, "outputs": [], "source": [ @@ -402,7 +402,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a7acf2b0", + "id": "b7ff96d1", "metadata": {}, "outputs": [], "source": [ @@ -415,7 +415,7 @@ { "cell_type": "code", "execution_count": null, - "id": "81a6e999", + "id": "dbfef8a8", "metadata": {}, "outputs": [], "source": [ @@ -427,14 +427,16 @@ }, { "cell_type": "markdown", - "id": "4503b1cf", + "id": "5db3f38d", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", "\n", "Check out the following notebook to learn more about:\n", "\n", - "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n" + "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n", + "\n", + "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/)\n" ] } ], diff --git a/docs/colab_notebooks/4-providing-images-as-context.ipynb b/docs/colab_notebooks/4-providing-images-as-context.ipynb index cd175537..9797695e 100644 --- a/docs/colab_notebooks/4-providing-images-as-context.ipynb +++ b/docs/colab_notebooks/4-providing-images-as-context.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "90dda708", + "id": "19e57933", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Providing Images as Context for Vision-Based Data Generation" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "52ccb1e5", + "id": "25e3cc64", "metadata": {}, "source": [ "#### πŸ“š What you'll learn\n", @@ -25,7 +25,7 @@ }, { "cell_type": "markdown", - "id": "9627c4eb", + "id": "4aae5c82", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -37,7 +37,7 @@ }, { "cell_type": "markdown", - "id": "1817171a", + "id": "24dfae6c", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -48,7 +48,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1f15a669", + "id": "619b1aae", "metadata": {}, "outputs": [], "source": [ @@ -59,7 +59,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1201c93b", + "id": "0d49a542", "metadata": {}, "outputs": [], "source": [ @@ -77,7 +77,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f814b76c", + "id": "1b28f160", "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "markdown", - "id": "ac423d57", + "id": "63dc34de", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -113,7 +113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3c655c2d", + "id": "672155c8", "metadata": {}, "outputs": [], "source": [ @@ -122,7 +122,7 @@ }, { "cell_type": "markdown", - "id": "7d41e922", + "id": "4b32c25e", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -139,7 +139,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a8b5f4bf", + "id": "72971915", "metadata": {}, "outputs": [], "source": [ @@ -162,7 +162,7 @@ }, { "cell_type": "markdown", - "id": "6455fc58", + "id": "115ad20f", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -177,7 +177,7 @@ { "cell_type": "code", "execution_count": null, - "id": "462c2e01", + "id": "11e844d2", "metadata": {}, "outputs": [], "source": [ @@ -186,7 +186,7 @@ }, { "cell_type": "markdown", - "id": "31369d10", + "id": "77862fce", "metadata": {}, "source": [ "### 🌱 Seed Dataset Creation\n", @@ -203,7 +203,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55d9432a", + "id": "e415a502", "metadata": {}, "outputs": [], "source": [ @@ -218,7 +218,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8614c4e9", + "id": "335f2611", "metadata": {}, "outputs": [], "source": [ @@ -266,7 +266,7 @@ { "cell_type": "code", "execution_count": null, - "id": "80550e46", + "id": "f055e88d", "metadata": {}, "outputs": [], "source": [ @@ -284,7 +284,7 @@ { "cell_type": "code", "execution_count": null, - "id": "65ced9bb", + "id": "47a1c586", "metadata": {}, "outputs": [], "source": [ @@ -294,7 +294,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34b210e8", + "id": "3a77fc52", "metadata": {}, "outputs": [], "source": [ @@ -306,7 +306,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d506903d", + "id": "c0941cc7", "metadata": {}, "outputs": [], "source": [ @@ -335,7 +335,7 @@ }, { "cell_type": "markdown", - "id": "b91032a2", + "id": "578e77dc", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -352,7 +352,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4bd947de", + "id": "9f0c11ce", "metadata": {}, "outputs": [], "source": [ @@ -362,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0ff4c07", + "id": "b10412c1", "metadata": {}, "outputs": [], "source": [ @@ -373,7 +373,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e97e4dfe", + "id": "766ee2d7", "metadata": {}, "outputs": [], "source": [ @@ -383,7 +383,7 @@ }, { "cell_type": "markdown", - "id": "0a284c12", + "id": "6370bfa5", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -396,7 +396,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2570e7fd", + "id": "d57ded0e", "metadata": {}, "outputs": [], "source": [ @@ -406,7 +406,7 @@ }, { "cell_type": "markdown", - "id": "28b8eb5a", + "id": "5afd8e8c", "metadata": {}, "source": [ "### πŸ”Ž Visual Inspection\n", @@ -417,7 +417,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5d0d9336", + "id": "aa4bfcc3", "metadata": { "lines_to_next_cell": 2 }, @@ -441,7 +441,7 @@ }, { "cell_type": "markdown", - "id": "1c257a81", + "id": "4eeaada6", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -454,7 +454,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6d840e9", + "id": "0ee5b1b9", "metadata": {}, "outputs": [], "source": [ @@ -464,7 +464,7 @@ { "cell_type": "code", "execution_count": null, - "id": "909e6f3f", + "id": "e5e8b241", "metadata": {}, "outputs": [], "source": [ @@ -477,7 +477,7 @@ { "cell_type": "code", "execution_count": null, - "id": "adbb4cae", + "id": "23ebb3ca", "metadata": {}, "outputs": [], "source": [ @@ -489,7 +489,7 @@ }, { "cell_type": "markdown", - "id": "d085584c", + "id": "14a78533", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", @@ -499,7 +499,9 @@ "- Experiment with different vision models for specific document types\n", "- Try different prompt variations to generate specialized descriptions (e.g., technical details, key findings)\n", "- Combine vision-based summaries with other column types for multi-modal workflows\n", - "- Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering\n" + "- Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering\n", + "\n", + "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) with Data Designer\n" ] } ], diff --git a/docs/colab_notebooks/5-generating-images.ipynb b/docs/colab_notebooks/5-generating-images.ipynb new file mode 100644 index 00000000..c8092938 --- /dev/null +++ b/docs/colab_notebooks/5-generating-images.ipynb @@ -0,0 +1,437 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "735e6197", + "metadata": {}, + "source": [ + "# 🎨 Data Designer Tutorial: Generating Images\n", + "\n", + "#### πŸ“š What you'll learn\n", + "\n", + "This notebook shows how to generate synthetic image data with Data Designer using image-generation models.\n", + "\n", + "- πŸ–ΌοΈ **Image generation columns**: Add columns that produce images from text prompts\n", + "- πŸ“ **Jinja2 prompts**: Drive diversity by referencing other columns in your prompt template\n", + "- πŸ’Ύ **Preview vs create**: Preview stores base64 in the dataframe; create saves images to disk and stores paths\n", + "\n", + "Data Designer supports both **diffusion** (e.g. DALLΒ·E, Stable Diffusion, Imagen) and **autoregressive** (e.g. Gemini image, GPT image) models; the API is chosen automatically from the model name.\n", + "\n", + "If this is your first time using Data Designer, we recommend starting with the [first notebook](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/) in this tutorial series.\n" + ] + }, + { + "cell_type": "markdown", + "id": "92ae4afe", + "metadata": {}, + "source": [ + "### πŸ“¦ Import Data Designer\n", + "\n", + "- `data_designer.config` provides the configuration API.\n", + "- `DataDesigner` is the main interface for generation.\n" + ] + }, + { + "cell_type": "markdown", + "id": "ccc77347", + "metadata": {}, + "source": [ + "### ⚑ Colab Setup\n", + "\n", + "Run the cells below to install the dependencies and set up the API key. If you don't have an API key, you can generate one from [build.nvidia.com](https://build.nvidia.com).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23627c23", + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install -U data-designer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf958dc6", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "from google.colab import userdata\n", + "\n", + "try:\n", + " os.environ[\"NVIDIA_API_KEY\"] = userdata.get(\"NVIDIA_API_KEY\")\n", + "except userdata.SecretNotFoundError:\n", + " os.environ[\"NVIDIA_API_KEY\"] = getpass.getpass(\"Enter your NVIDIA API key: \")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab0cfff8", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Image as IPImage\n", + "from IPython.display import display\n", + "\n", + "import data_designer.config as dd\n", + "from data_designer.interface import DataDesigner" + ] + }, + { + "cell_type": "markdown", + "id": "a18ef5ce", + "metadata": {}, + "source": [ + "### βš™οΈ Initialize the Data Designer interface\n", + "\n", + "When initialized without arguments, [default model providers](https://nvidia-nemo.github.io/DataDesigner/latest/concepts/models/default-model-settings/) are used. This tutorial uses [OpenRouter](https://openrouter.ai) with the Flux 2 Pro image model; set `OPENROUTER_API_KEY` in your environment.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fe11301", + "metadata": {}, + "outputs": [], + "source": [ + "data_designer = DataDesigner()" + ] + }, + { + "cell_type": "markdown", + "id": "b913d454", + "metadata": {}, + "source": [ + "### πŸŽ›οΈ Define an image-generation model\n", + "\n", + "- Use `ImageInferenceParams` so Data Designer treats this model as an image generator.\n", + "- Image options (size, quality, aspect ratio, etc.) are model-specific; pass them via `extra_body`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a50d26ee", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_PROVIDER = \"openrouter\"\n", + "MODEL_ID = \"black-forest-labs/flux.2-pro\"\n", + "MODEL_ALIAS = \"image-model\"\n", + "\n", + "model_configs = [\n", + " dd.ModelConfig(\n", + " alias=MODEL_ALIAS,\n", + " model=MODEL_ID,\n", + " provider=MODEL_PROVIDER,\n", + " inference_parameters=dd.ImageInferenceParams(\n", + " extra_body={\"height\": 512, \"width\": 512},\n", + " ),\n", + " )\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "122374d9", + "metadata": {}, + "source": [ + "### πŸ—οΈ Build the config: samplers + image column\n", + "\n", + "We'll generate diverse **dog portrait** images: sampler columns drive subject (breed), age, style, look direction, and emotion. The image-generation column uses a Jinja2 prompt that references all of them.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "940f2b70", + "metadata": {}, + "outputs": [], + "source": [ + "config_builder = dd.DataDesignerConfigBuilder(model_configs=model_configs)\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"style\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\n", + " \"photorealistic\",\n", + " \"oil painting\",\n", + " \"watercolor\",\n", + " \"digital art\",\n", + " \"sketch\",\n", + " \"anime\",\n", + " ],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"dog_breed\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\n", + " \"a Golden Retriever\",\n", + " \"a German Shepherd\",\n", + " \"a Labrador Retriever\",\n", + " \"a Bulldog\",\n", + " \"a Beagle\",\n", + " \"a Poodle\",\n", + " \"a Corgi\",\n", + " \"a Siberian Husky\",\n", + " \"a Dalmatian\",\n", + " \"a Yorkshire Terrier\",\n", + " \"a Boxer\",\n", + " \"a Dachshund\",\n", + " \"a Doberman Pinscher\",\n", + " \"a Shih Tzu\",\n", + " \"a Chihuahua\",\n", + " \"a Border Collie\",\n", + " \"an Australian Shepherd\",\n", + " \"a Cocker Spaniel\",\n", + " \"a Maltese\",\n", + " \"a Pomeranian\",\n", + " \"a Saint Bernard\",\n", + " \"a Great Dane\",\n", + " \"an Akita\",\n", + " \"a Samoyed\",\n", + " \"a Boston Terrier\",\n", + " ],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"cat_breed\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\n", + " \"a Persian\",\n", + " \"a Maine Coon\",\n", + " \"a Siamese\",\n", + " \"a Ragdoll\",\n", + " \"a Bengal\",\n", + " \"an Abyssinian\",\n", + " \"a British Shorthair\",\n", + " \"a Sphynx\",\n", + " \"a Scottish Fold\",\n", + " \"a Russian Blue\",\n", + " \"a Birman\",\n", + " \"an Oriental Shorthair\",\n", + " \"a Norwegian Forest Cat\",\n", + " \"a Devon Rex\",\n", + " \"a Burmese\",\n", + " \"an Egyptian Mau\",\n", + " \"a Tonkinese\",\n", + " \"a Himalayan\",\n", + " \"a Savannah\",\n", + " \"a Chartreux\",\n", + " \"a Somali\",\n", + " \"a Manx\",\n", + " \"a Turkish Angora\",\n", + " \"a Balinese\",\n", + " \"an American Shorthair\",\n", + " ],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"dog_age\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\"1-3\", \"3-6\", \"6-9\", \"9-12\", \"12-15\"],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"cat_age\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\"1-3\", \"3-6\", \"6-9\", \"9-12\", \"12-18\"],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"dog_look_direction\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\"left\", \"right\", \"front\", \"up\", \"down\"],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"cat_look_direction\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\"left\", \"right\", \"front\", \"up\", \"down\"],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"dog_emotion\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\"happy\", \"curious\", \"serious\", \"sleepy\", \"excited\"],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.SamplerColumnConfig(\n", + " name=\"cat_emotion\",\n", + " sampler_type=dd.SamplerType.CATEGORY,\n", + " params=dd.CategorySamplerParams(\n", + " values=[\"aloof\", \"curious\", \"content\", \"sleepy\", \"playful\"],\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " dd.ImageColumnConfig(\n", + " name=\"generated_image\",\n", + " prompt=(\n", + " \"\"\"\n", + "A {{ style }} family pet portrait of a {{ dog_breed }} dog of {{ dog_age }} years old looking {{dog_look_direction}} with an {{ dog_emotion }} expression and\n", + "{{ cat_breed }} cat of {{ cat_age }} years old looking {{ cat_look_direction }} with an {{ cat_emotion }} expression in the background. Both subjects should be in focus.\n", + " \"\"\"\n", + " ),\n", + " model_alias=MODEL_ALIAS,\n", + " )\n", + ")\n", + "\n", + "data_designer.validate(config_builder)" + ] + }, + { + "cell_type": "markdown", + "id": "e13e0bb4", + "metadata": {}, + "source": [ + "### πŸ” Preview: images as base64\n", + "\n", + "In **preview** mode, generated images are stored as base64 strings in the dataframe. Run the next cell to step through each record (images are shown in the sample record display, but only in a notebook environment).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a60a76f", + "metadata": {}, + "outputs": [], + "source": [ + "preview = data_designer.preview(config_builder, num_records=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c831ee8", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(len(preview.dataset)):\n", + " preview.display_sample_record()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "143e762f", + "metadata": {}, + "outputs": [], + "source": [ + "preview.dataset" + ] + }, + { + "cell_type": "markdown", + "id": "a84606b4", + "metadata": {}, + "source": [ + "### πŸ†™ Create: images saved to disk\n", + "\n", + "In **create** mode, images are written to an `images/` folder with UUID filenames; the dataframe stores relative paths (e.g. `images/1d16b6e2-562f-4f51-91e5-baaa999ea916.png`).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89147954", + "metadata": {}, + "outputs": [], + "source": [ + "results = data_designer.create(config_builder, num_records=2, dataset_name=\"tutorial-5-images\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04c96063", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = results.load_dataset()\n", + "dataset.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edb794bb", + "metadata": {}, + "outputs": [], + "source": [ + "# Display all image from the created dataset. Paths are relative to the artifact output directory.\n", + "for index, row in dataset.iterrows():\n", + " path_or_list = row.get(\"generated_image\")\n", + " if path_or_list is not None:\n", + " for path in path_or_list:\n", + " base = results.artifact_storage.base_dataset_path\n", + " full_path = base / path\n", + " display(IPImage(data=full_path))" + ] + }, + { + "cell_type": "markdown", + "id": "e0a72bf6", + "metadata": {}, + "source": [ + "## ⏭️ Next steps\n", + "\n", + "- [The basics](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/): samplers and LLM text columns\n", + "- [Structured outputs and Jinja](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/2-structured-outputs-and-jinja-expressions/)\n", + "- [Seeding with a dataset](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/3-seeding-with-a-dataset/)\n", + "- [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebook_source/1-the-basics.py b/docs/notebook_source/1-the-basics.py index 392efb34..8735d582 100644 --- a/docs/notebook_source/1-the-basics.py +++ b/docs/notebook_source/1-the-basics.py @@ -330,3 +330,5 @@ # # - [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/) # +# - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) +# diff --git a/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py b/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py index 66b3773f..df581612 100644 --- a/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py +++ b/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py @@ -372,3 +372,5 @@ class ProductReview(BaseModel): # # - [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/) # +# - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) +# diff --git a/docs/notebook_source/3-seeding-with-a-dataset.py b/docs/notebook_source/3-seeding-with-a-dataset.py index c9d694a8..e4f9218e 100644 --- a/docs/notebook_source/3-seeding-with-a-dataset.py +++ b/docs/notebook_source/3-seeding-with-a-dataset.py @@ -274,3 +274,5 @@ # # - [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/) # +# - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) +# diff --git a/docs/notebook_source/4-providing-images-as-context.py b/docs/notebook_source/4-providing-images-as-context.py index a11880ba..1fd68dac 100644 --- a/docs/notebook_source/4-providing-images-as-context.py +++ b/docs/notebook_source/4-providing-images-as-context.py @@ -299,3 +299,5 @@ def convert_image_to_chat_format(record, height: int) -> dict: # - Combine vision-based summaries with other column types for multi-modal workflows # - Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering # +# - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) with Data Designer +# diff --git a/docs/notebook_source/5-generating-images.py b/docs/notebook_source/5-generating-images.py new file mode 100644 index 00000000..b445b950 --- /dev/null +++ b/docs/notebook_source/5-generating-images.py @@ -0,0 +1,296 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.18.1 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # 🎨 Data Designer Tutorial: Generating Images +# +# #### πŸ“š What you'll learn +# +# This notebook shows how to generate synthetic image data with Data Designer using image-generation models. +# +# - πŸ–ΌοΈ **Image generation columns**: Add columns that produce images from text prompts +# - πŸ“ **Jinja2 prompts**: Drive diversity by referencing other columns in your prompt template +# - πŸ’Ύ **Preview vs create**: Preview stores base64 in the dataframe; create saves images to disk and stores paths +# +# Data Designer supports both **diffusion** (e.g. DALLΒ·E, Stable Diffusion, Imagen) and **autoregressive** (e.g. Gemini image, GPT image) models; the API is chosen automatically from the model name. +# +# If this is your first time using Data Designer, we recommend starting with the [first notebook](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/) in this tutorial series. +# + +# %% [markdown] +# ### πŸ“¦ Import Data Designer +# +# - `data_designer.config` provides the configuration API. +# - `DataDesigner` is the main interface for generation. +# + +# %% +from IPython.display import Image as IPImage +from IPython.display import display + +import data_designer.config as dd +from data_designer.interface import DataDesigner + +# %% [markdown] +# ### βš™οΈ Initialize the Data Designer interface +# +# When initialized without arguments, [default model providers](https://nvidia-nemo.github.io/DataDesigner/latest/concepts/models/default-model-settings/) are used. This tutorial uses [OpenRouter](https://openrouter.ai) with the Flux 2 Pro image model; set `OPENROUTER_API_KEY` in your environment. +# + +# %% +data_designer = DataDesigner() + +# %% [markdown] +# ### πŸŽ›οΈ Define an image-generation model +# +# - Use `ImageInferenceParams` so Data Designer treats this model as an image generator. +# - Image options (size, quality, aspect ratio, etc.) are model-specific; pass them via `extra_body`. +# + +# %% +MODEL_PROVIDER = "openrouter" +MODEL_ID = "black-forest-labs/flux.2-pro" +MODEL_ALIAS = "image-model" + +model_configs = [ + dd.ModelConfig( + alias=MODEL_ALIAS, + model=MODEL_ID, + provider=MODEL_PROVIDER, + inference_parameters=dd.ImageInferenceParams( + extra_body={"height": 512, "width": 512}, + ), + ) +] + +# %% [markdown] +# ### πŸ—οΈ Build the config: samplers + image column +# +# We'll generate diverse **dog portrait** images: sampler columns drive subject (breed), age, style, look direction, and emotion. The image-generation column uses a Jinja2 prompt that references all of them. +# + +# %% +config_builder = dd.DataDesignerConfigBuilder(model_configs=model_configs) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="style", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=[ + "photorealistic", + "oil painting", + "watercolor", + "digital art", + "sketch", + "anime", + ], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="dog_breed", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=[ + "a Golden Retriever", + "a German Shepherd", + "a Labrador Retriever", + "a Bulldog", + "a Beagle", + "a Poodle", + "a Corgi", + "a Siberian Husky", + "a Dalmatian", + "a Yorkshire Terrier", + "a Boxer", + "a Dachshund", + "a Doberman Pinscher", + "a Shih Tzu", + "a Chihuahua", + "a Border Collie", + "an Australian Shepherd", + "a Cocker Spaniel", + "a Maltese", + "a Pomeranian", + "a Saint Bernard", + "a Great Dane", + "an Akita", + "a Samoyed", + "a Boston Terrier", + ], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="cat_breed", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=[ + "a Persian", + "a Maine Coon", + "a Siamese", + "a Ragdoll", + "a Bengal", + "an Abyssinian", + "a British Shorthair", + "a Sphynx", + "a Scottish Fold", + "a Russian Blue", + "a Birman", + "an Oriental Shorthair", + "a Norwegian Forest Cat", + "a Devon Rex", + "a Burmese", + "an Egyptian Mau", + "a Tonkinese", + "a Himalayan", + "a Savannah", + "a Chartreux", + "a Somali", + "a Manx", + "a Turkish Angora", + "a Balinese", + "an American Shorthair", + ], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="dog_age", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=["1-3", "3-6", "6-9", "9-12", "12-15"], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="cat_age", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=["1-3", "3-6", "6-9", "9-12", "12-18"], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="dog_look_direction", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=["left", "right", "front", "up", "down"], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="cat_look_direction", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=["left", "right", "front", "up", "down"], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="dog_emotion", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=["happy", "curious", "serious", "sleepy", "excited"], + ), + ) +) + +config_builder.add_column( + dd.SamplerColumnConfig( + name="cat_emotion", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams( + values=["aloof", "curious", "content", "sleepy", "playful"], + ), + ) +) + +config_builder.add_column( + dd.ImageColumnConfig( + name="generated_image", + prompt=( + """ +A {{ style }} family pet portrait of a {{ dog_breed }} dog of {{ dog_age }} years old looking {{dog_look_direction}} with an {{ dog_emotion }} expression and +{{ cat_breed }} cat of {{ cat_age }} years old looking {{ cat_look_direction }} with an {{ cat_emotion }} expression in the background. Both subjects should be in focus. + """ + ), + model_alias=MODEL_ALIAS, + ) +) + +data_designer.validate(config_builder) + +# %% [markdown] +# ### πŸ” Preview: images as base64 +# +# In **preview** mode, generated images are stored as base64 strings in the dataframe. Run the next cell to step through each record (images are shown in the sample record display, but only in a notebook environment). +# + +# %% +preview = data_designer.preview(config_builder, num_records=2) + +# %% +for i in range(len(preview.dataset)): + preview.display_sample_record() + +# %% +preview.dataset + +# %% [markdown] +# ### πŸ†™ Create: images saved to disk +# +# In **create** mode, images are written to an `images/` folder with UUID filenames; the dataframe stores relative paths (e.g. `images/1d16b6e2-562f-4f51-91e5-baaa999ea916.png`). +# + +# %% +results = data_designer.create(config_builder, num_records=2, dataset_name="tutorial-5-images") + +# %% +dataset = results.load_dataset() +dataset.head() + +# %% +# Display all image from the created dataset. Paths are relative to the artifact output directory. +for index, row in dataset.iterrows(): + path_or_list = row.get("generated_image") + if path_or_list is not None: + for path in path_or_list: + base = results.artifact_storage.base_dataset_path + full_path = base / path + display(IPImage(data=full_path)) + +# %% [markdown] +# ## ⏭️ Next steps +# +# - [The basics](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/): samplers and LLM text columns +# - [Structured outputs and Jinja](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/2-structured-outputs-and-jinja-expressions/) +# - [Seeding with a dataset](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/3-seeding-with-a-dataset/) +# - [Providing images as context](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/4-providing-images-as-context/) +# diff --git a/docs/notebook_source/_README.md b/docs/notebook_source/_README.md index 09053c22..7bcd77d1 100644 --- a/docs/notebook_source/_README.md +++ b/docs/notebook_source/_README.md @@ -97,6 +97,15 @@ Learn how to use vision-language models to generate text descriptions from image - Generating detailed summaries from document images - Inspecting and validating vision-based generation results +### [5. Generating Images](5-generating-images.ipynb) + +Generate synthetic image data with Data Designer: + +- Configuring image-generation models with `ImageInferenceParams` +- Adding image columns with Jinja2 prompts and sampler-driven diversity +- Preview (base64 in dataframe) vs create (images saved to disk, paths in dataframe) +- Displaying generated images in the notebook + ## πŸ“– Important Documentation Sections Before diving into the tutorials, familiarize yourself with these key documentation sections: diff --git a/packages/data-designer-config/pyproject.toml b/packages/data-designer-config/pyproject.toml index cb61dceb..c69d1ba0 100644 --- a/packages/data-designer-config/pyproject.toml +++ b/packages/data-designer-config/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "jinja2>=3.1.6,<4", "numpy>=1.23.5,<3", "pandas>=2.3.3,<3", + "pillow>=12.0.0,<13", "pyarrow>=19.0.1,<20", # Required for parquet I/O operations "pydantic[email]>=2.9.2,<3", "pygments>=2.19.2,<3", diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py index 2ea641e7..baf21754 100644 --- a/packages/data-designer-config/src/data_designer/config/__init__.py +++ b/packages/data-designer-config/src/data_designer/config/__init__.py @@ -17,6 +17,7 @@ EmbeddingColumnConfig, ExpressionColumnConfig, GenerationStrategy, + ImageColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -41,6 +42,7 @@ GenerationType, ImageContext, ImageFormat, + ImageInferenceParams, ManualDistribution, ManualDistributionParams, Modality, @@ -123,6 +125,7 @@ "CustomColumnConfig": (_MOD_COLUMN_CONFIGS, "CustomColumnConfig"), "EmbeddingColumnConfig": (_MOD_COLUMN_CONFIGS, "EmbeddingColumnConfig"), "ExpressionColumnConfig": (_MOD_COLUMN_CONFIGS, "ExpressionColumnConfig"), + "ImageColumnConfig": (_MOD_COLUMN_CONFIGS, "ImageColumnConfig"), "GenerationStrategy": (_MOD_COLUMN_CONFIGS, "GenerationStrategy"), "LLMCodeColumnConfig": (_MOD_COLUMN_CONFIGS, "LLMCodeColumnConfig"), "LLMJudgeColumnConfig": (_MOD_COLUMN_CONFIGS, "LLMJudgeColumnConfig"), @@ -150,6 +153,7 @@ "GenerationType": (_MOD_MODELS, "GenerationType"), "ImageContext": (_MOD_MODELS, "ImageContext"), "ImageFormat": (_MOD_MODELS, "ImageFormat"), + "ImageInferenceParams": (_MOD_MODELS, "ImageInferenceParams"), "ManualDistribution": (_MOD_MODELS, "ManualDistribution"), "ManualDistributionParams": (_MOD_MODELS, "ManualDistributionParams"), "Modality": (_MOD_MODELS, "Modality"), diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index b2eefd26..49dbb831 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -485,6 +485,59 @@ def side_effect_columns(self) -> list[str]: return [] +class ImageColumnConfig(SingleColumnConfig): + """Configuration for image generation columns. + + Image columns generate images using either autoregressive or diffusion models. + The API used is automatically determined based on the model name: + + Attributes: + prompt: Prompt template for image generation. Supports Jinja2 templating to + reference other columns (e.g., "Generate an image of a {{ character_name }}"). + Must be a valid Jinja2 template. + model_alias: The model to use for image generation. + multi_modal_context: Optional list of image contexts for multi-modal generation. + Enables autoregressive multi-modal models to generate images based on image inputs. + Only works with autoregressive models that support image-to-image generation. + column_type: Discriminator field, always "image" for this configuration type. + """ + + prompt: str + model_alias: str + multi_modal_context: list[ImageContext] | None = None + column_type: Literal["image"] = "image" + + @staticmethod + def get_column_emoji() -> str: + return "πŸ–ΌοΈ" + + @property + def required_columns(self) -> list[str]: + """Get columns referenced in the prompt template. + + Returns: + List of unique column names referenced in Jinja2 templates. + """ + return list(extract_keywords_from_jinja2_template(self.prompt)) + + @model_validator(mode="after") + def assert_prompt_valid_jinja(self) -> Self: + """Validate that prompt is a valid Jinja2 template. + + Returns: + The validated instance. + + Raises: + InvalidConfigError: If prompt contains invalid Jinja2 syntax. + """ + assert_valid_jinja2_template(self.prompt) + return self + + @property + def side_effect_columns(self) -> list[str]: + return [] + + class CustomColumnConfig(SingleColumnConfig): """Configuration for custom user-defined column generators. diff --git a/packages/data-designer-config/src/data_designer/config/column_types.py b/packages/data-designer-config/src/data_designer/config/column_types.py index 8a82223e..baba25dd 100644 --- a/packages/data-designer-config/src/data_designer/config/column_types.py +++ b/packages/data-designer-config/src/data_designer/config/column_types.py @@ -9,6 +9,7 @@ CustomColumnConfig, EmbeddingColumnConfig, ExpressionColumnConfig, + ImageColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -39,6 +40,7 @@ | SeedDatasetColumnConfig | ValidationColumnConfig | EmbeddingColumnConfig + | ImageColumnConfig ) ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT) @@ -87,6 +89,7 @@ def get_column_display_order() -> list[DataDesignerColumnType]: DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, DataDesignerColumnType.CUSTOM, @@ -142,4 +145,5 @@ def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict: DataDesignerColumnType.SAMPLER: SamplerColumnConfig, DataDesignerColumnType.SEED_DATASET: SeedDatasetColumnConfig, DataDesignerColumnType.EMBEDDING: EmbeddingColumnConfig, + DataDesignerColumnType.IMAGE: ImageColumnConfig, } diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index c91443d0..51954a74 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -242,6 +242,7 @@ def sample(self) -> float: class GenerationType(str, Enum): CHAT_COMPLETION = "chat-completion" EMBEDDING = "embedding" + IMAGE = "image" class BaseInferenceParams(ConfigBase, ABC): @@ -421,8 +422,41 @@ def generate_kwargs(self) -> dict[str, float | int]: return result +class ImageInferenceParams(BaseInferenceParams): + """Configuration for image generation models. + + Works for both diffusion and autoregressive image generation models. Pass all model-specific image options via `extra_body`. + + Attributes: + generation_type: Type of generation, always "image" for this class. + + Example: + ```python + # OpenAI-style (DALLΒ·E): quality and size in extra_body or as top-level kwargs + dd.ImageInferenceParams( + extra_body={"size": "1024x1024", "quality": "hd"} + ) + + # Gemini-style: generationConfig.imageConfig + dd.ImageInferenceParams( + extra_body={ + "generationConfig": { + "imageConfig": { + "aspectRatio": "1:1", + "imageSize": "1024" + } + } + } + ) + ``` + """ + + generation_type: Literal[GenerationType.IMAGE] = GenerationType.IMAGE + + InferenceParamsT: TypeAlias = Annotated[ - ChatCompletionInferenceParams | EmbeddingInferenceParams, Field(discriminator="generation_type") + ChatCompletionInferenceParams | EmbeddingInferenceParams | ImageInferenceParams, + Field(discriminator="generation_type"), ] @@ -454,8 +488,13 @@ def generation_type(self) -> GenerationType: def _convert_inference_parameters(cls, value: Any) -> Any: """Convert raw dict to appropriate inference parameters type based on field presence.""" if isinstance(value, dict): - # Infer type from presence of embedding-specific fields - if "encoding_format" in value or "dimensions" in value: + # Check for explicit generation_type first + gen_type = value.get("generation_type") + + # Infer type from generation_type or field presence + if gen_type == "image": + return ImageInferenceParams(**value) + elif gen_type == "embedding" or "encoding_format" in value or "dimensions" in value: return EmbeddingInferenceParams(**value) else: return ChatCompletionInferenceParams(**value) diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py new file mode 100644 index 00000000..45f43622 --- /dev/null +++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py @@ -0,0 +1,269 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Helper utilities for working with images.""" + +from __future__ import annotations + +import base64 +import io +import re +from pathlib import Path +from typing import TYPE_CHECKING + +import requests + +from data_designer.config.models import ImageFormat +from data_designer.lazy_heavy_imports import Image + +if TYPE_CHECKING: + from PIL import Image + +# Magic bytes for image format detection +IMAGE_FORMAT_MAGIC_BYTES = { + ImageFormat.PNG: b"\x89PNG\r\n\x1a\n", + ImageFormat.JPG: b"\xff\xd8\xff", + ImageFormat.GIF: b"GIF8", + # WEBP uses RIFF header - handled separately +} + +# Maps PIL format name (lowercase) to our ImageFormat enum. +# PIL reports "JPEG" (not "JPG"), so we normalize it here. +_PIL_FORMAT_TO_IMAGE_FORMAT: dict[str, ImageFormat] = { + "png": ImageFormat.PNG, + "jpeg": ImageFormat.JPG, + "jpg": ImageFormat.JPG, + "gif": ImageFormat.GIF, + "webp": ImageFormat.WEBP, +} + +_BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/=]+$") + +# Patterns for diffusion-based image models only (use image_generation API). +IMAGE_DIFFUSION_MODEL_PATTERNS = ( + "dall-e-", + "dalle", + "stable-diffusion", + "sd-", + "sd_", + "imagen", + "gpt-image-", +) + +SUPPORTED_IMAGE_EXTENSIONS = [f".{fmt.value.lower()}" for fmt in ImageFormat] + + +def is_image_diffusion_model(model_name: str) -> bool: + """Return True if the model is a diffusion-based image generation model. + + Args: + model_name: Model name or identifier (e.g. from provider). + + Returns: + True if the model is detected as diffusion-based, False otherwise. + """ + return any(pattern in model_name.lower() for pattern in IMAGE_DIFFUSION_MODEL_PATTERNS) + + +def extract_base64_from_data_uri(data: str) -> str: + """Extract base64 from data URI or return as-is. + + Handles data URIs like "data:image/png;base64,iVBORw0..." and returns + just the base64 portion. + + Args: + data: Data URI (e.g., "data:image/png;base64,XXX") or plain base64 + + Returns: + Base64 string without data URI prefix + + Raises: + ValueError: If data URI format is invalid + """ + if data.startswith("data:"): + if "," in data: + return data.split(",", 1)[1] + raise ValueError("Invalid data URI format: missing comma separator") + return data + + +def decode_base64_image(base64_data: str) -> bytes: + """Decode base64 string to image bytes. + + Automatically handles data URIs by extracting the base64 portion first. + + Args: + base64_data: Base64 string (with or without data URI prefix) + + Returns: + Decoded image bytes + + Raises: + ValueError: If base64 data is invalid + """ + # Remove data URI prefix if present + base64_data = extract_base64_from_data_uri(base64_data) + + try: + return base64.b64decode(base64_data, validate=True) + except Exception as e: + raise ValueError(f"Invalid base64 data: {e}") from e + + +def detect_image_format(image_bytes: bytes) -> ImageFormat: + """Detect image format from bytes. + + Uses magic bytes for fast detection, falls back to PIL for robust detection. + + Args: + image_bytes: Image data as bytes + + Returns: + Detected ImageFormat + + Raises: + ValueError: If the image format cannot be determined + """ + # Check magic bytes first (fast) + if image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.PNG]): + return ImageFormat.PNG + elif image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.JPG]): + return ImageFormat.JPG + elif image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.GIF]): + return ImageFormat.GIF + elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]: + return ImageFormat.WEBP + + # Fallback to PIL for robust detection + try: + img = Image.open(io.BytesIO(image_bytes)) + format_str = img.format.lower() if img.format else None + if format_str in _PIL_FORMAT_TO_IMAGE_FORMAT: + return _PIL_FORMAT_TO_IMAGE_FORMAT[format_str] + except Exception: + pass + + raise ValueError( + f"Unable to detect image format (first 8 bytes: {image_bytes[:8]!r}). " + f"Supported formats: {', '.join(SUPPORTED_IMAGE_EXTENSIONS)}." + ) + + +def is_image_path(value: str) -> bool: + """Check if a string is an image file path. + + Args: + value: String to check + + Returns: + True if the string looks like an image file path, False otherwise + """ + if not isinstance(value, str): + return False + return any(value.lower().endswith(ext) for ext in SUPPORTED_IMAGE_EXTENSIONS) + + +def is_base64_image(value: str) -> bool: + """Check if a string is base64-encoded image data. + + Args: + value: String to check + + Returns: + True if the string looks like base64-encoded image data, False otherwise + """ + if not isinstance(value, str): + return False + # Check if it starts with data URI scheme + if value.startswith("data:image/"): + return True + # Check if it looks like base64 (at least 100 chars, contains only base64 chars) + if len(value) > 100 and _BASE64_PATTERN.match(value[:100]): + try: + # Try to decode a small portion to verify it's valid base64 + base64.b64decode(value[:100]) + return True + except Exception: + return False + return False + + +def is_image_url(value: str) -> bool: + """Check if a string is an image URL. + + Args: + value: String to check + + Returns: + True if the string looks like an image URL, False otherwise + """ + if not isinstance(value, str): + return False + return value.startswith(("http://", "https://")) and any(ext in value.lower() for ext in SUPPORTED_IMAGE_EXTENSIONS) + + +def load_image_path_to_base64(image_path: str, base_path: str | None = None) -> str | None: + """Load an image from a file path and return as base64. + + Args: + image_path: Relative or absolute path to the image file. + base_path: Optional base path to resolve relative paths from. + + Returns: + Base64-encoded image data or None if loading fails. + """ + try: + path = Path(image_path) + + # If path is not absolute, try to resolve it + if not path.is_absolute(): + if base_path: + path = Path(base_path) / path + # If still not found, try current working directory + if not path.exists(): + path = Path.cwd() / image_path + + # Check if file exists + if not path.exists(): + return None + + # Read image file and convert to base64 + with open(path, "rb") as f: + image_bytes = f.read() + return base64.b64encode(image_bytes).decode() + except Exception: + return None + + +def load_image_url_to_base64(url: str, timeout: int = 60) -> str: + """Download an image from a URL and return as base64. + + Args: + url: HTTP(S) URL pointing to an image. + timeout: Request timeout in seconds. + + Returns: + Base64-encoded image data. + + Raises: + requests.HTTPError: If the download fails with a non-2xx status. + """ + resp = requests.get(url, timeout=timeout) + resp.raise_for_status() + return base64.b64encode(resp.content).decode() + + +def validate_image(image_path: Path) -> None: + """Validate that an image file is readable and not corrupted. + + Args: + image_path: Path to image file + + Raises: + ValueError: If image is corrupted or unreadable + """ + try: + with Image.open(image_path) as img: + img.verify() + except Exception as e: + raise ValueError(f"Image validation failed: {e}") from e diff --git a/packages/data-designer-config/src/data_designer/config/utils/visualization.py b/packages/data-designer-config/src/data_designer/config/utils/visualization.py index 4f388bce..ac4df1d8 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/visualization.py +++ b/packages/data-designer-config/src/data_designer/config/utils/visualization.py @@ -3,6 +3,7 @@ from __future__ import annotations +import html import json import os from collections import OrderedDict @@ -26,6 +27,13 @@ from data_designer.config.utils.code_lang import code_lang_to_syntax_lexer from data_designer.config.utils.constants import NVIDIA_API_KEY_ENV_VAR_NAME, OPENAI_API_KEY_ENV_VAR_NAME from data_designer.config.utils.errors import DatasetSampleDisplayError +from data_designer.config.utils.image_helpers import ( + extract_base64_from_data_uri, + is_base64_image, + is_image_path, + is_image_url, + load_image_path_to_base64, +) from data_designer.lazy_heavy_imports import np, pd if TYPE_CHECKING: @@ -39,6 +47,69 @@ console = Console() +def _display_image_if_in_notebook(image_data: str, col_name: str) -> bool: + """Display image with caption in Jupyter notebook if available. + + Args: + image_data: Base64-encoded image data, data URI, file path, or URL. + col_name: Name of the column (used for caption). + + Returns: + True if image was displayed, False otherwise. + """ + try: + # Check if we're in a Jupyter environment + from IPython.display import HTML, display + + get_ipython() # This will raise NameError if not in IPython/Jupyter + + # Escape column name to prevent HTML injection + escaped_col_name = html.escape(col_name) + + # URLs: render directly as + if is_image_url(image_data): + escaped_url = html.escape(image_data) + html_content = f""" +
+
πŸ–ΌοΈ {escaped_col_name}
+ +
+ """ + display(HTML(html_content)) + return True + + # File paths: load from disk and convert to base64 + if is_image_path(image_data) and not image_data.startswith("data:image/"): + loaded_base64 = load_image_path_to_base64(image_data) + if loaded_base64 is None: + console.print( + f"[yellow]⚠️ Could not load image from path '{image_data}' for column '{col_name}'[/yellow]" + ) + return False + base64_data = loaded_base64 + else: + base64_data = image_data + + # Extract base64 from data URI if present + img_base64 = extract_base64_from_data_uri(base64_data) + + # Create HTML with caption and image in left-aligned container + html_content = f""" +
+
πŸ–ΌοΈ {escaped_col_name}
+ +
+ """ + display(HTML(html_content)) + return True + except (ImportError, NameError): + # Not in a notebook environment + return False + except Exception as e: + console.print(f"[yellow]⚠️ Could not display image for column '{col_name}': {e}[/yellow]") + return False + + def get_nvidia_api_key() -> str | None: return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME) @@ -100,7 +171,7 @@ def display_sample_record( processors_to_display: List of processors to display the artifacts for. If None, all processors will be displayed. hide_seed_columns: If True, seed columns will not be displayed separately. """ - i = index or self._display_cycle_index + i = self._display_cycle_index if index is None else index try: record = self._record_sampler_dataset.iloc[i] @@ -223,6 +294,66 @@ def display_sample_record( table.add_row(output_col, convert_to_row_element(record[output_col])) render_list.append(pad_console_element(table)) + # Collect image generation columns (will be displayed at the end) + image_columns = config_builder.get_columns_of_type(DataDesignerColumnType.IMAGE) + images_to_display_later = [] + if len(image_columns) > 0: + # Check if we're in a notebook to decide display style + try: + get_ipython() + in_notebook = True + except NameError: + in_notebook = False + + # Create table for image columns + table = Table(title="Images", **table_kws) + table.add_column("Name") + table.add_column("Preview") + + for col in image_columns: + if col.drop: + continue + image_data = record[col.name] + + # Handle list of images + if isinstance(image_data, list): + previews = [] + for idx, img in enumerate(image_data): + if is_base64_image(img): + previews.append(f"[{idx}] ") + if in_notebook: + images_to_display_later.append((f"{col.name}[{idx}]", img)) + elif is_image_url(img): + previews.append(f"[{idx}] ") + if in_notebook: + images_to_display_later.append((f"{col.name}[{idx}]", img)) + elif is_image_path(img): + previews.append(f"[{idx}] ") + if in_notebook: + images_to_display_later.append((f"{col.name}[{idx}]", img)) + else: + previews.append(f"[{idx}] {str(img)[:30]}") + preview = "\n".join(previews) if previews else "" + # Handle single image (backwards compatibility) + elif is_base64_image(image_data): + preview = f"" + if in_notebook: + images_to_display_later.append((col.name, image_data)) + elif is_image_url(image_data): + preview = f"" + if in_notebook: + images_to_display_later.append((col.name, image_data)) + elif is_image_path(image_data): + preview = f"" + if in_notebook: + images_to_display_later.append((col.name, image_data)) + else: + preview = str(image_data)[:100] + "..." if len(str(image_data)) > 100 else str(image_data) + + table.add_row(col.name, preview) + + render_list.append(pad_console_element(table)) + for col in config_builder.get_columns_of_type(DataDesignerColumnType.LLM_CODE): panel = Panel( Syntax( @@ -287,6 +418,11 @@ def display_sample_record( console.print(Group(*render_list), markup=False) + # Display images at the bottom with captions (only in notebook) + if len(images_to_display_later) > 0: + for col_name, image_data in images_to_display_later: + _display_image_if_in_notebook(image_data, col_name) + def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> str: if max_items <= 0: diff --git a/packages/data-designer-config/src/data_designer/lazy_heavy_imports.py b/packages/data-designer-config/src/data_designer/lazy_heavy_imports.py index be7b7185..0e95f248 100644 --- a/packages/data-designer-config/src/data_designer/lazy_heavy_imports.py +++ b/packages/data-designer-config/src/data_designer/lazy_heavy_imports.py @@ -35,6 +35,8 @@ "nx": "networkx", "scipy": "scipy", "jsonschema": "jsonschema", + "PIL": "PIL", + "Image": "PIL.Image", } diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index a069fea0..e633518d 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -53,6 +53,7 @@ def test_data_designer_column_type_get_display_order(): DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, DataDesignerColumnType.CUSTOM, diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py index 38b8079e..564b235c 100644 --- a/packages/data-designer-config/tests/config/test_models.py +++ b/packages/data-designer-config/tests/config/test_models.py @@ -17,6 +17,7 @@ GenerationType, ImageContext, ImageFormat, + ImageInferenceParams, ManualDistribution, ManualDistributionParams, ModalityDataType, @@ -412,6 +413,12 @@ def test_model_config_construction(): assert model_config.inference_parameters == embedding_params assert model_config.generation_type == GenerationType.EMBEDDING + # test construction with image inference parameters + image_params = ImageInferenceParams(extra_body={"size": "1024x1024", "quality": "hd"}) + model_config = ModelConfig(alias="test", model="test", inference_parameters=image_params) + assert model_config.inference_parameters == image_params + assert model_config.generation_type == GenerationType.IMAGE + def test_model_config_generation_type_from_dict(): # Test that generation_type in dict is used to create the right inference params type @@ -435,6 +442,30 @@ def test_model_config_generation_type_from_dict(): assert isinstance(model_config.inference_parameters, ChatCompletionInferenceParams) assert model_config.generation_type == GenerationType.CHAT_COMPLETION + model_config = ModelConfig.model_validate( + { + "alias": "test", + "model": "image-model", + "inference_parameters": { + "generation_type": "image", + "extra_body": {"size": "1024x1024", "quality": "hd"}, + }, + } + ) + assert isinstance(model_config.inference_parameters, ImageInferenceParams) + assert model_config.inference_parameters.extra_body == {"size": "1024x1024", "quality": "hd"} + assert model_config.generation_type == GenerationType.IMAGE + + +def test_image_inference_params_generate_kwargs() -> None: + """ImageInferenceParams.generate_kwargs delegates to base; image params go via extra_body.""" + params = ImageInferenceParams() + assert "quality" not in params.generate_kwargs + assert "size" not in params.generate_kwargs + + params = ImageInferenceParams(extra_body={"size": "1024x1024", "quality": "hd"}) + assert params.generate_kwargs.get("extra_body") == {"size": "1024x1024", "quality": "hd"} + def test_chat_completion_params_format_for_display_all_params(): """Test formatting chat completion model with all parameters.""" diff --git a/packages/data-designer-config/tests/config/utils/test_image_helpers.py b/packages/data-designer-config/tests/config/utils/test_image_helpers.py new file mode 100644 index 00000000..8b2f557a --- /dev/null +++ b/packages/data-designer-config/tests/config/utils/test_image_helpers.py @@ -0,0 +1,335 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import base64 +import io +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + +from data_designer.config.models import ImageFormat +from data_designer.config.utils.image_helpers import ( + decode_base64_image, + detect_image_format, + extract_base64_from_data_uri, + is_base64_image, + is_image_diffusion_model, + is_image_path, + is_image_url, + load_image_path_to_base64, + validate_image, +) +from data_designer.lazy_heavy_imports import Image + + +@pytest.fixture +def sample_png_bytes() -> bytes: + """Create a valid 1x1 PNG as raw bytes.""" + img = Image.new("RGB", (1, 1), color="red") + buf = io.BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + +# --------------------------------------------------------------------------- +# extract_base64_from_data_uri +# --------------------------------------------------------------------------- + + +def test_extract_base64_from_data_uri_with_prefix() -> None: + data_uri = "data:image/png;base64,iVBORw0KGgoAAAANS" + result = extract_base64_from_data_uri(data_uri) + assert result == "iVBORw0KGgoAAAANS" + + +def test_extract_base64_plain_base64_without_prefix() -> None: + plain_base64 = "iVBORw0KGgoAAAANS" + result = extract_base64_from_data_uri(plain_base64) + assert result == plain_base64 + + +def test_extract_base64_invalid_data_uri_raises_error() -> None: + with pytest.raises(ValueError, match="Invalid data URI format: missing comma separator"): + extract_base64_from_data_uri("data:image/png;base64") + + +# --------------------------------------------------------------------------- +# decode_base64_image +# --------------------------------------------------------------------------- + + +def test_decode_base64_image_valid() -> None: + png_bytes = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" + base64_data = base64.b64encode(png_bytes).decode() + result = decode_base64_image(base64_data) + assert result == png_bytes + + +def test_decode_base64_image_with_data_uri() -> None: + png_bytes = b"\x89PNG\r\n\x1a\n" + base64_data = base64.b64encode(png_bytes).decode() + data_uri = f"data:image/png;base64,{base64_data}" + result = decode_base64_image(data_uri) + assert result == png_bytes + + +def test_decode_base64_image_invalid_raises_error() -> None: + with pytest.raises(ValueError, match="Invalid base64 data"): + decode_base64_image("not-valid-base64!!!") + + +# --------------------------------------------------------------------------- +# detect_image_format (magic bytes) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "header_bytes,expected_format", + [ + (b"\x89PNG\r\n\x1a\n" + b"\x00" * 10, ImageFormat.PNG), + (b"\xff\xd8\xff" + b"\x00" * 10, ImageFormat.JPG), + (b"RIFF" + b"\x00" * 4 + b"WEBP", ImageFormat.WEBP), + ], + ids=["png", "jpg", "webp"], +) +def test_detect_image_format_magic_bytes(header_bytes: bytes, expected_format: ImageFormat) -> None: + assert detect_image_format(header_bytes) == expected_format + + +def test_detect_image_format_gif_magic_bytes(tmp_path: Path) -> None: + img = Image.new("RGB", (1, 1), color="red") + gif_path = tmp_path / "test.gif" + img.save(gif_path, format="GIF") + gif_bytes = gif_path.read_bytes() + assert detect_image_format(gif_bytes) == ImageFormat.GIF + + +def test_detect_image_format_with_pil_fallback_jpeg() -> None: + mock_img = Mock() + mock_img.format = "JPEG" + test_bytes = b"\x00\x00\x00\x00" + + with patch.object(Image, "open", return_value=mock_img): + result = detect_image_format(test_bytes) + assert result == ImageFormat.JPG + + +def test_detect_image_format_unknown_raises_error() -> None: + unknown_bytes = b"\x00\x00\x00\x00" + b"\x00" * 10 + with pytest.raises(ValueError, match="Unable to detect image format"): + detect_image_format(unknown_bytes) + + +# --------------------------------------------------------------------------- +# is_image_path +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "value,expected", + [ + ("/path/to/image.png", True), + ("image.PNG", True), + ("image.jpg", True), + ("image.jpeg", True), + ("/path/to/file.txt", False), + ("document.pdf", False), + ("/some.png/file.txt", False), + ], + ids=["png", "png-upper", "jpg", "jpeg", "txt", "pdf", "ext-in-dir"], +) +def test_is_image_path(value: str, expected: bool) -> None: + assert is_image_path(value) is expected + + +# --------------------------------------------------------------------------- +# is_image_url +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "value,expected", + [ + ("http://example.com/image.png", True), + ("https://example.com/photo.jpg", True), + ("https://example.com/image.png?size=large", True), + ("https://example.com/page.html", False), + ("ftp://example.com/image.png", False), + ], + ids=["http", "https", "query-params", "non-image-ext", "ftp"], +) +def test_is_image_url(value: str, expected: bool) -> None: + assert is_image_url(value) is expected + + +# --------------------------------------------------------------------------- +# is_base64_image +# --------------------------------------------------------------------------- + + +def test_is_base64_image_data_uri() -> None: + assert is_base64_image("data:image/png;base64,iVBORw0KGgo") is True + + +def test_is_base64_image_long_valid_base64() -> None: + long_base64 = base64.b64encode(b"x" * 100).decode() + assert is_base64_image(long_base64) is True + + +def test_is_base64_image_short_string() -> None: + assert is_base64_image("short") is False + + +def test_is_base64_image_invalid_base64_decode() -> None: + invalid_base64 = "A" * 50 + "=" + "A" * 49 + "more text" + assert is_base64_image(invalid_base64) is False + + +# --------------------------------------------------------------------------- +# Non-string guard (is_image_path, is_base64_image, is_image_url) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "func", + [is_image_path, is_base64_image, is_image_url], + ids=["is_image_path", "is_base64_image", "is_image_url"], +) +@pytest.mark.parametrize("value", [123, None, []], ids=["int", "none", "list"]) +def test_non_string_input_returns_false(func: object, value: object) -> None: + assert func(value) is False + + +# --------------------------------------------------------------------------- +# is_image_diffusion_model +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "model_name,expected", + [ + ("dall-e-3", True), + ("DALL-E-2", True), + ("openai/dalle-2", True), + ("stable-diffusion-xl", True), + ("sd-2.1", True), + ("sd_1.5", True), + ("imagen-3", True), + ("google/imagen", True), + ("gpt-image-1", True), + ("gemini-3-pro-image-preview", False), + ("gpt-5-image", False), + ("flux.2-pro", False), + ], + ids=[ + "dall-e-3", + "DALL-E-2", + "dalle-2", + "stable-diffusion-xl", + "sd-2.1", + "sd_1.5", + "imagen-3", + "google-imagen", + "gpt-image-1", + "gemini-not-diffusion", + "gpt-5-not-diffusion", + "flux-not-diffusion", + ], +) +def test_is_image_diffusion_model(model_name: str, expected: bool) -> None: + assert is_image_diffusion_model(model_name) is expected + + +# --------------------------------------------------------------------------- +# validate_image +# --------------------------------------------------------------------------- + + +def test_validate_image_valid_png(tmp_path: Path, sample_png_bytes: bytes) -> None: + image_path = tmp_path / "test.png" + image_path.write_bytes(sample_png_bytes) + validate_image(image_path) + + +def test_validate_image_corrupted_raises_error(tmp_path: Path) -> None: + image_path = tmp_path / "corrupted.png" + image_path.write_bytes(b"not a valid image") + with pytest.raises(ValueError, match="Image validation failed"): + validate_image(image_path) + + +def test_validate_image_nonexistent_raises_error(tmp_path: Path) -> None: + image_path = tmp_path / "nonexistent.png" + with pytest.raises(ValueError, match="Image validation failed"): + validate_image(image_path) + + +# --------------------------------------------------------------------------- +# load_image_path_to_base64 +# --------------------------------------------------------------------------- + + +def test_load_image_path_to_base64_absolute_path(tmp_path: Path) -> None: + img = Image.new("RGB", (1, 1), color="blue") + image_path = tmp_path / "test.png" + img.save(image_path) + + result = load_image_path_to_base64(str(image_path)) + assert result is not None + assert len(result) > 0 + decoded = base64.b64decode(result) + assert len(decoded) > 0 + + +def test_load_image_path_to_base64_relative_with_base_path(tmp_path: Path) -> None: + img = Image.new("RGB", (1, 1), color="green") + image_path = tmp_path / "subdir" / "test.png" + image_path.parent.mkdir(exist_ok=True) + img.save(image_path) + + result = load_image_path_to_base64("subdir/test.png", base_path=str(tmp_path)) + assert result is not None + assert len(result) > 0 + + +def test_load_image_path_to_base64_nonexistent_file() -> None: + result = load_image_path_to_base64("/nonexistent/path/to/image.png") + assert result is None + + +def test_load_image_path_to_base64_relative_with_cwd_fallback(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.chdir(tmp_path) + + img = Image.new("RGB", (1, 1), color="yellow") + image_path = tmp_path / "test_cwd.png" + img.save(image_path) + + result = load_image_path_to_base64("test_cwd.png") + assert result is not None + assert len(result) > 0 + + +def test_load_image_path_to_base64_base_path_fallback_to_cwd(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.chdir(tmp_path) + + img = Image.new("RGB", (1, 1), color="red") + image_path = tmp_path / "test.png" + img.save(image_path) + + wrong_base = tmp_path / "wrong" + wrong_base.mkdir() + + result = load_image_path_to_base64("test.png", base_path=str(wrong_base)) + assert result is not None + assert len(result) > 0 + + +def test_load_image_path_to_base64_exception_handling(tmp_path: Path) -> None: + dir_path = tmp_path / "directory" + dir_path.mkdir() + + result = load_image_path_to_base64(str(dir_path)) + assert result is None diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py new file mode 100644 index 00000000..730e73bb --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from data_designer.config.column_configs import ImageColumnConfig +from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModel, GenerationStrategy +from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering +from data_designer.engine.processing.utils import deserialize_json_values + +if TYPE_CHECKING: + from data_designer.engine.storage.media_storage import MediaStorage + + +class ImageCellGenerator(WithJinja2UserTemplateRendering, ColumnGeneratorWithModel[ImageColumnConfig]): + """Generator for image columns with disk or dataframe persistence. + + Media storage always exists and determines behavior via its mode: + - DISK mode: Saves images to disk and stores relative paths in dataframe + - DATAFRAME mode: Stores base64 directly in dataframe + """ + + @property + def media_storage(self) -> MediaStorage: + """Get media storage from resource provider.""" + return self._resource_provider.artifact_storage.media_storage + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + """Generate image(s) and optionally save to disk. + + Args: + data: Record data + + Returns: + Record with image path(s) (create mode) or base64 data (preview mode) added + """ + deserialized_record = deserialize_json_values(data) + + # Validate required columns + missing_columns = list(set(self.config.required_columns) - set(data.keys())) + if len(missing_columns) > 0: + error_msg = ( + f"There was an error preparing the Jinja2 expression template. " + f"The following columns {missing_columns} are missing!" + ) + raise ValueError(error_msg) + + # Render prompt template + self.prepare_jinja2_template_renderer(self.config.prompt, list(deserialized_record.keys())) + prompt = self.render_template(deserialized_record) + + # Validate prompt is non-empty + if not prompt or not prompt.strip(): + raise ValueError(f"Rendered prompt for column {self.config.name!r} is empty") + + # Process multi-modal context if provided + multi_modal_context = None + if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0: + multi_modal_context = [] + for context in self.config.multi_modal_context: + multi_modal_context.extend(context.get_contexts(deserialized_record)) + + # Generate images (returns list of base64 strings) + base64_images = self.model.generate_image(prompt=prompt, multi_modal_context=multi_modal_context) + + # Store via media storage (mode determines disk vs dataframe storage) + # Use column name as subfolder to organize images + results = [ + self.media_storage.save_base64_image(base64_image, subfolder_name=self.config.name) + for base64_image in base64_images + ] + data[self.config.name] = results + + return data diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/registry.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/registry.py index 46642622..f4fc27b9 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/registry.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/registry.py @@ -8,6 +8,7 @@ CustomColumnConfig, EmbeddingColumnConfig, ExpressionColumnConfig, + ImageColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -19,6 +20,7 @@ from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator +from data_designer.engine.column_generators.generators.image import ImageCellGenerator from data_designer.engine.column_generators.generators.llm_completion import ( LLMCodeCellGenerator, LLMJudgeCellGenerator, @@ -52,6 +54,7 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig) registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig) registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig) + registry.register(DataDesignerColumnType.IMAGE, ImageCellGenerator, ImageColumnConfig) if with_plugins: for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR): registry.register( diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/generator_classification.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/generator_classification.py index 1b891b16..1411374d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/generator_classification.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/generator_classification.py @@ -23,6 +23,7 @@ def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType) DataDesignerColumnType.LLM_TEXT, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE, } dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType)) return column_type in dag_column_types @@ -37,6 +38,7 @@ def column_type_is_model_generated(column_type: str | DataDesignerColumnType) -> DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE, } for plugin in plugin_manager.get_column_generator_plugins(): if issubclass(plugin.impl_cls, ColumnGeneratorWithModelRegistry): diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py index 35e7d4f8..43b817b0 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/artifact_storage.py @@ -11,11 +11,12 @@ from pathlib import Path from typing import TYPE_CHECKING -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator, model_validator from data_designer.config.utils.io_helpers import read_parquet_dataset from data_designer.config.utils.type_helpers import StrEnum, resolve_string_enum from data_designer.engine.dataset_builders.errors import ArtifactStorageError +from data_designer.engine.storage.media_storage import MediaStorage, StorageMode from data_designer.lazy_heavy_imports import pd if TYPE_CHECKING: @@ -38,12 +39,25 @@ class BatchStage(StrEnum): class ArtifactStorage(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + artifact_path: Path | str dataset_name: str = "dataset" final_dataset_folder_name: str = FINAL_DATASET_FOLDER_NAME partial_results_folder_name: str = "tmp-partial-parquet-files" dropped_columns_folder_name: str = "dropped-columns-parquet-files" processors_outputs_folder_name: str = PROCESSORS_OUTPUTS_FOLDER_NAME + _media_storage: MediaStorage = PrivateAttr(default=None) + + @property + def media_storage(self) -> MediaStorage: + """Access media storage instance.""" + return self._media_storage + + @media_storage.setter + def media_storage(self, value: MediaStorage) -> None: + """Set media storage instance.""" + self._media_storage = value @property def artifact_path_exists(self) -> bool: @@ -114,8 +128,22 @@ def validate_folder_names(self): if any(char in invalid_chars for char in name): raise ArtifactStorageError(f"πŸ›‘ Directory name '{name}' contains invalid characters.") + # Initialize media storage with DISK mode by default + self._media_storage = MediaStorage( + base_path=self.base_dataset_path, + mode=StorageMode.DISK, + ) + return self + def set_media_storage_mode(self, mode: StorageMode) -> None: + """Set media storage mode. + + Args: + mode: StorageMode.DISK (save to disk) or StorageMode.DATAFRAME (store in memory) + """ + self._media_storage.mode = mode + @staticmethod def mkdir_if_needed(path: Path | str) -> Path: """Create the directory if it does not exist.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index e5404d49..6e42844b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Callable from data_designer.config.column_configs import CustomColumnConfig -from data_designer.config.column_types import ColumnConfigT +from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType from data_designer.config.config_builder import BuilderConfig from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.processors import ( @@ -40,6 +40,7 @@ from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry from data_designer.engine.resources.resource_provider import ResourceProvider +from data_designer.engine.storage.media_storage import StorageMode from data_designer.lazy_heavy_imports import pd if TYPE_CHECKING: @@ -108,10 +109,29 @@ def build( *, num_records: int, on_batch_complete: Callable[[Path], None] | None = None, + save_multimedia_to_disk: bool = True, ) -> Path: + """Build the dataset. + + Args: + num_records: Number of records to generate. + on_batch_complete: Optional callback function called when each batch completes. + save_multimedia_to_disk: Whether to save generated multimedia (images, audio, video) to disk. + If False, multimedia is stored directly in the DataFrame (e.g., images as base64). + Default is True. + + Returns: + Path to the generated dataset directory. + """ self._run_model_health_check_if_needed() self._run_mcp_tool_check_if_needed() self._write_builder_config() + + # Set media storage mode based on parameters + if self._has_image_columns(): + mode = StorageMode.DISK if save_multimedia_to_disk else StorageMode.DATAFRAME + self.artifact_storage.set_media_storage_mode(mode) + generators = self._initialize_generators() start_time = time.perf_counter() group_id = uuid.uuid4().hex @@ -138,6 +158,10 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame: self._run_model_health_check_if_needed() self._run_mcp_tool_check_if_needed() + # Set media storage to DATAFRAME mode for preview - base64 stored directly in DataFrame + if self._has_image_columns(): + self.artifact_storage.set_media_storage_mode(StorageMode.DATAFRAME) + generators = self._initialize_generators() group_id = uuid.uuid4().hex start_time = time.perf_counter() @@ -154,6 +178,10 @@ def process_preview(self, dataset: pd.DataFrame) -> pd.DataFrame: df = self._processor_runner.run_post_batch(dataset.copy(), current_batch_number=None) return self._processor_runner.run_after_generation_on_df(df) + def _has_image_columns(self) -> bool: + """Check if config has any image generation columns.""" + return any(col.column_type == DataDesignerColumnType.IMAGE for col in self.single_column_configs) + def _initialize_generators(self) -> list[ColumnGenerator]: return [ self._registry.column_generators.get_for_config_type(type(config))( @@ -286,6 +314,7 @@ def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max progress_tracker.log_final() if len(self._records_to_drop) > 0: + self._cleanup_dropped_record_images(self._records_to_drop) self.batch_manager.drop_records(self._records_to_drop) self._records_to_drop.clear() @@ -349,6 +378,30 @@ def _initialize_processors(self, processor_configs: list[ProcessorConfig]) -> li return processors + def _cleanup_dropped_record_images(self, dropped_indices: set[int]) -> None: + """Remove saved image files for records that will be dropped. + + When a record fails during generation, any images already saved to disk + for that record in previous columns become dangling. This method deletes + those files so they don't accumulate. + """ + media_storage = self.artifact_storage.media_storage + if not self._has_image_columns() or media_storage is None or media_storage.mode != StorageMode.DISK: + return + + image_col_names = [ + col.name for col in self.single_column_configs if col.column_type == DataDesignerColumnType.IMAGE + ] + + buffer = self.batch_manager.get_current_batch(as_dataframe=False) + for idx in dropped_indices: + if idx < 0 or idx >= len(buffer): + continue + for col_name in image_col_names: + paths = buffer[idx].get(col_name, []) + for path in [paths] if isinstance(paths, str) else paths: + media_storage.delete_image(path) + def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) -> None: """If a worker fails, we can handle the exception here.""" logger.warning( diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py index 3e1ddf01..8ca1ebfd 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py @@ -83,6 +83,9 @@ class ModelStructuredOutputError(DataDesignerError): ... class ModelGenerationValidationFailureError(DataDesignerError): ... +class ImageGenerationError(DataDesignerError): ... + + class FormattedLLMErrorMessage(BaseModel): cause: str solution: str diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 9f5fc85c..ef328a9a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -9,16 +9,23 @@ from typing import TYPE_CHECKING, Any from data_designer.config.models import GenerationType, ModelConfig, ModelProvider +from data_designer.config.utils.image_helpers import ( + extract_base64_from_data_uri, + is_base64_image, + is_image_diffusion_model, + load_image_url_to_base64, +) from data_designer.engine.mcp.errors import MCPConfigurationError from data_designer.engine.model_provider import ModelProviderRegistry from data_designer.engine.models.errors import ( GenerationValidationFailureError, + ImageGenerationError, catch_llm_exceptions, get_exception_primary_cause, ) from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs from data_designer.engine.models.parsers.errors import ParserException -from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats +from data_designer.engine.models.usage import ImageUsageStats, ModelUsageStats, RequestUsageStats, TokenUsageStats from data_designer.engine.models.utils import ChatMessage, prompt_to_messages from data_designer.engine.secret_resolver import SecretResolver from data_designer.lazy_heavy_imports import litellm @@ -35,6 +42,32 @@ def _identity(x: Any) -> Any: return x +def _try_extract_base64(source: str | litellm.types.utils.ImageObject) -> str | None: + """Try to extract base64 image data from a data URI string or image response object. + + Args: + source: Either a data URI string (e.g. "data:image/png;base64,...") + or a litellm ImageObject with b64_json/url attributes. + + Returns: + Base64-encoded image string, or None if extraction fails. + """ + try: + if isinstance(source, str): + return extract_base64_from_data_uri(source) + + if getattr(source, "b64_json", None): + return source.b64_json + + if getattr(source, "url", None): + return load_image_url_to_base64(source.url) + except Exception: + logger.debug(f"Failed to extract base64 from source of type {type(source).__name__}") + return None + + return None + + logger = logging.getLogger(__name__) @@ -105,7 +138,7 @@ def completion( raise e finally: if not skip_usage_tracking and response is not None: - self._track_usage(response) + self._track_token_usage_from_completion(response) def consolidate_kwargs(self, **kwargs) -> dict[str, Any]: # Remove purpose from kwargs to avoid passing it to the model @@ -117,50 +150,6 @@ def consolidate_kwargs(self, **kwargs) -> dict[str, Any]: kwargs["extra_headers"] = self.model_provider.extra_headers return kwargs - def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: - if tool_alias is None: - return None - if self._mcp_registry is None: - raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.") - - try: - return self._mcp_registry.get_mcp(tool_alias=tool_alias) - except ValueError as exc: - raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc - - @catch_llm_exceptions - def generate_text_embeddings( - self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs - ) -> list[list[float]]: - logger.debug( - f"Generating embeddings with model {self.model_name!r}...", - extra={ - "model": self.model_name, - "input_count": len(input_texts), - }, - ) - kwargs = self.consolidate_kwargs(**kwargs) - response = None - try: - response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs) - logger.debug( - f"Received embeddings from model {self.model_name!r}", - extra={ - "model": self.model_name, - "embedding_count": len(response.data) if response.data else 0, - "usage": self._usage_stats.model_dump(), - }, - ) - if response.data and len(response.data) == len(input_texts): - return [data["embedding"] for data in response.data] - else: - raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}") - except Exception as e: - raise e - finally: - if not skip_usage_tracking and response is not None: - self._track_usage_from_embedding(response) - @catch_llm_exceptions def generate( self, @@ -309,6 +298,208 @@ def generate( return output_obj, messages + @catch_llm_exceptions + def generate_text_embeddings( + self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs + ) -> list[list[float]]: + logger.debug( + f"Generating embeddings with model {self.model_name!r}...", + extra={ + "model": self.model_name, + "input_count": len(input_texts), + }, + ) + kwargs = self.consolidate_kwargs(**kwargs) + response = None + try: + response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs) + logger.debug( + f"Received embeddings from model {self.model_name!r}", + extra={ + "model": self.model_name, + "embedding_count": len(response.data) if response.data else 0, + "usage": self._usage_stats.model_dump(), + }, + ) + if response.data and len(response.data) == len(input_texts): + return [data["embedding"] for data in response.data] + else: + raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}") + except Exception as e: + raise e + finally: + if not skip_usage_tracking and response is not None: + self._track_token_usage_from_embedding(response) + + @catch_llm_exceptions + def generate_image( + self, + prompt: str, + multi_modal_context: list[dict[str, Any]] | None = None, + skip_usage_tracking: bool = False, + **kwargs, + ) -> list[str]: + """Generate image(s) and return base64-encoded data. + + Automatically detects the appropriate API based on model name: + - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) β†’ image_generation API + - All other models β†’ chat/completions API (default) + + Both paths return base64-encoded image data. If the API returns multiple images, + all are returned in the list. + + Args: + prompt: The prompt for image generation + multi_modal_context: Optional list of image contexts for multi-modal generation. + Only used with autoregressive models via chat completions API. + skip_usage_tracking: Whether to skip usage tracking + **kwargs: Additional arguments to pass to the model (including n=number of images) + + Returns: + List of base64-encoded image strings (without data URI prefix) + + Raises: + ImageGenerationError: If image generation fails or returns invalid data + """ + logger.debug( + f"Generating image with model {self.model_name!r}...", + extra={"model": self.model_name, "prompt": prompt}, + ) + + # Auto-detect API type based on model name + if is_image_diffusion_model(self.model_name): + images = self._generate_image_diffusion(prompt, skip_usage_tracking, **kwargs) + else: + images = self._generate_image_chat_completion(prompt, multi_modal_context, skip_usage_tracking, **kwargs) + + # Track image usage + if not skip_usage_tracking and len(images) > 0: + self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) + + return images + + def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: + if tool_alias is None: + return None + if self._mcp_registry is None: + raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.") + + try: + return self._mcp_registry.get_mcp(tool_alias=tool_alias) + except ValueError as exc: + raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc + + def _generate_image_chat_completion( + self, + prompt: str, + multi_modal_context: list[dict[str, Any]] | None = None, + skip_usage_tracking: bool = False, + **kwargs, + ) -> list[str]: + """Generate image(s) using autoregressive model via chat completions API. + + Args: + prompt: The prompt for image generation + multi_modal_context: Optional list of image contexts for multi-modal generation + skip_usage_tracking: Whether to skip usage tracking + **kwargs: Additional arguments to pass to the model + + Returns: + List of base64-encoded image strings + """ + messages = prompt_to_messages(user_prompt=prompt, multi_modal_context=multi_modal_context) + + response = None + try: + response = self.completion( + messages=messages, + skip_usage_tracking=skip_usage_tracking, + **kwargs, + ) + + logger.debug( + f"Received image(s) from autoregressive model {self.model_name!r}", + extra={"model": self.model_name, "response": response}, + ) + + # Validate response structure + if not response.choices or len(response.choices) == 0: + raise ImageGenerationError("Image generation response missing choices") + + message = response.choices[0].message + images = [] + + # Extract base64 from images attribute (primary path) + if hasattr(message, "images") and message.images: + for image in message.images: + # Handle different response formats + if isinstance(image, dict) and "image_url" in image: + image_url = image["image_url"] + + if isinstance(image_url, dict) and "url" in image_url: + if (b64 := _try_extract_base64(image_url["url"])) is not None: + images.append(b64) + elif isinstance(image_url, str): + if (b64 := _try_extract_base64(image_url)) is not None: + images.append(b64) + # Fallback: treat as base64 string + elif isinstance(image, str): + if (b64 := _try_extract_base64(image)) is not None: + images.append(b64) + + # Fallback: check content field if it looks like image data + if not images: + content = message.content or "" + if content and (content.startswith("data:image/") or is_base64_image(content)): + if (b64 := _try_extract_base64(content)) is not None: + images.append(b64) + + if not images: + raise ImageGenerationError("No image data found in image generation response") + + return images + + except Exception: + raise + + def _generate_image_diffusion(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]: + """Generate image(s) using diffusion model via image_generation API. + + Always returns base64. If the API returns URLs instead of inline base64, + the images are downloaded and converted automatically. + + Returns: + List of base64-encoded image strings + """ + kwargs = self.consolidate_kwargs(**kwargs) + + response = None + + try: + response = self._router.image_generation(prompt=prompt, model=self.model_name, **kwargs) + + logger.debug( + f"Received {len(response.data)} image(s) from diffusion model {self.model_name!r}", + extra={"model": self.model_name, "response": response}, + ) + + # Validate response + if not response.data or len(response.data) == 0: + raise ImageGenerationError("Image generation returned no data") + + images = [b64 for img in response.data if (b64 := _try_extract_base64(img)) is not None] + + if not images: + raise ImageGenerationError("No image data could be extracted from response") + + return images + + except Exception: + raise + finally: + if not skip_usage_tracking and response is not None: + self._track_token_usage_from_image_diffusion(response) + def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict: provider = self._model_provider_registry.get_provider(model_config.provider) api_key = None @@ -326,7 +517,7 @@ def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.Deployme "litellm_params": litellm_params.model_dump(), } - def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> None: + def _track_token_usage_from_completion(self, response: litellm.types.utils.ModelResponse | None) -> None: if response is None: self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) return @@ -343,7 +534,7 @@ def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> No request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) - def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None: + def _track_token_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None: if response is None: self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) return @@ -355,3 +546,21 @@ def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingRes ), request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) + + def _track_token_usage_from_image_diffusion(self, response: litellm.types.utils.ImageResponse | None) -> None: + """Track token usage from image_generation API response.""" + if response is None: + self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) + return + + if response.usage is not None and isinstance(response.usage, litellm.types.utils.ImageUsage): + self._usage_stats.extend( + token_usage=TokenUsageStats( + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + ), + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + ) + else: + # Successful response but no token usage data (some providers don't report it) + self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=1, failed_requests=0)) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/registry.py b/packages/data-designer-engine/src/data_designer/engine/models/registry.py index 03bef223..0b103e76 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/registry.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/registry.py @@ -120,6 +120,10 @@ def log_model_usage(self, total_time_elapsed: float) -> None: f"turns={tool_usage['total_tool_call_turns']}" ) + if image_usage := stats.get("image_usage"): + total_images = image_usage["total_images"] + logger.info(f"{LOG_INDENT}images: total={total_images}") + if model_index < len(sorted_model_names) - 1: logger.info(LOG_INDENT.rstrip()) @@ -183,6 +187,12 @@ def run_health_check(self, model_aliases: list[str]) -> None: skip_usage_tracking=True, purpose="running health checks", ) + elif model.model_generation_type == GenerationType.IMAGE: + model.generate_image( + prompt="Generate a simple illustration of a thumbs up sign.", + skip_usage_tracking=True, + purpose="running health checks", + ) else: raise ValueError(f"Unsupported generation type: {model.model_generation_type}") logger.info(f"{LOG_INDENT}βœ… Passed!") diff --git a/packages/data-designer-engine/src/data_designer/engine/models/usage.py b/packages/data-designer-engine/src/data_designer/engine/models/usage.py index f44a31ae..64e82b47 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/usage.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/usage.py @@ -71,14 +71,27 @@ def merge(self, other: ToolUsageStats) -> ToolUsageStats: return self +class ImageUsageStats(BaseModel): + total_images: int = 0 + + @property + def has_usage(self) -> bool: + return self.total_images > 0 + + def extend(self, *, images: int) -> None: + """Extend stats with generated images count.""" + self.total_images += images + + class ModelUsageStats(BaseModel): token_usage: TokenUsageStats = TokenUsageStats() request_usage: RequestUsageStats = RequestUsageStats() tool_usage: ToolUsageStats = ToolUsageStats() + image_usage: ImageUsageStats = ImageUsageStats() @property def has_usage(self) -> bool: - return self.token_usage.has_usage and self.request_usage.has_usage + return self.token_usage.has_usage or self.request_usage.has_usage or self.image_usage.has_usage def extend( self, @@ -86,6 +99,7 @@ def extend( token_usage: TokenUsageStats | None = None, request_usage: RequestUsageStats | None = None, tool_usage: ToolUsageStats | None = None, + image_usage: ImageUsageStats | None = None, ) -> None: if token_usage is not None: self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens) @@ -95,9 +109,16 @@ def extend( ) if tool_usage is not None: self.tool_usage.merge(tool_usage) + if image_usage is not None: + self.image_usage.extend(images=image_usage.total_images) def get_usage_stats(self, *, total_time_elapsed: float) -> dict: - exclude = {"tool_usage"} if not self.tool_usage.has_usage else None + exclude = set() + if not self.tool_usage.has_usage: + exclude.add("tool_usage") + if not self.image_usage.has_usage: + exclude.add("image_usage") + exclude = exclude if exclude else None return self.model_dump(exclude=exclude) | { "tokens_per_second": int(self.token_usage.total_tokens / total_time_elapsed) if total_time_elapsed > 0 diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py b/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py new file mode 100644 index 00000000..9d416c65 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/storage/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.engine.storage.media_storage import MediaStorage, StorageMode + +__all__ = ["MediaStorage", "StorageMode"] diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py new file mode 100644 index 00000000..1c887c80 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import uuid +from pathlib import Path + +from data_designer.config.utils.image_helpers import decode_base64_image, detect_image_format, validate_image +from data_designer.config.utils.type_helpers import StrEnum + +IMAGES_SUBDIR = "images" + + +class StorageMode(StrEnum): + """Storage mode for generated media content. + + - DISK: Save media to disk and store relative paths in dataframe (for dataset creation) + - DATAFRAME: Store base64 data directly in dataframe (for preview mode) + """ + + DISK = "disk" + DATAFRAME = "dataframe" + + +class MediaStorage: + """Manages storage of generated media content. + + Currently supports: + - Images (PNG, JPG, WEBP) + + Storage modes: + - DISK: Save media to disk and return relative paths (for dataset creation) + - DATAFRAME: Return base64 data directly (for preview mode) + + Handles: + - Creating storage directories + - Decoding base64 to bytes + - Detecting media format + - Saving with UUID filenames (DISK mode) + - Returning relative paths or base64 data based on mode + - Always validates images to ensure data quality + """ + + def __init__( + self, base_path: Path, images_subdir: str = IMAGES_SUBDIR, mode: StorageMode = StorageMode.DISK + ) -> None: + """Initialize media storage manager. + + Args: + base_path: Base directory for dataset + images_subdir: Subdirectory name for images (default: "images") + mode: Storage mode - DISK (save to disk) or DATAFRAME (return base64) + """ + self.base_path = Path(base_path) + self.images_dir = self.base_path / images_subdir + self.images_subdir = images_subdir + self.mode = mode + + def save_base64_image(self, base64_data: str, subfolder_name: str) -> str: + """Save or return base64 image based on storage mode. + + Args: + base64_data: Base64 encoded image string (with or without data URI prefix) + subfolder_name: Subfolder name to organize images (e.g., "images//") + + Returns: + DISK mode: Relative path to saved image (e.g., "images/subfolder_name/f47ac10b-58cc.png") + DATAFRAME mode: Original base64 data string + + Raises: + ValueError: If base64 data is invalid (DISK mode only) + OSError: If disk write fails (DISK mode only) + """ + # DATAFRAME mode: return base64 directly without disk operations + if self.mode == StorageMode.DATAFRAME: + return base64_data + + # DISK mode: save to disk, validate, and return relative path + # Sanitize subfolder name to prevent path traversal + sanitized_subfolder = self._sanitize_subfolder_name(subfolder_name) + + # Determine the target directory (organized by subfolder) + target_dir = self.images_dir / sanitized_subfolder + + # Ensure target directory exists (lazy initialization) + target_dir.mkdir(parents=True, exist_ok=True) + + # Decode base64 to bytes + image_bytes = decode_base64_image(base64_data) + + # Detect format + image_format = detect_image_format(image_bytes) + + # Generate unique filename + image_id = uuid.uuid4() + filename = f"{image_id}.{image_format.value}" + full_path = target_dir / filename + + # Build relative path + relative_path = f"{self.images_subdir}/{sanitized_subfolder}/{filename}" + + # Write to disk + with open(full_path, "wb") as f: + f.write(image_bytes) + + # Always validate in DISK mode to ensure data quality + self._validate_image(full_path) + + return relative_path + + def delete_image(self, relative_path: str) -> bool: + """Delete a saved image file given its relative path. + + Args: + relative_path: Relative path as returned by save_base64_image (e.g., "images/col/uuid.png") + + Returns: + True if the file was deleted, False if it didn't exist or deletion failed. + """ + try: + full_path = self.base_path / relative_path + if full_path.exists() and self.images_dir in full_path.parents: + full_path.unlink() + return True + except OSError: + pass + return False + + def _validate_image(self, image_path: Path) -> None: + """Validate that saved image is readable. + + Args: + image_path: Path to image file + + Raises: + ValueError: If image is corrupted or unreadable + """ + try: + validate_image(image_path) + except ValueError: + # Clean up invalid file + image_path.unlink(missing_ok=True) + raise + + def _ensure_images_directory(self) -> None: + """Create images directory if it doesn't exist (lazy initialization).""" + self.images_dir.mkdir(parents=True, exist_ok=True) + + def _sanitize_subfolder_name(self, name: str) -> str: + """Sanitize subfolder name to prevent path traversal and filesystem issues.""" + # Replace path separators and parent directory references with underscores + return name.replace("/", "_").replace("\\", "_").replace("..", "_") diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py new file mode 100644 index 00000000..ca5cbfae --- /dev/null +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import Mock, patch + +import pytest + +from data_designer.config.column_configs import ImageColumnConfig +from data_designer.config.models import ImageContext, ImageFormat, ModalityDataType +from data_designer.engine.column_generators.generators.base import GenerationStrategy +from data_designer.engine.column_generators.generators.image import ImageCellGenerator +from data_designer.engine.processing.ginja.exceptions import UserTemplateError + + +@pytest.fixture +def stub_image_column_config(): + return ImageColumnConfig(name="test_image", prompt="A {{ style }} image of {{ subject }}", model_alias="test_model") + + +@pytest.fixture +def stub_base64_images() -> list[str]: + return ["base64_image_1", "base64_image_2"] + + +def test_image_cell_generator_generation_strategy( + stub_image_column_config: ImageColumnConfig, stub_resource_provider: None +) -> None: + generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) + assert generator.get_generation_strategy() == GenerationStrategy.CELL_BY_CELL + + +def test_image_cell_generator_media_storage_property( + stub_image_column_config: ImageColumnConfig, stub_resource_provider: None +) -> None: + generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) + # Should return media_storage from artifact_storage (always exists) + assert generator.media_storage is not None + + +def test_image_cell_generator_generate_with_storage( + stub_image_column_config, stub_resource_provider, stub_base64_images +): + """Test generate with media storage (create mode) - saves to disk.""" + # Setup mock media storage + mock_storage = Mock() + mock_storage.save_base64_image.side_effect = [ + "images/test_image/uuid1.png", + "images/test_image/uuid2.png", + ] + stub_resource_provider.artifact_storage.media_storage = mock_storage + + with patch.object( + stub_resource_provider.model_registry.get_model.return_value, + "generate_image", + return_value=stub_base64_images, + ) as mock_generate: + generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) + data = generator.generate(data={"style": "photorealistic", "subject": "cat"}) + + # Check that column was added with relative paths (organized in subfolder) + assert stub_image_column_config.name in data + assert data[stub_image_column_config.name] == [ + "images/test_image/uuid1.png", + "images/test_image/uuid2.png", + ] + + # Verify model was called with rendered prompt + mock_generate.assert_called_once_with(prompt="A photorealistic image of cat", multi_modal_context=None) + + # Verify storage was called for each image with subfolder name + assert mock_storage.save_base64_image.call_count == 2 + mock_storage.save_base64_image.assert_any_call("base64_image_1", subfolder_name="test_image") + mock_storage.save_base64_image.assert_any_call("base64_image_2", subfolder_name="test_image") + + +def test_image_cell_generator_generate_in_dataframe_mode( + stub_image_column_config, stub_resource_provider, stub_base64_images +): + """Test generate with media storage in DATAFRAME mode - stores base64 directly.""" + # Mock save_base64_image to return base64 directly (simulating DATAFRAME mode) + mock_storage = Mock() + mock_storage.save_base64_image.side_effect = stub_base64_images + stub_resource_provider.artifact_storage.media_storage = mock_storage + + with patch.object( + stub_resource_provider.model_registry.get_model.return_value, + "generate_image", + return_value=stub_base64_images, + ) as mock_generate: + generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) + data = generator.generate(data={"style": "watercolor", "subject": "dog"}) + + # Check that column was added with base64 data (simulating DATAFRAME mode) + assert stub_image_column_config.name in data + assert data[stub_image_column_config.name] == stub_base64_images + + # Verify model was called with rendered prompt + mock_generate.assert_called_once_with(prompt="A watercolor image of dog", multi_modal_context=None) + + # Verify storage was called for each image with subfolder name (even in DATAFRAME mode) + assert mock_storage.save_base64_image.call_count == 2 + mock_storage.save_base64_image.assert_any_call("base64_image_1", subfolder_name="test_image") + mock_storage.save_base64_image.assert_any_call("base64_image_2", subfolder_name="test_image") + + +def test_image_cell_generator_missing_columns_error(stub_image_column_config, stub_resource_provider): + """Test that missing required columns raises ValueError.""" + generator = ImageCellGenerator(config=stub_image_column_config, resource_provider=stub_resource_provider) + + with pytest.raises(ValueError, match="columns.*missing"): + # Missing 'subject' column + generator.generate(data={"style": "photorealistic"}) + + +def test_image_cell_generator_empty_prompt_error(stub_resource_provider): + """Test that empty rendered prompt raises UserTemplateError.""" + # Create config with template that renders to empty string + config = ImageColumnConfig(name="test_image", prompt="{{ empty }}", model_alias="test_model") + + generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) + + with pytest.raises(UserTemplateError): + generator.generate(data={"empty": ""}) + + +def test_image_cell_generator_whitespace_only_prompt_error(stub_resource_provider): + """Test that whitespace-only rendered prompt raises ValueError.""" + config = ImageColumnConfig(name="test_image", prompt="{{ spaces }}", model_alias="test_model") + + generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) + + with pytest.raises(ValueError, match="empty"): + generator.generate(data={"spaces": " "}) + + +def test_image_cell_generator_with_multi_modal_context(stub_resource_provider): + """Test generate with multi-modal context for autoregressive models.""" + # Create image context that references a column with URL + image_context = ImageContext(column_name="reference_image", data_type=ModalityDataType.URL) + + config = ImageColumnConfig( + name="test_image", + prompt="Generate a similar image to the reference", + model_alias="test_model", + multi_modal_context=[image_context], + ) + + # Setup mock media storage + mock_storage = Mock() + mock_storage.save_base64_image.return_value = "images/generated.png" + stub_resource_provider.artifact_storage.media_storage = mock_storage + + stub_base64_images = ["base64_generated_image"] + + with patch.object( + stub_resource_provider.model_registry.get_model.return_value, + "generate_image", + return_value=stub_base64_images, + ) as mock_generate: + generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) + data = generator.generate(data={"reference_image": "https://example.com/image.png"}) + + # Check that column was added + assert config.name in data + assert data[config.name] == ["images/generated.png"] + + # Verify model was called with prompt and multi_modal_context + mock_generate.assert_called_once() + call_args = mock_generate.call_args + assert call_args.kwargs["prompt"] == "Generate a similar image to the reference" + assert call_args.kwargs["multi_modal_context"] is not None + assert len(call_args.kwargs["multi_modal_context"]) == 1 + assert call_args.kwargs["multi_modal_context"][0]["type"] == "image_url" + assert call_args.kwargs["multi_modal_context"][0]["image_url"] == "https://example.com/image.png" + + +def test_image_cell_generator_with_base64_multi_modal_context(stub_resource_provider): + """Test generate with base64 multi-modal context.""" + # Create image context that references a column with base64 data + image_context = ImageContext( + column_name="reference_image", data_type=ModalityDataType.BASE64, image_format=ImageFormat.PNG + ) + + config = ImageColumnConfig( + name="test_image", + prompt="Generate a variation of this image", + model_alias="test_model", + multi_modal_context=[image_context], + ) + + # Setup mock media storage + mock_storage = Mock() + mock_storage.save_base64_image.return_value = "images/generated.png" + stub_resource_provider.artifact_storage.media_storage = mock_storage + + stub_base64_images = ["base64_generated_image"] + + with patch.object( + stub_resource_provider.model_registry.get_model.return_value, + "generate_image", + return_value=stub_base64_images, + ) as mock_generate: + generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) + data = generator.generate(data={"reference_image": "iVBORw0KGgoAAAANS"}) + + # Check that column was added + assert config.name in data + assert data[config.name] == ["images/generated.png"] + + # Verify model was called with prompt and multi_modal_context + mock_generate.assert_called_once() + call_args = mock_generate.call_args + assert call_args.kwargs["prompt"] == "Generate a variation of this image" + assert call_args.kwargs["multi_modal_context"] is not None + assert len(call_args.kwargs["multi_modal_context"]) == 1 + assert call_args.kwargs["multi_modal_context"][0]["type"] == "image_url" + # Should be formatted as data URI + assert "data:image/png;base64," in call_args.kwargs["multi_modal_context"][0]["image_url"]["url"] diff --git a/packages/data-designer-engine/tests/engine/column_generators/utils/test_generator_classification.py b/packages/data-designer-engine/tests/engine/column_generators/utils/test_generator_classification.py index e1136233..fb3cdab5 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/utils/test_generator_classification.py +++ b/packages/data-designer-engine/tests/engine/column_generators/utils/test_generator_classification.py @@ -14,6 +14,7 @@ def test_column_type_is_model_generated() -> None: assert column_type_is_model_generated(DataDesignerColumnType.LLM_STRUCTURED) assert column_type_is_model_generated(DataDesignerColumnType.LLM_JUDGE) assert column_type_is_model_generated(DataDesignerColumnType.EMBEDDING) + assert column_type_is_model_generated(DataDesignerColumnType.IMAGE) assert not column_type_is_model_generated(DataDesignerColumnType.SAMPLER) assert not column_type_is_model_generated(DataDesignerColumnType.VALIDATION) assert not column_type_is_model_generated(DataDesignerColumnType.EXPRESSION) @@ -29,5 +30,6 @@ def test_column_type_used_in_execution_dag() -> None: assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_TEXT) assert column_type_used_in_execution_dag(DataDesignerColumnType.VALIDATION) assert column_type_used_in_execution_dag(DataDesignerColumnType.EMBEDDING) + assert column_type_used_in_execution_dag(DataDesignerColumnType.IMAGE) assert not column_type_used_in_execution_dag(DataDesignerColumnType.SAMPLER) assert not column_type_used_in_execution_dag(DataDesignerColumnType.SEED_DATASET) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_artifact_storage.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_artifact_storage.py index df15b4f7..35edf892 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_artifact_storage.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_artifact_storage.py @@ -213,10 +213,11 @@ def test_artifact_storage_resolved_dataset_name(mock_datetime, tmp_path): (af_storage.artifact_path / af_storage.dataset_name).mkdir() assert af_storage.resolved_dataset_name == "dataset" - # dataset path exists and is not empty + # dataset path exists and is not empty (create file BEFORE constructing ArtifactStorage) + dataset_dir = tmp_path / "dataset" + dataset_dir.mkdir(exist_ok=True) + (dataset_dir / "stub_file.txt").touch() af_storage = ArtifactStorage(artifact_path=tmp_path) - (af_storage.artifact_path / af_storage.dataset_name / "stub_file.txt").touch() - print(af_storage.resolved_dataset_name) assert af_storage.resolved_dataset_name == "dataset_01-01-2025_120304" diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 235ead55..84da6ebb 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -1,18 +1,23 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Any -from unittest.mock import patch +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest.mock import MagicMock, patch import pytest -from litellm.types.utils import Choices, EmbeddingResponse, Message, ModelResponse from data_designer.engine.mcp.errors import MCPConfigurationError, MCPToolError -from data_designer.engine.models.errors import ModelGenerationValidationFailureError +from data_designer.engine.models.errors import ImageGenerationError, ModelGenerationValidationFailureError from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.parsers.errors import ParserException from data_designer.engine.models.utils import ChatMessage from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, StubMessage, StubResponse +from data_designer.lazy_heavy_imports import litellm + +if TYPE_CHECKING: + import litellm def mock_oai_response_object(response_text: str) -> StubResponse: @@ -35,12 +40,14 @@ def stub_completion_messages() -> list[ChatMessage]: @pytest.fixture def stub_expected_completion_response(): - return ModelResponse(choices=Choices(message=Message(content="Test response"))) + return litellm.types.utils.ModelResponse( + choices=litellm.types.utils.Choices(message=litellm.types.utils.Message(content="Test response")) + ) @pytest.fixture def stub_expected_embedding_response(): - return EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2) + return litellm.types.utils.EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2) @pytest.mark.parametrize( @@ -106,9 +113,11 @@ def test_generate_with_system_prompt( # Capture messages at call time since they get mutated after the call captured_messages = [] - def capture_and_return(*args: Any, **kwargs: Any) -> ModelResponse: + def capture_and_return(*args: Any, **kwargs: Any) -> litellm.types.utils.ModelResponse: captured_messages.append(list(args[1])) # Copy the messages list - return ModelResponse(choices=Choices(message=Message(content="Hello!"))) + return litellm.types.utils.ModelResponse( + choices=litellm.types.utils.Choices(message=litellm.types.utils.Message(content="Hello!")) + ) mock_completion.side_effect = capture_and_return @@ -188,7 +197,7 @@ def test_completion_success( stub_completion_messages: list[ChatMessage], stub_model_configs: Any, stub_model_facade: ModelFacade, - stub_expected_completion_response: ModelResponse, + stub_expected_completion_response: litellm.types.utils.ModelResponse, skip_usage_tracking: bool, ) -> None: mock_router_completion.side_effect = lambda self, model, messages, **kwargs: stub_expected_completion_response @@ -221,11 +230,13 @@ def test_completion_with_kwargs( stub_completion_messages: list[ChatMessage], stub_model_configs: Any, stub_model_facade: ModelFacade, - stub_expected_completion_response: ModelResponse, + stub_expected_completion_response: litellm.types.utils.ModelResponse, ) -> None: captured_kwargs = {} - def mock_completion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> ModelResponse: + def mock_completion( + self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any + ) -> litellm.types.utils.ModelResponse: captured_kwargs.update(kwargs) return stub_expected_completion_response @@ -1011,3 +1022,253 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe with patch.object(ModelFacade, "completion", new=_completion): with pytest.raises(MCPToolError, match="Invalid tool arguments"): model.generate(prompt="question", parser=lambda x: x, tool_alias="tools") + + +# ============================================================================= +# Image generation tests +# ============================================================================= + + +@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) +def test_generate_image_diffusion_tracks_image_usage( + mock_image_generation: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image tracks image usage for diffusion models.""" + # Mock response with 3 images + mock_response = litellm.types.utils.ImageResponse( + data=[ + litellm.types.utils.ImageObject(b64_json="image1_base64"), + litellm.types.utils.ImageObject(b64_json="image2_base64"), + litellm.types.utils.ImageObject(b64_json="image3_base64"), + ] + ) + mock_image_generation.return_value = mock_response + + # Verify initial state + assert stub_model_facade.usage_stats.image_usage.total_images == 0 + + # Generate images + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): + images = stub_model_facade.generate_image(prompt="test prompt", n=3) + + # Verify results + assert len(images) == 3 + assert images == ["image1_base64", "image2_base64", "image3_base64"] + + # Verify image usage was tracked + assert stub_model_facade.usage_stats.image_usage.total_images == 3 + assert stub_model_facade.usage_stats.image_usage.has_usage is True + + +@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +def test_generate_image_chat_completion_tracks_image_usage( + mock_completion: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image tracks image usage for chat completion models.""" + # Mock response with images attribute (Message requires type and index per ImageURLListItem) + mock_message = litellm.types.utils.Message( + role="assistant", + content="", + images=[ + litellm.types.utils.ImageURLListItem( + type="image_url", image_url={"url": "data:image/png;base64,image1"}, index=0 + ), + litellm.types.utils.ImageURLListItem( + type="image_url", image_url={"url": "data:image/png;base64,image2"}, index=1 + ), + ], + ) + mock_response = litellm.types.utils.ModelResponse(choices=[litellm.types.utils.Choices(message=mock_message)]) + mock_completion.return_value = mock_response + + # Verify initial state + assert stub_model_facade.usage_stats.image_usage.total_images == 0 + + # Generate images + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): + images = stub_model_facade.generate_image(prompt="test prompt") + + # Verify results + assert len(images) == 2 + assert images == ["image1", "image2"] + + # Verify image usage was tracked + assert stub_model_facade.usage_stats.image_usage.total_images == 2 + assert stub_model_facade.usage_stats.image_usage.has_usage is True + + +@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +def test_generate_image_chat_completion_with_dict_format( + mock_completion: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image handles images as dicts with image_url string.""" + # Create mock message with images as dict with string image_url + mock_message = MagicMock() + mock_message.role = "assistant" + mock_message.content = "" + mock_message.images = [ + {"image_url": "data:image/png;base64,image1"}, + {"image_url": "data:image/jpeg;base64,image2"}, + ] + + mock_choice = MagicMock() + mock_choice.message = mock_message + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + mock_completion.return_value = mock_response + + # Generate images + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): + images = stub_model_facade.generate_image(prompt="test prompt") + + # Verify results + assert len(images) == 2 + assert images == ["image1", "image2"] + + +@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +def test_generate_image_chat_completion_with_plain_strings( + mock_completion: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image handles images as plain strings.""" + # Create mock message with images as plain strings + mock_message = MagicMock() + mock_message.role = "assistant" + mock_message.content = "" + mock_message.images = [ + "data:image/png;base64,image1", + "image2", # Plain base64 without data URI prefix + ] + + mock_choice = MagicMock() + mock_choice.message = mock_message + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + mock_completion.return_value = mock_response + + # Generate images + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): + images = stub_model_facade.generate_image(prompt="test prompt") + + # Verify results + assert len(images) == 2 + assert images == ["image1", "image2"] + + +@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) +def test_generate_image_skip_usage_tracking( + mock_image_generation: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image respects skip_usage_tracking flag.""" + mock_response = litellm.types.utils.ImageResponse( + data=[ + litellm.types.utils.ImageObject(b64_json="image1_base64"), + litellm.types.utils.ImageObject(b64_json="image2_base64"), + ] + ) + mock_image_generation.return_value = mock_response + + # Verify initial state + assert stub_model_facade.usage_stats.image_usage.total_images == 0 + + # Generate images with skip_usage_tracking=True + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): + images = stub_model_facade.generate_image(prompt="test prompt", skip_usage_tracking=True) + + # Verify results + assert len(images) == 2 + + # Verify image usage was NOT tracked + assert stub_model_facade.usage_stats.image_usage.total_images == 0 + assert stub_model_facade.usage_stats.image_usage.has_usage is False + + +@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +def test_generate_image_chat_completion_no_choices( + mock_completion: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image raises ImageGenerationError when response has no choices.""" + mock_response = litellm.types.utils.ModelResponse(choices=[]) + mock_completion.return_value = mock_response + + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): + with pytest.raises(ImageGenerationError, match="Image generation response missing choices"): + stub_model_facade.generate_image(prompt="test prompt") + + +@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +def test_generate_image_chat_completion_no_image_data( + mock_completion: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image raises ImageGenerationError when no image data in response.""" + mock_message = litellm.types.utils.Message(role="assistant", content="just text, no image") + mock_response = litellm.types.utils.ModelResponse(choices=[litellm.types.utils.Choices(message=mock_message)]) + mock_completion.return_value = mock_response + + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): + with pytest.raises(ImageGenerationError, match="No image data found in image generation response"): + stub_model_facade.generate_image(prompt="test prompt") + + +@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) +def test_generate_image_diffusion_no_data( + mock_image_generation: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image raises ImageGenerationError when diffusion API returns no data.""" + mock_response = litellm.types.utils.ImageResponse(data=[]) + mock_image_generation.return_value = mock_response + + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): + with pytest.raises(ImageGenerationError, match="Image generation returned no data"): + stub_model_facade.generate_image(prompt="test prompt") + + +@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) +def test_generate_image_accumulates_usage( + mock_image_generation: Any, + stub_model_facade: ModelFacade, +) -> None: + """Test that generate_image accumulates image usage across multiple calls.""" + # First call - 2 images + mock_response1 = litellm.types.utils.ImageResponse( + data=[ + litellm.types.utils.ImageObject(b64_json="image1"), + litellm.types.utils.ImageObject(b64_json="image2"), + ] + ) + # Second call - 3 images + mock_response2 = litellm.types.utils.ImageResponse( + data=[ + litellm.types.utils.ImageObject(b64_json="image3"), + litellm.types.utils.ImageObject(b64_json="image4"), + litellm.types.utils.ImageObject(b64_json="image5"), + ] + ) + mock_image_generation.side_effect = [mock_response1, mock_response2] + + # Verify initial state + assert stub_model_facade.usage_stats.image_usage.total_images == 0 + + # First generation + with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): + images1 = stub_model_facade.generate_image(prompt="test1") + assert len(images1) == 2 + assert stub_model_facade.usage_stats.image_usage.total_images == 2 + + # Second generation + images2 = stub_model_facade.generate_image(prompt="test2") + assert len(images2) == 3 + # Usage should accumulate + assert stub_model_facade.usage_stats.image_usage.total_images == 5 diff --git a/packages/data-designer-engine/tests/engine/models/test_usage.py b/packages/data-designer-engine/tests/engine/models/test_usage.py index 8e7adb04..2bfea4b4 100644 --- a/packages/data-designer-engine/tests/engine/models/test_usage.py +++ b/packages/data-designer-engine/tests/engine/models/test_usage.py @@ -1,7 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats, ToolUsageStats +from data_designer.engine.models.usage import ( + ImageUsageStats, + ModelUsageStats, + RequestUsageStats, + TokenUsageStats, + ToolUsageStats, +) def test_token_usage_stats() -> None: @@ -32,6 +38,20 @@ def test_request_usage_stats() -> None: assert request_usage_stats.has_usage is True +def test_image_usage_stats() -> None: + image_usage_stats = ImageUsageStats() + assert image_usage_stats.total_images == 0 + assert image_usage_stats.has_usage is False + + image_usage_stats.extend(images=5) + assert image_usage_stats.total_images == 5 + assert image_usage_stats.has_usage is True + + image_usage_stats.extend(images=3) + assert image_usage_stats.total_images == 8 + assert image_usage_stats.has_usage is True + + def test_tool_usage_stats_empty_state() -> None: """Test ToolUsageStats initialization with empty state.""" tool_usage = ToolUsageStats() @@ -132,9 +152,10 @@ def test_model_usage_stats() -> None: assert model_usage_stats.token_usage.output_tokens == 0 assert model_usage_stats.request_usage.successful_requests == 0 assert model_usage_stats.request_usage.failed_requests == 0 + assert model_usage_stats.image_usage.total_images == 0 assert model_usage_stats.has_usage is False - # tool_usage is excluded when has_usage is False + # tool_usage and image_usage are excluded when has_usage is False assert model_usage_stats.get_usage_stats(total_time_elapsed=10) == { "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, "request_usage": {"successful_requests": 0, "failed_requests": 0, "total_requests": 0}, @@ -152,7 +173,7 @@ def test_model_usage_stats() -> None: assert model_usage_stats.request_usage.failed_requests == 1 assert model_usage_stats.has_usage is True - # tool_usage is excluded when has_usage is False + # tool_usage and image_usage are excluded when has_usage is False assert model_usage_stats.get_usage_stats(total_time_elapsed=2) == { "token_usage": {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, "request_usage": {"successful_requests": 2, "failed_requests": 1, "total_requests": 3}, @@ -177,3 +198,58 @@ def test_model_usage_stats_extend_with_tool_usage() -> None: assert stats1.tool_usage.total_tool_call_turns == 6 assert stats1.tool_usage.total_generations == 4 assert stats1.tool_usage.generations_with_tools == 3 + + +def test_model_usage_stats_with_image_usage() -> None: + """Test that ModelUsageStats includes image_usage when it has usage.""" + model_usage_stats = ModelUsageStats() + model_usage_stats.extend( + token_usage=TokenUsageStats(input_tokens=10, output_tokens=20), + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + image_usage=ImageUsageStats(total_images=5), + ) + + assert model_usage_stats.image_usage.total_images == 5 + assert model_usage_stats.image_usage.has_usage is True + + # image_usage should be included in output + usage_stats = model_usage_stats.get_usage_stats(total_time_elapsed=2) + assert "image_usage" in usage_stats + assert usage_stats["image_usage"] == {"total_images": 5} + + +def test_model_usage_stats_has_usage_any_of() -> None: + """Test that has_usage is True when any of token, request, or image usage is present.""" + # Only token usage + stats = ModelUsageStats() + stats.extend(token_usage=TokenUsageStats(input_tokens=1, output_tokens=0)) + assert stats.has_usage is True + + # Only request usage (e.g. diffusion API without token counts) + stats = ModelUsageStats() + stats.extend(request_usage=RequestUsageStats(successful_requests=1, failed_requests=0)) + assert stats.has_usage is True + + # Only image usage + stats = ModelUsageStats() + stats.extend(image_usage=ImageUsageStats(total_images=2)) + assert stats.has_usage is True + + # None of the three + stats = ModelUsageStats() + assert stats.has_usage is False + + +def test_model_usage_stats_exclude_unused_stats() -> None: + """Test that ModelUsageStats excludes tool_usage and image_usage when they have no usage.""" + model_usage_stats = ModelUsageStats() + model_usage_stats.extend( + token_usage=TokenUsageStats(input_tokens=10, output_tokens=20), + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + ) + + usage_stats = model_usage_stats.get_usage_stats(total_time_elapsed=2) + assert "tool_usage" not in usage_stats + assert "image_usage" not in usage_stats + assert "token_usage" in usage_stats + assert "request_usage" in usage_stats diff --git a/packages/data-designer-engine/tests/engine/storage/__init__.py b/packages/data-designer-engine/tests/engine/storage/__init__.py new file mode 100644 index 00000000..e5725ea5 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/storage/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/data-designer-engine/tests/engine/storage/test_media_storage.py b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py new file mode 100644 index 00000000..e79c854b --- /dev/null +++ b/packages/data-designer-engine/tests/engine/storage/test_media_storage.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import base64 +import io + +import pytest + +from data_designer.engine.storage.media_storage import IMAGES_SUBDIR, MediaStorage, StorageMode +from data_designer.lazy_heavy_imports import Image + + +@pytest.fixture +def media_storage(tmp_path): + """Create a MediaStorage instance with a temporary directory.""" + return MediaStorage(base_path=tmp_path) + + +@pytest.fixture +def sample_base64_png() -> str: + """Create a valid 1x1 PNG as base64.""" + img = Image.new("RGB", (1, 1), color="red") + buf = io.BytesIO() + img.save(buf, format="PNG") + png_bytes = buf.getvalue() + return base64.b64encode(png_bytes).decode() + + +@pytest.fixture +def sample_base64_jpg() -> str: + """Create a valid 1x1 JPEG as base64.""" + img = Image.new("RGB", (1, 1), color="blue") + buf = io.BytesIO() + img.save(buf, format="JPEG") + jpg_bytes = buf.getvalue() + return base64.b64encode(jpg_bytes).decode() + + +@pytest.mark.parametrize( + "images_subdir,mode", + [ + (IMAGES_SUBDIR, StorageMode.DISK), + ("custom_images", StorageMode.DATAFRAME), + ], + ids=["defaults", "custom-subdir-dataframe"], +) +def test_media_storage_init(tmp_path, images_subdir: str, mode: StorageMode) -> None: + """Test MediaStorage initialization with various configurations.""" + storage = MediaStorage(base_path=tmp_path, images_subdir=images_subdir, mode=mode) + assert storage.base_path == tmp_path + assert storage.images_subdir == images_subdir + assert storage.images_dir == tmp_path / images_subdir + assert storage.mode == mode + # Directory should NOT exist until first save (lazy initialization) + assert not storage.images_dir.exists() + + +@pytest.mark.parametrize( + "image_fixture,expected_extension", + [ + ("sample_base64_png", ".png"), + ("sample_base64_jpg", ".jpg"), + ], +) +def test_save_base64_image_format(media_storage, image_fixture, expected_extension, request): + """Test saving images from base64 in different formats.""" + # Get the actual fixture value using request.getfixturevalue + sample_base64 = request.getfixturevalue(image_fixture) + + relative_path = media_storage.save_base64_image(sample_base64, subfolder_name="test_column") + + # Check return value format (organized by column name) + assert relative_path.startswith(f"{IMAGES_SUBDIR}/test_column/") + assert relative_path.endswith(expected_extension) + + # Check file exists on disk + full_path = media_storage.base_path / relative_path + assert full_path.exists() + + # Verify file content + saved_bytes = full_path.read_bytes() + expected_bytes = base64.b64decode(sample_base64) + assert saved_bytes == expected_bytes + + +def test_save_base64_image_with_data_uri(media_storage, sample_base64_png): + """Test saving image from data URI format.""" + data_uri = f"data:image/png;base64,{sample_base64_png}" + relative_path = media_storage.save_base64_image(data_uri, subfolder_name="test_column") + + # Should successfully extract base64 and save (organized by column name) + assert relative_path.startswith(f"{IMAGES_SUBDIR}/test_column/") + assert relative_path.endswith(".png") + + # Verify file exists and content is correct + full_path = media_storage.base_path / relative_path + assert full_path.exists() + saved_bytes = full_path.read_bytes() + expected_bytes = base64.b64decode(sample_base64_png) + assert saved_bytes == expected_bytes + + +def test_save_base64_image_invalid_base64_raises_error(media_storage): + """Test that invalid base64 data raises ValueError.""" + with pytest.raises(ValueError, match="Invalid base64"): + media_storage.save_base64_image("not-valid-base64!!!", subfolder_name="test_column") + + +def test_save_base64_image_multiple_images_unique_filenames(media_storage, sample_base64_png): + """Test that multiple images get unique filenames.""" + path1 = media_storage.save_base64_image(sample_base64_png, subfolder_name="test_column") + path2 = media_storage.save_base64_image(sample_base64_png, subfolder_name="test_column") + + # Paths should be different (different UUIDs) + assert path1 != path2 + + # Both files should exist + assert (media_storage.base_path / path1).exists() + assert (media_storage.base_path / path2).exists() + + +def test_save_base64_image_disk_mode_validates(tmp_path, sample_base64_png): + """Test that DISK mode validates images.""" + storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DISK) + # Should succeed with valid image + relative_path = storage.save_base64_image(sample_base64_png, subfolder_name="test_column") + assert relative_path.startswith(f"{IMAGES_SUBDIR}/test_column/") + + +def test_save_base64_image_disk_mode_corrupted_image_raises_error(tmp_path): + """Test that DISK mode validates and rejects corrupted images.""" + storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DISK) + + # Create base64 of invalid image data + corrupted_bytes = b"not a valid image" + corrupted_base64 = base64.b64encode(corrupted_bytes).decode() + + with pytest.raises(ValueError, match="Unable to detect image format"): + storage.save_base64_image(corrupted_base64, subfolder_name="test_column") + + # Check that no files were left behind (cleanup on validation failure) + column_dir = storage.images_dir / "test_column" + if column_dir.exists(): + assert len(list(column_dir.iterdir())) == 0 + + +@pytest.mark.parametrize("subfolder_name", ["test_column", "test_subfolder"], ids=["column", "subfolder"]) +def test_save_base64_image_dataframe_mode_returns_base64(tmp_path, sample_base64_png, subfolder_name): + """Test that DATAFRAME mode returns base64 directly regardless of subfolder name.""" + storage = MediaStorage(base_path=tmp_path, mode=StorageMode.DATAFRAME) + + result = storage.save_base64_image(sample_base64_png, subfolder_name=subfolder_name) + assert result == sample_base64_png + + # Directory should not be created in DATAFRAME mode (lazy initialization) + assert not storage.images_dir.exists() + + +def test_save_base64_image_with_subfolder_name(media_storage, sample_base64_png): + """Test saving image with subfolder name organizes into subdirectory.""" + subfolder = "test_subfolder" + relative_path = media_storage.save_base64_image(sample_base64_png, subfolder_name=subfolder) + + # Check return value format includes subfolder + assert relative_path.startswith(f"{IMAGES_SUBDIR}/{subfolder}/") + assert relative_path.endswith(".png") + + # Check file exists in correct subdirectory + full_path = media_storage.base_path / relative_path + assert full_path.exists() + assert full_path.parent.name == subfolder + + # Verify file content + saved_bytes = full_path.read_bytes() + expected_bytes = base64.b64decode(sample_base64_png) + assert saved_bytes == expected_bytes + + +def test_save_base64_image_with_different_subfolder_names(media_storage, sample_base64_png, sample_base64_jpg): + """Test that images with different subfolder names are stored in separate subdirectories.""" + path1 = media_storage.save_base64_image(sample_base64_png, subfolder_name="subfolder_a") + path2 = media_storage.save_base64_image(sample_base64_jpg, subfolder_name="subfolder_b") + + # Check paths are in different subdirectories + assert "subfolder_a" in path1 + assert "subfolder_b" in path2 + + # Check both directories exist + subfolder_a_dir = media_storage.images_dir / "subfolder_a" + subfolder_b_dir = media_storage.images_dir / "subfolder_b" + assert subfolder_a_dir.exists() + assert subfolder_b_dir.exists() + + # Check files exist in their respective directories + assert (media_storage.base_path / path1).exists() + assert (media_storage.base_path / path2).exists() + + +@pytest.mark.parametrize( + "unsafe_name,expected_sanitized", + [ + ("../evil", "__evil"), # Parent directory traversal: .. -> _, / -> _ + ("foo/bar", "foo_bar"), # Path separator (forward slash) + ("foo\\bar", "foo_bar"), # Path separator (backslash) + ("test..name", "test_name"), # Double dots in middle: .. -> _ + ], +) +def test_save_base64_image_sanitizes_subfolder_name(media_storage, sample_base64_png, unsafe_name, expected_sanitized): + """Test that subfolder names are sanitized to prevent path traversal.""" + relative_path = media_storage.save_base64_image(sample_base64_png, subfolder_name=unsafe_name) + + # Check that path contains sanitized subfolder name + assert expected_sanitized in relative_path + assert "/" not in expected_sanitized # No path separators + assert "\\" not in expected_sanitized # No backslashes + assert ".." not in expected_sanitized # No parent references + + # Verify file is inside images directory (not escaped via path traversal) + full_path = media_storage.base_path / relative_path + assert full_path.exists() + assert media_storage.images_dir in full_path.parents + + +# --------------------------------------------------------------------------- +# delete_image +# --------------------------------------------------------------------------- + + +def test_delete_image_removes_saved_file(media_storage, sample_base64_png) -> None: + """Test that delete_image removes a previously saved image.""" + relative_path = media_storage.save_base64_image(sample_base64_png, subfolder_name="col") + full_path = media_storage.base_path / relative_path + assert full_path.exists() + + result = media_storage.delete_image(relative_path) + assert result is True + assert not full_path.exists() + + +def test_delete_image_returns_false_for_nonexistent(media_storage) -> None: + """Test that delete_image returns False when the file doesn't exist.""" + assert media_storage.delete_image(f"{IMAGES_SUBDIR}/col/nonexistent.png") is False + + +def test_delete_image_rejects_path_outside_images_dir(media_storage, tmp_path) -> None: + """Test that delete_image refuses to delete files outside the images directory.""" + outside_file = tmp_path / "outside.txt" + outside_file.write_text("should not be deleted") + + result = media_storage.delete_image("../outside.txt") + assert result is False + assert outside_file.exists() diff --git a/packages/data-designer-engine/tests/engine/test_configurable_task.py b/packages/data-designer-engine/tests/engine/test_configurable_task.py index f20936a2..6e3673de 100644 --- a/packages/data-designer-engine/tests/engine/test_configurable_task.py +++ b/packages/data-designer-engine/tests/engine/test_configurable_task.py @@ -25,7 +25,7 @@ def test_configurable_task_generic_type_variables() -> None: assert TaskConfigT.__bound__ == ConfigBase -def test_configurable_task_concrete_implementation() -> None: +def test_configurable_task_concrete_implementation(tmp_path) -> None: class TestConfig(ConfigBase): value: str @@ -41,13 +41,8 @@ def _initialize(self) -> None: pass config = TestConfig(value="test") - mock_artifact_storage = Mock(spec=ArtifactStorage) - mock_artifact_storage.dataset_name = "test_dataset" - mock_artifact_storage.final_dataset_folder_name = "final_dataset" - mock_artifact_storage.partial_results_folder_name = "partial_results" - mock_artifact_storage.dropped_columns_folder_name = "dropped_columns" - mock_artifact_storage.processors_outputs_folder_name = "processors_outputs" - resource_provider = ResourceProvider(artifact_storage=mock_artifact_storage) + artifact_storage = ArtifactStorage(artifact_path=tmp_path) + resource_provider = ResourceProvider(artifact_storage=artifact_storage) task = TestTask(config=config, resource_provider=resource_provider) @@ -55,7 +50,7 @@ def _initialize(self) -> None: assert task._resource_provider == resource_provider -def test_configurable_task_config_validation() -> None: +def test_configurable_task_config_validation(tmp_path) -> None: class TestConfig(ConfigBase): value: str @@ -69,13 +64,8 @@ def _validate(self) -> None: raise ValueError("Invalid config") config = TestConfig(value="test") - mock_artifact_storage = Mock(spec=ArtifactStorage) - mock_artifact_storage.dataset_name = "test_dataset" - mock_artifact_storage.final_dataset_folder_name = "final_dataset" - mock_artifact_storage.partial_results_folder_name = "partial_results" - mock_artifact_storage.dropped_columns_folder_name = "dropped_columns" - mock_artifact_storage.processors_outputs_folder_name = "processors_outputs" - resource_provider = ResourceProvider(artifact_storage=mock_artifact_storage) + artifact_storage = ArtifactStorage(artifact_path=tmp_path) + resource_provider = ResourceProvider(artifact_storage=artifact_storage) task = TestTask(config=config, resource_provider=resource_provider) assert task._config.value == "test" @@ -85,7 +75,7 @@ def _validate(self) -> None: TestTask(config=invalid_config, resource_provider=resource_provider) -def test_configurable_task_resource_validation() -> None: +def test_configurable_task_resource_validation(tmp_path) -> None: class TestConfig(ConfigBase): value: str @@ -102,14 +92,9 @@ def _initialize(self) -> None: config = TestConfig(value="test") - mock_artifact_storage = Mock(spec=ArtifactStorage) - mock_artifact_storage.dataset_name = "test_dataset" - mock_artifact_storage.final_dataset_folder_name = "final_dataset" - mock_artifact_storage.partial_results_folder_name = "partial_results" - mock_artifact_storage.dropped_columns_folder_name = "dropped_columns" - mock_artifact_storage.processors_outputs_folder_name = "processors_outputs" + artifact_storage = ArtifactStorage(artifact_path=tmp_path) mock_model_registry = Mock(spec=ModelRegistry) - resource_provider = ResourceProvider(artifact_storage=mock_artifact_storage, model_registry=mock_model_registry) + resource_provider = ResourceProvider(artifact_storage=artifact_storage, model_registry=mock_model_registry) task = TestTask(config=config, resource_provider=resource_provider) assert task._resource_provider == resource_provider diff --git a/packages/data-designer/src/data_designer/integrations/huggingface/client.py b/packages/data-designer/src/data_designer/integrations/huggingface/client.py index c047d73b..1d0a0f0e 100644 --- a/packages/data-designer/src/data_designer/integrations/huggingface/client.py +++ b/packages/data-designer/src/data_designer/integrations/huggingface/client.py @@ -66,6 +66,7 @@ def upload_dataset( Uploads the complete dataset including: - Main parquet batch files from parquet-files/ β†’ data/ + - Images from images/ β†’ images/ (if present) - Processor output batch files from processors-files/{name}/ β†’ {name}/ - Existing builder_config.json and metadata.json files - Auto-generated README.md (dataset card) @@ -102,6 +103,7 @@ def upload_dataset( raise HuggingFaceHubClientUploadError(f"Failed to upload dataset card: {e}") from e self._upload_main_dataset_files(repo_id=repo_id, parquet_folder=base_dataset_path / FINAL_DATASET_FOLDER_NAME) + self._upload_images_folder(repo_id=repo_id, images_folder=base_dataset_path / "images") self._upload_processor_files( repo_id=repo_id, processors_folder=base_dataset_path / PROCESSORS_OUTPUTS_FOLDER_NAME ) @@ -178,6 +180,36 @@ def _upload_main_dataset_files(self, repo_id: str, parquet_folder: Path) -> None except Exception as e: raise HuggingFaceHubClientUploadError(f"Failed to upload parquet files: {e}") from e + def _upload_images_folder(self, repo_id: str, images_folder: Path) -> None: + """Upload images folder to Hugging Face Hub. + + Args: + repo_id: Hugging Face dataset repo ID + images_folder: Path to images folder + + Raises: + HuggingFaceUploadError: If upload fails + """ + if not images_folder.exists(): + return + + image_files = list(images_folder.rglob("*.*")) + if not image_files: + return + + logger.info(f" |-- {RandomEmoji.loading()} Uploading {len(image_files)} image files...") + + try: + self._api.upload_folder( + repo_id=repo_id, + folder_path=str(images_folder), + path_in_repo="images", + repo_type="dataset", + commit_message="Upload images", + ) + except Exception as e: + raise HuggingFaceHubClientUploadError(f"Failed to upload images: {e}") from e + def _upload_processor_files(self, repo_id: str, processors_folder: Path) -> None: """Upload processor output files. diff --git a/packages/data-designer/tests/integrations/huggingface/test_client.py b/packages/data-designer/tests/integrations/huggingface/test_client.py index 735ea3bc..924a6bfe 100644 --- a/packages/data-designer/tests/integrations/huggingface/test_client.py +++ b/packages/data-designer/tests/integrations/huggingface/test_client.py @@ -462,6 +462,76 @@ def test_validate_dataset_path_invalid_builder_config_json(tmp_path: Path) -> No client.upload_dataset("test/dataset", base_path, "Test") +def test_upload_dataset_uploads_images_folder( + mock_hf_api: MagicMock, mock_dataset_card: MagicMock, sample_dataset_path: Path +) -> None: + """Test that upload_dataset uploads images when images folder exists with subfolders.""" + # Create images directory with column subfolders (matches MediaStorage structure) + images_dir = sample_dataset_path / "images" + col_dir = images_dir / "my_image_column" + col_dir.mkdir(parents=True) + (col_dir / "uuid1.png").write_bytes(b"fake png data") + (col_dir / "uuid2.png").write_bytes(b"fake png data") + + client = HuggingFaceHubClient(token="test-token") + client.upload_dataset(repo_id="test/dataset", base_dataset_path=sample_dataset_path, description="Test dataset") + + # Check that upload_folder was called for images + image_calls = [call for call in mock_hf_api.upload_folder.call_args_list if call.kwargs["path_in_repo"] == "images"] + assert len(image_calls) == 1 + assert image_calls[0].kwargs["folder_path"] == str(images_dir) + assert image_calls[0].kwargs["repo_type"] == "dataset" + + +def test_upload_dataset_skips_images_when_folder_missing( + mock_hf_api: MagicMock, mock_dataset_card: MagicMock, sample_dataset_path: Path +) -> None: + """Test that upload_dataset skips images upload when images folder doesn't exist.""" + # sample_dataset_path has no images/ directory by default + client = HuggingFaceHubClient(token="test-token") + client.upload_dataset(repo_id="test/dataset", base_dataset_path=sample_dataset_path, description="Test dataset") + + # No upload_folder call should target "images" + image_calls = [call for call in mock_hf_api.upload_folder.call_args_list if call.kwargs["path_in_repo"] == "images"] + assert len(image_calls) == 0 + + +def test_upload_dataset_skips_images_when_folder_empty( + mock_hf_api: MagicMock, mock_dataset_card: MagicMock, sample_dataset_path: Path +) -> None: + """Test that upload_dataset skips images upload when images folder exists but is empty.""" + images_dir = sample_dataset_path / "images" + images_dir.mkdir() + + client = HuggingFaceHubClient(token="test-token") + client.upload_dataset(repo_id="test/dataset", base_dataset_path=sample_dataset_path, description="Test dataset") + + image_calls = [call for call in mock_hf_api.upload_folder.call_args_list if call.kwargs["path_in_repo"] == "images"] + assert len(image_calls) == 0 + + +def test_upload_dataset_images_upload_failure( + mock_hf_api: MagicMock, mock_dataset_card: MagicMock, sample_dataset_path: Path +) -> None: + """Test that upload_dataset raises error when images upload fails.""" + # Create images directory with a file + images_dir = sample_dataset_path / "images" + col_dir = images_dir / "col" + col_dir.mkdir(parents=True) + (col_dir / "img.png").write_bytes(b"fake") + + # Make upload_folder fail only for images + def failing_upload_folder(**kwargs): + if kwargs.get("path_in_repo") == "images": + raise Exception("Network error") + + mock_hf_api.upload_folder.side_effect = failing_upload_folder + + client = HuggingFaceHubClient(token="test-token") + with pytest.raises(HuggingFaceHubClientUploadError, match="Failed to upload images"): + client.upload_dataset(repo_id="test/dataset", base_dataset_path=sample_dataset_path, description="Test dataset") + + def test_upload_dataset_invalid_repo_id(mock_hf_api: MagicMock, sample_dataset_path: Path) -> None: """Test upload_dataset fails with invalid repo_id.""" client = HuggingFaceHubClient(token="test-token") diff --git a/pyproject.toml b/pyproject.toml index f93bdcd5..4f8578b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,6 @@ notebooks = [ "datasets>=4.0.0,<5", "ipykernel>=6.29.0,<7", "jupyter>=1.0.0,<2", - "pillow>=12.0.0,<13", ] recipes = [ "bm25s>=0.2.0,<1", diff --git a/uv.lock b/uv.lock index 2271ad88..dbb512fd 100644 --- a/uv.lock +++ b/uv.lock @@ -799,6 +799,7 @@ dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pandas" }, + { name = "pillow" }, { name = "pyarrow" }, { name = "pydantic", extra = ["email"] }, { name = "pygments" }, @@ -813,6 +814,7 @@ requires-dist = [ { name = "jinja2", specifier = ">=3.1.6,<4" }, { name = "numpy", specifier = ">=1.23.5,<3" }, { name = "pandas", specifier = ">=2.3.3,<3" }, + { name = "pillow", specifier = ">=12.0.0,<13" }, { name = "pyarrow", specifier = ">=19.0.1,<20" }, { name = "pydantic", extras = ["email"], specifier = ">=2.9.2,<3" }, { name = "pygments", specifier = ">=2.19.2,<3" }, @@ -902,7 +904,6 @@ notebooks = [ { name = "datasets" }, { name = "ipykernel" }, { name = "jupyter" }, - { name = "pillow" }, ] recipes = [ { name = "bm25s" }, @@ -936,7 +937,6 @@ notebooks = [ { name = "datasets", specifier = ">=4.0.0,<5" }, { name = "ipykernel", specifier = ">=6.29.0,<7" }, { name = "jupyter", specifier = ">=1.0.0,<2" }, - { name = "pillow", specifier = ">=12.0.0,<13" }, ] recipes = [ { name = "bm25s", specifier = ">=0.2.0,<1" },