diff --git a/README.md b/README.md index aaa9e58..c23b67b 100644 --- a/README.md +++ b/README.md @@ -46,3 +46,141 @@ If you use DiffUTE in your research or wish to refer to the baseline results pub Please feel free to contact us if you have any problems. Email: [hx.chen@hotmail.com](hx.chen@hotmail.com) or [zhuoerxu.xzr@antgroup.com](zhuoerxu.xzr@antgroup.com) + +# DiffUTE Training Scripts V2 + +This repository contains updated training scripts for the DiffUTE (Diffusion Universal Text Editor) model. The scripts have been modernized with improved data handling, better code organization, and MinIO integration for efficient data storage. + +## Key Changes + +1. Replaced pcache_fileio with MinIO for data handling +2. Removed alps dependencies +3. Improved code organization and readability +4. Enhanced error handling and logging +5. Better type hints and documentation +6. Modernized training loops + +## Requirements + +Install the required packages: + +```bash +pip install -r requirements.txt +``` + +## Directory Structure + +``` +. +├── README.md +├── requirements.txt +├── train_vae_v2.py +├── train_diffute_v2.py +└── utils/ + └── minio_utils.py +``` + +## Training Scripts + +### VAE Training + +Train the VAE component using: + +```bash +python train_vae_v2.py \ + --pretrained_model_name_or_path "path/to/model" \ + --output_dir "vae-fine-tuned" \ + --data_path "path/to/data.csv" \ + --resolution 512 \ + --train_batch_size 16 \ + --num_train_epochs 100 \ + --learning_rate 1e-4 \ + --minio_endpoint "your-minio-endpoint" \ + --minio_access_key "your-access-key" \ + --minio_secret_key "your-secret-key" \ + --minio_bucket "your-bucket-name" +``` + +### DiffUTE Training + +Train the complete DiffUTE model using: + +```bash +python train_diffute_v2.py \ + --pretrained_model_name_or_path "path/to/model" \ + --output_dir "diffute-fine-tuned" \ + --data_path "path/to/data.csv" \ + --resolution 512 \ + --train_batch_size 16 \ + --num_train_epochs 100 \ + --learning_rate 1e-4 \ + --guidance_scale 0.8 \ + --minio_endpoint "your-minio-endpoint" \ + --minio_access_key "your-access-key" \ + --minio_secret_key "your-secret-key" \ + --minio_bucket "your-bucket-name" +``` + +## Data Format + +The training data should be specified in a CSV file with the following columns: + +For VAE training: +- `path`: Path to the image file in MinIO storage + +For DiffUTE training: +- `image_path`: Path to the image file in MinIO storage +- `ocr_path`: Path to the OCR results JSON file in MinIO storage + +## MinIO Setup + +1. Install and configure MinIO server +2. Create a bucket for storing training data +3. Upload your training images and OCR results +4. Configure access credentials in the training scripts + +## Model Architecture + +The DiffUTE model consists of three main components: + +1. VAE (Variational AutoEncoder): + - Handles image encoding/decoding + - Pre-trained and frozen during DiffUTE training + - Reduces computational complexity by working in latent space + +2. UNet: + - Main trainable component + - Performs denoising in latent space + - Conditioned on text embeddings + - Takes concatenated input of noisy latents, mask, and masked image + +3. TrOCR: + - Pre-trained text recognition model + - Provides text embeddings for conditioning + - Frozen during training + +## Training Process + +1. Data Preparation: + - Images are loaded from MinIO storage + - OCR results are used to identify text regions + - Images are preprocessed and normalized + +2. Training Loop: + - VAE encodes images to latent space + - Random noise is added according to diffusion schedule + - UNet predicts noise or velocity + - Loss is calculated and model is updated + - Checkpoints are saved periodically + +## Error Handling + +The scripts include robust error handling: +- Graceful handling of failed image loads +- Fallback mechanisms for missing data +- Detailed logging of errors +- Proper cleanup of resources + +## Contributing + +All rights go to original authors diff --git a/requirements.txt b/requirements.txt index 31b9026..a862899 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,12 @@ -accelerate>=0.16.0 -torchvision -transformers>=4.25.1 -datasets -ftfy -tensorboard -Jinja2 +torch>=2.0.0 +accelerate>=0.20.0 +transformers>=4.30.0 +diffusers>=0.15.0 +albumentations>=1.3.0 +opencv-python>=4.7.0 +pandas>=2.0.0 +numpy>=1.24.0 +Pillow>=9.5.0 +tqdm>=4.65.0 +minio>=7.1.0 +scikit-image>=0.20.0 diff --git a/stable_diffusion_text_inpaint/README.md b/stable_diffusion_text_inpaint/README.md new file mode 100644 index 0000000..5514d81 --- /dev/null +++ b/stable_diffusion_text_inpaint/README.md @@ -0,0 +1,168 @@ +# Using Stable Diffusion for Text Inpainting + +This guide explains how to use Stable Diffusion's inpainting capability to add text to specific regions in an image. While not as specialized as DiffUTE for text editing, this approach can still achieve decent results. + +## Requirements + +```python +pip install diffusers transformers torch +``` + +## Basic Implementation + +```python +import torch +from diffusers import StableDiffusionInpaintPipeline +from PIL import Image, ImageDraw +import numpy as np + +def create_text_mask(image, text_box): + """Create a binary mask for the text region + + Args: + image: PIL Image + text_box: tuple of (x1, y1, x2, y2) coordinates + """ + mask = Image.new("RGB", image.size, "black") + draw = ImageDraw.Draw(mask) + draw.rectangle(text_box, fill="white") + return mask + +# Load the model +model_id = "stabilityai/stable-diffusion-2-inpainting" +pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +# Load your image +image = Image.open("your_image.png") + +# Define the text region (x1, y1, x2, y2) +text_box = (100, 100, 300, 150) # Example coordinates + +# Create the mask +mask = create_text_mask(image, text_box) + +# Generate the inpainting +prompt = "Clear black text saying 'Hello World' on a white background" +negative_prompt = "blurry, unclear text, multiple texts, watermark" + +result = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + mask_image=mask, + num_inference_steps=50, + guidance_scale=7.5, +).images[0] +``` + +## Tips for Better Results + +1. **Mask Preparation**: + - Make the mask slightly larger than the text area + - Use anti-aliasing on mask edges for smoother blending + - Consider the text baseline and x-height in mask creation + +2. **Prompt Engineering**: + - Be specific about text style: "sharp, clear black text" + - Mention text properties: "centered, serif font" + - Include context: "text on a white background" + +3. **Negative Prompts**: + - "blurry, unclear text" + - "multiple texts, overlapping text" + - "watermark, artifacts" + - "distorted, warped text" + +4. **Parameter Tuning**: + ```python + # For clearer text + result = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + mask_image=mask, + num_inference_steps=50, # More steps for better quality + guidance_scale=7.5, # Higher for more prompt adherence + strength=0.8, # Control how much to change + ).images[0] + ``` + +## Advanced Usage + +### 1. Style Matching + +To match existing text styles in the image: + +```python +def match_text_style(image, text_region): + """Analyze existing text style in the image""" + # Add OCR or style analysis here + return "style_description" + +style = match_text_style(image, text_region) +prompt = f"Text saying 'Hello World' in style: {style}" +``` + +### 2. Context-Aware Masking + +```python +def create_context_mask(image, text_box, padding=10): + """Create a mask with context awareness""" + x1, y1, x2, y2 = text_box + padded_box = (x1-padding, y1-padding, x2+padding, y2+padding) + mask = create_text_mask(image, padded_box) + return mask +``` + +### 3. Multiple Attempts + +```python +def generate_multiple_attempts(pipe, image, mask, prompt, num_attempts=3): + """Generate multiple versions and pick the best""" + results = [] + for _ in range(num_attempts): + result = pipe( + prompt=prompt, + image=image, + mask_image=mask, + num_inference_steps=50, + ).images[0] + results.append(result) + return results +``` + +## Limitations + +1. Less precise text control compared to DiffUTE +2. May require multiple attempts to get desired results +3. Text style matching is less reliable +4. May introduce artifacts around text regions + +## Best Practices + +1. **Preparation**: + - Clean the text region thoroughly + - Create precise masks + - Use high-resolution images + +2. **Generation**: + - Start with lower strength values + - Generate multiple variations + - Use detailed prompts + +3. **Post-processing**: + - Check text clarity and alignment + - Verify style consistency + - Touch up edges if needed + +## When to Use DiffUTE Instead + +Consider using DiffUTE when: +- Precise text style matching is crucial +- Multiple text regions need editing +- Text needs to perfectly match surrounding context +- Working with complex backgrounds \ No newline at end of file diff --git a/stable_diffusion_text_inpaint/__init__.py b/stable_diffusion_text_inpaint/__init__.py new file mode 100644 index 0000000..3b12315 --- /dev/null +++ b/stable_diffusion_text_inpaint/__init__.py @@ -0,0 +1,19 @@ +"""Text inpainting package using Stable Diffusion.""" + +from .text_inpainter import TextInpainter +from .utils.mask_utils import ( + create_text_mask, + create_context_mask, + create_antialiased_mask, +) +from .utils.style_utils import TextStyleAnalyzer, generate_style_prompt + +__version__ = "0.1.0" +__all__ = [ + "TextInpainter", + "create_text_mask", + "create_context_mask", + "create_antialiased_mask", + "TextStyleAnalyzer", + "generate_style_prompt", +] diff --git a/stable_diffusion_text_inpaint/__pycache__/text_inpainter.cpython-311.pyc b/stable_diffusion_text_inpaint/__pycache__/text_inpainter.cpython-311.pyc new file mode 100644 index 0000000..e06dbec Binary files /dev/null and b/stable_diffusion_text_inpaint/__pycache__/text_inpainter.cpython-311.pyc differ diff --git a/stable_diffusion_text_inpaint/cli.py b/stable_diffusion_text_inpaint/cli.py new file mode 100644 index 0000000..c40396d --- /dev/null +++ b/stable_diffusion_text_inpaint/cli.py @@ -0,0 +1,147 @@ +"""Command line interface for text inpainting.""" + +import click +from PIL import Image +import sys +from text_inpainter import TextInpainter +from utils.region_finder import interactive_region_select, detect_text_regions, visualize_region + +@click.group() +def cli(): + """Text inpainting tools using Stable Diffusion.""" + pass + +@cli.command() +@click.argument('image_path', type=click.Path(exists=True)) +@click.option('--text', '-t', help='Text to add (if not provided, will be prompted)') +@click.option('--region', '-r', nargs=4, type=int, help='Text region coordinates (x1 y1 x2 y2)') +@click.option('--match-style/--no-match-style', default=True, help='Match existing text style') +@click.option('--attempts', '-a', default=1, help='Number of variations to generate') +@click.option('--output', '-o', help='Output file path (default: inpainted_)') +@click.option('--device', default='cuda', help='Device to run on (cuda or cpu)') +@click.option('--interactive/--no-interactive', default=True, help='Use interactive region selection') +def inpaint(image_path, text, region, match_style, attempts, output, device, interactive): + """Inpaint text in an image. + + Examples: + # Interactive mode (recommended for first use) + python cli.py inpaint image.png + + # Specify everything via command line + python cli.py inpaint image.png -t "Hello" -r 100 50 300 100 + + # Generate multiple variations + python cli.py inpaint image.png -t "Hello" -a 3 + """ + try: + # Load image + try: + image = Image.open(image_path) + # Ensure image is in RGB mode + if image.mode != 'RGB': + image = image.convert('RGB') + except Exception as e: + click.echo(f"Error loading image: {str(e)}", err=True) + sys.exit(1) + + # Get text if not provided + if not text: + text = click.prompt('Enter the text to add', type=str) + + # Get region if not provided + if not region and interactive: + click.echo("\nSelect the region for the text:") + try: + region = interactive_region_select(image_path) + except Exception as e: + click.echo(f"Error selecting region: {str(e)}", err=True) + sys.exit(1) + elif not region: + click.echo("Error: Must provide either --region or use --interactive", err=True) + sys.exit(1) + + # Initialize inpainter + click.echo("\nInitializing Stable Diffusion (this may take a moment)...") + try: + inpainter = TextInpainter(device=device) + except Exception as e: + click.echo(f"Error initializing Stable Diffusion: {str(e)}", err=True) + sys.exit(1) + + # Generate the inpainting + click.echo(f"\nInpainting text: '{text}'") + with click.progressbar(length=attempts, label='Generating variations') as bar: + try: + results = inpainter.inpaint_text( + image=image, + text=text, + text_box=region, + match_style=match_style, + num_attempts=attempts + ) + bar.update(attempts) + except Exception as e: + click.echo(f"\nError during inpainting: {str(e)}", err=True) + sys.exit(1) + + # Save results + if not output: + import os + base, ext = os.path.splitext(image_path) + output = f"{base}_inpainted{ext}" + + if attempts == 1: + results = [results] # Make it a list for consistent handling + + # Save all variations + for i, result in enumerate(results): + try: + if attempts > 1: + base, ext = os.path.splitext(output) + save_path = f"{base}_{i+1}{ext}" + else: + save_path = output + result.save(save_path) + click.echo(f"Saved result to: {save_path}") + except Exception as e: + click.echo(f"Error saving result {i+1}: {str(e)}", err=True) + continue + + except KeyboardInterrupt: + click.echo("\nOperation cancelled by user", err=True) + sys.exit(1) + except Exception as e: + click.echo(f"Unexpected error: {str(e)}", err=True) + sys.exit(1) + +@cli.command() +@click.argument('image_path', type=click.Path(exists=True)) +@click.option('--auto/--no-auto', default=False, help='Use automatic detection') +def find_regions(image_path, auto): + """Find text regions in an image. + + This command helps identify the coordinates of text regions for inpainting. + """ + try: + if auto: + click.echo("Detecting text regions automatically...") + regions = detect_text_regions(image_path) + click.echo(f"\nFound {len(regions)} regions:") + for i, region in enumerate(regions, 1): + click.echo(f"Region {i}: {region}") + visualize_region(image_path, region) + else: + click.echo("Starting interactive region selection...") + region = interactive_region_select(image_path) + click.echo(f"\nSelected region coordinates: {region}") + click.echo("\nTo use these coordinates with the inpaint command:") + click.echo(f"python cli.py inpaint {image_path} -r {region[0]} {region[1]} {region[2]} {region[3]}") + except KeyboardInterrupt: + click.echo("\nOperation cancelled by user", err=True) + sys.exit(1) + except Exception as e: + click.echo(f"Error: {str(e)}", err=True) + sys.exit(1) + +if __name__ == '__main__': + cli() \ No newline at end of file diff --git a/stable_diffusion_text_inpaint/custom_params_result.png b/stable_diffusion_text_inpaint/custom_params_result.png new file mode 100644 index 0000000..5cf20ee Binary files /dev/null and b/stable_diffusion_text_inpaint/custom_params_result.png differ diff --git a/stable_diffusion_text_inpaint/example.png b/stable_diffusion_text_inpaint/example.png new file mode 100644 index 0000000..2d12241 Binary files /dev/null and b/stable_diffusion_text_inpaint/example.png differ diff --git a/stable_diffusion_text_inpaint/example.py b/stable_diffusion_text_inpaint/example.py new file mode 100644 index 0000000..b5832bb --- /dev/null +++ b/stable_diffusion_text_inpaint/example.py @@ -0,0 +1,120 @@ +"""Example usage of the TextInpainter class. + +Note: On first run, this script will download required models from Hugging Face: +1. stabilityai/stable-diffusion-2-inpainting (~4GB): Used for text inpainting +2. microsoft/trocr-large-printed (~1GB): Used for text style analysis + +The style analysis system helps match the appearance of existing text by: +- Analyzing text color by finding dominant colors in the text region +- Detecting background color by sampling corner pixels +- Determining text size based on the region dimensions +- Converting these properties into natural language prompts for Stable Diffusion + (e.g. "clear black text on white background") + +This helps ensure that newly inpainted text matches the style of text in the rest +of the image, maintaining visual consistency. +""" + +from PIL import Image, ImageDraw +from text_inpainter import TextInpainter +import os + + +def create_sample_image(size=(512, 512), color="white"): + """Create a sample image for testing. + + Args: + size (tuple): Image dimensions (width, height) + color (str): Background color + + Returns: + PIL.Image: Sample image + """ + image = Image.new("RGB", size, color) + draw = ImageDraw.Draw(image) + + # Add some shapes for visual reference + draw.rectangle((50, 50, 462, 462), outline="gray", width=2) + draw.line((50, 256, 462, 256), fill="gray", width=2) + draw.line((256, 50, 256, 462), fill="gray", width=2) + + # Save the image + if not os.path.exists("example.png"): + image.save("example.png") + return image + + +def single_text_example(): + """Example of inpainting a single text region.""" + # Initialize inpainter + inpainter = TextInpainter() + + # Create or load sample image + image = create_sample_image() + + # Define text region (centered horizontally) + text_box = (156, 100, 356, 150) # 200px wide, centered in 512px image + + # Simple inpainting + result = inpainter.inpaint_text(image=image, text="Hello World", text_box=text_box) + result.save("single_text_result.png") + + # Multiple attempts with style matching + variations = inpainter.inpaint_text( + image=image, + text="Hello World", + text_box=text_box, + match_style=True, + num_attempts=3, + ) + + # Save variations + for i, img in enumerate(variations): + img.save(f"variation_{i}.png") + + +def multiple_text_example(): + """Example of inpainting multiple text regions.""" + inpainter = TextInpainter() + image = create_sample_image() + + # Define multiple text regions (vertically stacked, centered) + text_regions = [ + ("First Text", (156, 100, 356, 150)), + ("Second Text", (156, 200, 356, 250)), + ("Third Text", (156, 300, 356, 350)), + ] + + # Batch inpainting + result = inpainter.batch_inpaint_text(image, text_regions) + result.save("multiple_text_result.png") + + +def custom_parameters_example(): + """Example with custom pipeline parameters.""" + inpainter = TextInpainter() + image = create_sample_image() + text_box = (156, 100, 356, 150) + + # Custom parameters for more control + result = inpainter.inpaint_text( + image=image, + text="Custom Text", + text_box=text_box, + num_inference_steps=75, # More steps for better quality + guidance_scale=8.5, # Stronger prompt adherence + negative_prompt="blurry, ugly, bad quality, error, watermark", + match_style=True, + ) + result.save("custom_params_result.png") + + +if __name__ == "__main__": + print("Running single text example...") + single_text_example() + + print("Running multiple text example...") + multiple_text_example() + + print("Running custom parameters example...") + custom_parameters_example() diff --git a/stable_diffusion_text_inpaint/find_regions.py b/stable_diffusion_text_inpaint/find_regions.py new file mode 100644 index 0000000..267649a --- /dev/null +++ b/stable_diffusion_text_inpaint/find_regions.py @@ -0,0 +1,26 @@ +"""Script to help find text regions in an image.""" + +import argparse +from utils.region_finder import interactive_region_select, detect_text_regions, visualize_region + +def main(): + parser = argparse.ArgumentParser(description="Find text regions in an image") + parser.add_argument("image_path", help="Path to the image file") + parser.add_argument("--auto", action="store_true", help="Use automatic detection") + args = parser.parse_args() + + if args.auto: + print("Detecting text regions automatically...") + regions = detect_text_regions(args.image_path) + print(f"\nFound {len(regions)} regions:") + for i, region in enumerate(regions, 1): + print(f"Region {i}: {region}") + visualize_region(args.image_path, region) + else: + print("Starting interactive region selection...") + region = interactive_region_select(args.image_path) + print(f"\nSelected region coordinates: {region}") + print("You can use these coordinates with the TextInpainter class.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/stable_diffusion_text_inpaint/multiple_text_result.png b/stable_diffusion_text_inpaint/multiple_text_result.png new file mode 100644 index 0000000..5f04b77 Binary files /dev/null and b/stable_diffusion_text_inpaint/multiple_text_result.png differ diff --git a/stable_diffusion_text_inpaint/requirements.txt b/stable_diffusion_text_inpaint/requirements.txt new file mode 100644 index 0000000..3149dcf --- /dev/null +++ b/stable_diffusion_text_inpaint/requirements.txt @@ -0,0 +1,9 @@ +diffusers>=0.24.0 +transformers>=4.36.0 +torch>=2.0.0 +Pillow>=10.0.0 +numpy>=1.24.0 +tqdm>=4.65.0 +opencv-python>=4.8.0 +matplotlib>=3.7.0 +click>=8.1.0 \ No newline at end of file diff --git a/stable_diffusion_text_inpaint/selection_1.png b/stable_diffusion_text_inpaint/selection_1.png new file mode 100644 index 0000000..d8d8602 Binary files /dev/null and b/stable_diffusion_text_inpaint/selection_1.png differ diff --git a/stable_diffusion_text_inpaint/selection_2.png b/stable_diffusion_text_inpaint/selection_2.png new file mode 100644 index 0000000..cf0ba0c Binary files /dev/null and b/stable_diffusion_text_inpaint/selection_2.png differ diff --git a/stable_diffusion_text_inpaint/single_text_result.png b/stable_diffusion_text_inpaint/single_text_result.png new file mode 100644 index 0000000..dadb07a Binary files /dev/null and b/stable_diffusion_text_inpaint/single_text_result.png differ diff --git a/stable_diffusion_text_inpaint/text_inpainter.py b/stable_diffusion_text_inpaint/text_inpainter.py new file mode 100644 index 0000000..0b1ac1d --- /dev/null +++ b/stable_diffusion_text_inpaint/text_inpainter.py @@ -0,0 +1,122 @@ +"""Main class for text inpainting using Stable Diffusion.""" + +import torch +from diffusers import StableDiffusionInpaintPipeline +from PIL import Image +from tqdm import tqdm + +from utils.mask_utils import create_context_mask, validate_text_box +from utils.style_utils import TextStyleAnalyzer, generate_style_prompt + + +class TextInpainter: + def __init__(self, device="cuda"): + """Initialize the text inpainting pipeline. + + Args: + device (str): Device to run the model on ("cuda" or "cpu") + """ + self.device = device + self.model_id = "stabilityai/stable-diffusion-2-inpainting" + + # Initialize pipeline + self.pipe = StableDiffusionInpaintPipeline.from_pretrained( + self.model_id, + torch_dtype=torch.float16 if device == "cuda" else torch.float32, + ) + self.pipe = self.pipe.to(device) + + # Initialize style analyzer + self.style_analyzer = TextStyleAnalyzer() + + def inpaint_text( + self, image, text, text_box, match_style=True, num_attempts=1, **kwargs + ): + """Inpaint text in the specified region. + + Args: + image (PIL.Image): Input image + text (str): Text to add + text_box (tuple): (x1, y1, x2, y2) coordinates for text region + match_style (bool): Whether to match existing text style + num_attempts (int): Number of generation attempts + **kwargs: Additional arguments for the pipeline + + Returns: + PIL.Image: Image with inpainted text + list: All generated variations if num_attempts > 1 + """ + # Validate text box + text_box = validate_text_box(image.size, text_box) + + # Create mask + mask = create_context_mask(image, text_box) + + # Generate prompt + if match_style: + style_props = self.style_analyzer.analyze_text_region(image, text_box) + style_prompt = generate_style_prompt(style_props) + prompt = f"{style_prompt}, text saying '{text}'" + else: + prompt = f"Clear text saying '{text}'" + + # Default parameters + params = { + "num_inference_steps": 50, + "guidance_scale": 7.5, + "negative_prompt": "blurry, unclear text, multiple texts, watermark", + } + params.update(kwargs) + + # Generate multiple attempts + results = [] + for _ in tqdm(range(num_attempts), desc="Generating variations"): + result = self.pipe( + prompt=prompt, image=image, mask_image=mask, **params + ).images[0] + results.append(result) + + return results[0] if num_attempts == 1 else results + + def batch_inpaint_text(self, image, text_regions): + """Inpaint multiple text regions in an image. + + Args: + image (PIL.Image): Input image + text_regions (list): List of (text, box) tuples + + Returns: + PIL.Image: Image with all text regions inpainted + """ + result = image.copy() + for text, box in text_regions: + result = self.inpaint_text(result, text, box) + return result + + +def main(): + """Example usage of TextInpainter.""" + # Initialize inpainter + inpainter = TextInpainter() + + # Load image + image = Image.open("example.png") + + # Define text region + text_box = (100, 100, 300, 150) + + # Inpaint text + result = inpainter.inpaint_text( + image=image, text="Hello World", text_box=text_box, num_attempts=3 + ) + + # Save results + if isinstance(result, list): + for i, img in enumerate(result): + img.save(f"result_{i}.png") + else: + result.save("result.png") + + +if __name__ == "__main__": + main() diff --git a/stable_diffusion_text_inpaint/utils/__init__.py b/stable_diffusion_text_inpaint/utils/__init__.py new file mode 100644 index 0000000..45bf751 --- /dev/null +++ b/stable_diffusion_text_inpaint/utils/__init__.py @@ -0,0 +1,18 @@ +"""Utility functions for text inpainting.""" + +from .mask_utils import ( + create_text_mask, + create_context_mask, + create_antialiased_mask, + validate_text_box, +) +from .style_utils import TextStyleAnalyzer, generate_style_prompt + +__all__ = [ + "create_text_mask", + "create_context_mask", + "create_antialiased_mask", + "validate_text_box", + "TextStyleAnalyzer", + "generate_style_prompt", +] diff --git a/stable_diffusion_text_inpaint/utils/__pycache__/__init__.cpython-311.pyc b/stable_diffusion_text_inpaint/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..e059bb1 Binary files /dev/null and b/stable_diffusion_text_inpaint/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/stable_diffusion_text_inpaint/utils/__pycache__/mask_utils.cpython-311.pyc b/stable_diffusion_text_inpaint/utils/__pycache__/mask_utils.cpython-311.pyc new file mode 100644 index 0000000..2d85a52 Binary files /dev/null and b/stable_diffusion_text_inpaint/utils/__pycache__/mask_utils.cpython-311.pyc differ diff --git a/stable_diffusion_text_inpaint/utils/__pycache__/region_finder.cpython-311.pyc b/stable_diffusion_text_inpaint/utils/__pycache__/region_finder.cpython-311.pyc new file mode 100644 index 0000000..b6f0d7d Binary files /dev/null and b/stable_diffusion_text_inpaint/utils/__pycache__/region_finder.cpython-311.pyc differ diff --git a/stable_diffusion_text_inpaint/utils/__pycache__/style_utils.cpython-311.pyc b/stable_diffusion_text_inpaint/utils/__pycache__/style_utils.cpython-311.pyc new file mode 100644 index 0000000..065eb1d Binary files /dev/null and b/stable_diffusion_text_inpaint/utils/__pycache__/style_utils.cpython-311.pyc differ diff --git a/stable_diffusion_text_inpaint/utils/mask_utils.py b/stable_diffusion_text_inpaint/utils/mask_utils.py new file mode 100644 index 0000000..49d8b29 --- /dev/null +++ b/stable_diffusion_text_inpaint/utils/mask_utils.py @@ -0,0 +1,76 @@ +"""Utility functions for creating and manipulating masks for text inpainting.""" + +from PIL import Image, ImageDraw + + +def create_text_mask(image, text_box): + """Create a binary mask for the text region. + + Args: + image (PIL.Image): Input image + text_box (tuple): (x1, y1, x2, y2) coordinates for text region + + Returns: + PIL.Image: Binary mask with white region for text area + """ + mask = Image.new("RGB", image.size, "black") + draw = ImageDraw.Draw(mask) + draw.rectangle(text_box, fill="white") + return mask + + +def create_context_mask(image, text_box, padding=10): + """Create a mask with padding for context awareness. + + Args: + image (PIL.Image): Input image + text_box (tuple): (x1, y1, x2, y2) coordinates for text region + padding (int): Number of pixels to pad around the text region + + Returns: + PIL.Image: Binary mask with padding + """ + x1, y1, x2, y2 = text_box + padded_box = ( + max(0, x1 - padding), + max(0, y1 - padding), + min(image.size[0], x2 + padding), + min(image.size[1], y2 + padding), + ) + return create_text_mask(image, padded_box) + + +def create_antialiased_mask(image, text_box, blur_radius=2): + """Create an anti-aliased mask for smoother blending. + + Args: + image (PIL.Image): Input image + text_box (tuple): (x1, y1, x2, y2) coordinates for text region + blur_radius (int): Radius for Gaussian blur + + Returns: + PIL.Image: Anti-aliased mask + """ + mask = create_text_mask(image, text_box) + return mask.filter(ImageFilter.GaussianBlur(blur_radius)) + + +def validate_text_box(image_size, text_box): + """Validate and adjust text box coordinates to fit within image bounds. + + Args: + image_size (tuple): (width, height) of the image + text_box (tuple): (x1, y1, x2, y2) coordinates for text region + + Returns: + tuple: Adjusted text box coordinates + """ + x1, y1, x2, y2 = text_box + width, height = image_size + + return ( + max(0, min(x1, width)), + max(0, min(y1, height)), + max(0, min(x2, width)), + max(0, min(y2, height)), + ) diff --git a/stable_diffusion_text_inpaint/utils/region_finder.py b/stable_diffusion_text_inpaint/utils/region_finder.py new file mode 100644 index 0000000..8342b3a --- /dev/null +++ b/stable_diffusion_text_inpaint/utils/region_finder.py @@ -0,0 +1,134 @@ +"""Utilities for finding and visualizing text regions in images.""" + +import cv2 +import numpy as np +from PIL import Image, ImageDraw +import matplotlib.pyplot as plt +from transformers import TrOCRProcessor, VisionEncoderDecoderModel +import torch + +def visualize_region(image_path, text_box=None): + """Display image with optional text region highlighted. + + Args: + image_path (str): Path to the image + text_box (tuple, optional): (x1, y1, x2, y2) coordinates to highlight + """ + # Load and display image + img = Image.open(image_path) + plt.figure(figsize=(12, 8)) + + # Create a copy for drawing + draw_img = img.copy() + draw = ImageDraw.Draw(draw_img) + + # Draw grid lines every 50 pixels + for x in range(0, img.width, 50): + draw.line([(x, 0), (x, img.height)], fill='gray', width=1) + if x % 100 == 0: # Add labels for every 100 pixels + draw.text((x, 5), str(x), fill='gray') + + for y in range(0, img.height, 50): + draw.line([(0, y), (img.width, y)], fill='gray', width=1) + if y % 100 == 0: # Add labels for every 100 pixels + draw.text((5, y), str(y), fill='gray') + + # If text box provided, highlight it + if text_box: + draw.rectangle(text_box, outline='red', width=2) + # Add coordinate labels + x1, y1, x2, y2 = text_box + draw.text((x1, y1-20), f'({x1}, {y1})', fill='red') + draw.text((x2, y2+5), f'({x2}, {y2})', fill='red') + + plt.imshow(draw_img) + plt.axis('off') + plt.show() + + return img.size + +def detect_text_regions(image_path): + """Automatically detect text regions using OCR. + + Args: + image_path (str): Path to the image + + Returns: + list: List of (text, box) tuples where box is (x1, y1, x2, y2) + """ + # Load image + img = cv2.imread(image_path) + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + # Use EAST text detector or Tesseract OCR + # For now, using simple edge detection as placeholder + edges = cv2.Canny(gray, 100, 200) + contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + regions = [] + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + # Filter small regions + if w > 20 and h > 10: + regions.append((x, y, x+w, y+h)) + + return regions + +def interactive_region_select(image_path): + """Display image and let user click points to define region. + + Args: + image_path (str): Path to the image + + Returns: + tuple: (x1, y1, x2, y2) coordinates of selected region + """ + img = Image.open(image_path) + + print("\nInstructions:") + print("1. Image will be displayed with a grid overlay") + print("2. Use the grid lines and coordinates to identify the region") + print("3. Close the image window when done") + + # Show image with grid + size = visualize_region(image_path) + + # Get coordinates from user + while True: + try: + print("\nEnter coordinates (x1 y1 x2 y2) separated by spaces:") + coords = input("> ").strip().split() + if len(coords) != 4: + print("Please enter exactly 4 numbers") + continue + + x1, y1, x2, y2 = map(int, coords) + + # Validate coordinates + if not (0 <= x1 < size[0] and 0 <= x2 < size[0] and + 0 <= y1 < size[1] and 0 <= y2 < size[1]): + print(f"Coordinates must be within image bounds: width={size[0]}, height={size[1]}") + continue + + # Show region for confirmation + print("\nSelected region (close window to confirm):") + visualize_region(image_path, (x1, y1, x2, y2)) + + confirm = input("Is this region correct? (y/n): ").lower().strip() + if confirm == 'y': + return (x1, y1, x2, y2) + + except ValueError: + print("Please enter valid numbers") + +if __name__ == "__main__": + # Example usage + image_path = "example.png" + + print("Method 1: Visual Grid Helper") + region = interactive_region_select(image_path) + print(f"Selected region: {region}") + + print("\nMethod 2: Automatic Detection") + regions = detect_text_regions(image_path) + print(f"Detected regions: {regions}") \ No newline at end of file diff --git a/stable_diffusion_text_inpaint/utils/style_utils.py b/stable_diffusion_text_inpaint/utils/style_utils.py new file mode 100644 index 0000000..dc7ecde --- /dev/null +++ b/stable_diffusion_text_inpaint/utils/style_utils.py @@ -0,0 +1,117 @@ +"""Utility functions for analyzing and matching text styles.""" + +from PIL import Image +import numpy as np +from transformers import TrOCRProcessor, VisionEncoderDecoderModel +import torch + + +class TextStyleAnalyzer: + def __init__(self): + """Initialize the text style analyzer with TrOCR model.""" + self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-printed") + self.model = VisionEncoderDecoderModel.from_pretrained( + "microsoft/trocr-large-printed" + ) + + def analyze_text_region(self, image, text_box): + """Analyze text style in the specified region. + + Args: + image (PIL.Image): Input image + text_box (tuple): (x1, y1, x2, y2) coordinates for text region + + Returns: + dict: Style properties including font, color, size + """ + # Crop the text region + x1, y1, x2, y2 = text_box + text_region = image.crop((x1, y1, x2, y2)) + + # Ensure image is in RGB format + if text_region.mode != 'RGB': + text_region = text_region.convert('RGB') + + try: + # Get text features from TrOCR + pixel_values = self.processor(text_region, return_tensors="pt").pixel_values + features = self.model.encoder(pixel_values).last_hidden_state + except Exception as e: + print(f"Warning: Style analysis failed ({str(e)}), using basic analysis only") + features = None + + # Analyze basic properties + style_props = { + "size": y2 - y1, # Approximate text height + "width": x2 - x1, # Region width + "color": self._analyze_color(text_region), + "background": self._analyze_background(text_region), + } + + return style_props + + def _analyze_color(self, text_region): + """Analyze the dominant text color.""" + # Convert to numpy array + img_array = np.array(text_region) + + # Simple color analysis (can be improved) + mean_color = np.mean(img_array, axis=(0, 1)) + return tuple(map(int, mean_color)) + + def _analyze_background(self, text_region): + """Analyze the background color.""" + img_array = np.array(text_region) + + # Assume corners are background + corners = [ + img_array[0, 0], + img_array[0, -1], + img_array[-1, 0], + img_array[-1, -1], + ] + bg_color = np.mean(corners, axis=0) + return tuple(map(int, bg_color)) + + +def generate_style_prompt(style_props): + """Generate a text prompt based on style properties. + + Args: + style_props (dict): Style properties from TextStyleAnalyzer + + Returns: + str: Generated prompt for stable diffusion + """ + # Convert RGB colors to descriptive terms + text_color = _rgb_to_description(style_props["color"]) + bg_color = _rgb_to_description(style_props["background"]) + + prompt = f"clear {text_color} text on {bg_color} background" + + # Add size information + if style_props["size"] < 20: + prompt = "small " + prompt + elif style_props["size"] > 40: + prompt = "large " + prompt + + return prompt + + +def _rgb_to_description(rgb): + """Convert RGB values to color description.""" + r, g, b = rgb + + # Simple color mapping (can be expanded) + if max(r, g, b) < 50: + return "black" + elif min(r, g, b) > 200: + return "white" + elif r > max(g, b) + 50: + return "red" + elif g > max(r, b) + 50: + return "green" + elif b > max(r, g) + 50: + return "blue" + else: + return "gray" diff --git a/stable_diffusion_text_inpaint/variation_0.png b/stable_diffusion_text_inpaint/variation_0.png new file mode 100644 index 0000000..0385813 Binary files /dev/null and b/stable_diffusion_text_inpaint/variation_0.png differ diff --git a/stable_diffusion_text_inpaint/variation_1.png b/stable_diffusion_text_inpaint/variation_1.png new file mode 100644 index 0000000..15b2522 Binary files /dev/null and b/stable_diffusion_text_inpaint/variation_1.png differ diff --git a/stable_diffusion_text_inpaint/variation_2.png b/stable_diffusion_text_inpaint/variation_2.png new file mode 100644 index 0000000..b348e06 Binary files /dev/null and b/stable_diffusion_text_inpaint/variation_2.png differ diff --git a/train_diffute_v1.py b/train_diffute_v1.py index 738f9bb..23ea62c 100644 --- a/train_diffute_v1.py +++ b/train_diffute_v1.py @@ -840,9 +840,7 @@ def main(): else: repo_name = args.hub_model_id create_repo(repo_name, exist_ok=True, token=args.hub_token) - Repository( - args.output_dir, clone_from=repo_name, token=args.hub_token - ) + Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: if "step_*" not in gitignore: diff --git a/train_diffute_v2.py b/train_diffute_v2.py new file mode 100644 index 0000000..cb1d42d --- /dev/null +++ b/train_diffute_v2.py @@ -0,0 +1,498 @@ +""" +DiffUTE (Diffusion Universal Text Editor) Training Script V2 + +This script implements the training process for the DiffUTE model, which edits text +in images while preserving surrounding context. The implementation uses MinIO for +efficient data loading and modern PyTorch practices for training. + +Key Components: +- VAE: Pre-trained and frozen, handles image encoding/decoding +- UNet: Main trainable component for denoising in latent space +- TrOCR: Pre-trained text recognition for conditioning +""" + +import os +import logging +import math +from typing import Tuple +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from tqdm.auto import tqdm +import albumentations as alb +from albumentations.pytorch import ToTensorV2 +from transformers import TrOCRProcessor, VisionEncoderDecoderModel + +from utils.minio_utils import MinioHandler + +# Will error if the minimal version of diffusers is not installed +check_min_version("0.15.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +def parse_args(): + """Parse training arguments.""" + import argparse + + parser = argparse.ArgumentParser(description="DiffUTE training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models", + ) + parser.add_argument( + "--output_dir", + type=str, + default="diffute-model-finetuned", + help="Output directory for checkpoints and models", + ) + parser.add_argument( + "--data_path", + type=str, + required=True, + help="Path to CSV file containing training data paths", + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help="Resolution for training images", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=16, + help="Batch size (per device) for training", + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=100, + help="Number of training epochs", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before backward pass", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate", + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="Mixed precision training mode", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=0.8, + help="Scale for guidance loss", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=4, + help="Number of dataloader workers", + ) + # MinIO configuration + parser.add_argument("--minio_endpoint", type=str, required=True) + parser.add_argument("--minio_access_key", type=str, required=True) + parser.add_argument("--minio_secret_key", type=str, required=True) + parser.add_argument("--minio_bucket", type=str, required=True) + + args = parser.parse_args() + return args + + +def prepare_mask_and_masked_image(image: np.ndarray, mask: np.ndarray) -> np.ndarray: + """ + Create a masked version of the input image. + + Args: + image: Input image array (H, W, C) + mask: Binary mask array (H, W) + + Returns: + Masked image with target regions set to 0 + """ + return image * np.stack([mask < 0.5] * 3, axis=2) + + +def generate_mask(size: Tuple[int, int], bbox: list, dilation: int = 10) -> np.ndarray: + """ + Generate a binary mask for the text region. + + Args: + size: (width, height) of the mask + bbox: [x1, y1, x2, y2] text bounding box + dilation: Number of pixels to dilate the mask + + Returns: + Binary mask array + """ + mask = np.zeros(size[::-1], dtype=np.float32) + x1, y1, x2, y2 = map(int, bbox) + mask[y1:y2, x1:x2] = 1 + + if dilation > 0: + import cv2 + + kernel = np.ones((dilation, dilation), np.uint8) + mask = cv2.dilate(mask, kernel, iterations=1) + + return mask + + +class DiffUTEDataset(Dataset): + """Dataset for DiffUTE training using MinIO storage.""" + + def __init__( + self, + minio_handler: MinioHandler, + data_paths: list, + ocr_paths: list, + resolution: int = 512, + ): + """ + Initialize the dataset. + + Args: + minio_handler: Initialized MinIO handler + data_paths: List of image paths in MinIO + ocr_paths: List of OCR result paths in MinIO + resolution: Target image resolution + """ + self.minio = minio_handler + self.image_paths = data_paths + self.ocr_paths = ocr_paths + self.resolution = resolution + + self.transform = alb.Compose( + [ + alb.SmallestMaxSize(max_size=resolution), + alb.CenterCrop(resolution, resolution), + alb.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + ToTensorV2(), + ] + ) + + self.mask_transform = alb.Compose( + [ + alb.SmallestMaxSize(max_size=resolution), + alb.CenterCrop(resolution, resolution), + ToTensorV2(), + ] + ) + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + image_path = self.image_paths[idx] + ocr_path = self.ocr_paths[idx] + + try: + # Load image and OCR results + image = self.minio.download_image(image_path) + ocr_data = self.minio.read_json(ocr_path) + + # Process OCR results + ocr_df = pd.DataFrame(ocr_data["document"]) + ocr_df = ocr_df[ocr_df["score"] > 0.8] + + if len(ocr_df) == 0: + raise ValueError("No valid OCR results found") + + # Randomly select one text region + ocr_sample = ocr_df.sample(n=1).iloc[0] + text = ocr_sample["text"] + bbox = ocr_sample["box"] + + # Convert bbox to [x1, y1, x2, y2] format + bbox = [ + min(x[0] for x in bbox), + min(x[1] for x in bbox), + max(x[0] for x in bbox), + max(x[1] for x in bbox), + ] + + # Generate mask and masked image + mask = generate_mask(image.shape[:2][::-1], bbox) + masked_image = prepare_mask_and_masked_image(image, mask) + + # Apply transforms + transformed = self.transform(image=image) + transformed_mask = self.mask_transform(image=mask) + transformed_masked = self.transform(image=masked_image) + + return { + "pixel_values": transformed["image"], + "mask": transformed_mask["image"][0], + "masked_image": transformed_masked["image"], + "text": text, + } + + except Exception as e: + logger.error(f"Error processing {image_path}: {str(e)}") + # Return random tensors as fallback + return { + "pixel_values": torch.randn(3, self.resolution, self.resolution), + "mask": torch.zeros(self.resolution, self.resolution), + "masked_image": torch.randn(3, self.resolution, self.resolution), + "text": "", + } + + +def main(): + args = parse_args() + + # Initialize accelerator + accelerator_project_config = ProjectConfiguration() + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + # Initialize MinIO handler + minio_handler = MinioHandler( + endpoint=args.minio_endpoint, + access_key=args.minio_access_key, + secret_key=args.minio_secret_key, + bucket_name=args.minio_bucket, + ) + + # Load data paths + df = pd.read_csv(args.data_path) + image_paths = df["image_path"].tolist() + ocr_paths = df["ocr_path"].tolist() + + # Create dataset and dataloader + dataset = DiffUTEDataset( + minio_handler=minio_handler, + data_paths=image_paths, + ocr_paths=ocr_paths, + resolution=args.resolution, + ) + + dataloader = DataLoader( + dataset, + batch_size=args.train_batch_size, + shuffle=True, + num_workers=args.dataloader_num_workers, + ) + + # Load models + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + ) + + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + ) + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + ) + + processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-printed") + trocr_model = VisionEncoderDecoderModel.from_pretrained( + "microsoft/trocr-large-printed" + ).encoder + + # Freeze VAE and TrOCR + vae.requires_grad_(False) + trocr_model.requires_grad_(False) + + # Optimizer + optimizer = torch.optim.AdamW( + unet.parameters(), + lr=args.learning_rate, + ) + + # Get number of training steps + num_update_steps_per_epoch = math.ceil( + len(dataloader) / args.gradient_accumulation_steps + ) + num_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + # Learning rate scheduler + lr_scheduler = get_scheduler( + "cosine", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=num_train_steps, + ) + + # Prepare everything with accelerator + unet, optimizer, dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, dataloader, lr_scheduler + ) + + # Move models to device and cast to dtype + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + vae.to(accelerator.device, dtype=weight_dtype) + trocr_model.to(accelerator.device, dtype=weight_dtype) + + # Get VAE scale factor + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + + # Train! + total_batch_size = ( + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(dataset)}") + logger.info(f" Num epochs = {args.num_train_epochs}") + logger.info(f" Batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size = {total_batch_size}") + logger.info(f" Total optimization steps = {num_train_steps}") + + global_step = 0 + for epoch in range(args.num_train_epochs): + unet.train() + train_loss = 0.0 + + progress_bar = tqdm( + total=num_update_steps_per_epoch, + disable=not accelerator.is_local_main_process, + ) + progress_bar.set_description(f"Epoch {epoch}") + + for step, batch in enumerate(dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode( + batch["pixel_values"].to(weight_dtype) + ).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Add noise + noise = torch.randn_like(latents) + timesteps = torch.randint( + 0, + noise_scheduler.num_train_timesteps, + (latents.shape[0],), + device=latents.device, + ) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Prepare mask + mask = F.interpolate( + batch["mask"].unsqueeze(1), + size=latents.shape[2:], + mode="nearest", + ) + mask = mask.to(weight_dtype) + + # Get masked image latents + masked_image_latents = vae.encode( + batch["masked_image"].to(weight_dtype) + ).latent_dist.sample() + masked_image_latents = masked_image_latents * vae.config.scaling_factor + + # Get text embeddings + text_embeddings = trocr_model( + batch["pixel_values"].to(weight_dtype) + ).last_hidden_state + + # Prepare model input + model_input = torch.cat( + [noisy_latents, mask, masked_image_latents], + dim=1, + ) + + # Get model prediction + model_pred = unet( + model_input, + timesteps, + encoder_hidden_states=text_embeddings, + ).sample + + # Calculate loss + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError( + f"Unknown prediction type {noise_scheduler.config.prediction_type}" + ) + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Scale loss by guidance + loss = loss * args.guidance_scale + + accelerator.backward(loss) + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), 1.0) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + progress_bar.update(1) + global_step += 1 + train_loss += loss.detach().item() + + logs = { + "loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + "step": global_step, + } + progress_bar.set_postfix(**logs) + + progress_bar.close() + + accelerator.wait_for_everyone() + + # Save checkpoint + if accelerator.is_main_process: + pipeline = accelerator.unwrap_model(unet) + pipeline.save_pretrained( + os.path.join(args.output_dir, f"checkpoint-{global_step}") + ) + + train_loss = train_loss / num_update_steps_per_epoch + logger.info(f"Epoch {epoch}: Average loss = {train_loss:.4f}") + + +if __name__ == "__main__": + main() diff --git a/train_vae.py b/train_vae.py index 8a383c3..c1b800e 100644 --- a/train_vae.py +++ b/train_vae.py @@ -664,9 +664,7 @@ def main(): else: repo_name = args.hub_model_id create_repo(repo_name, exist_ok=True, token=args.hub_token) - Repository( - args.output_dir, clone_from=repo_name, token=args.hub_token - ) + Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: if "step_*" not in gitignore: diff --git a/train_vae_v2.py b/train_vae_v2.py new file mode 100644 index 0000000..5077942 --- /dev/null +++ b/train_vae_v2.py @@ -0,0 +1,295 @@ +""" +Script for fine-tuning the Variational Autoencoder (VAE) component of a pre-trained Stable Diffusion model. + +This script focuses on optimizing the VAE for better image reconstruction, +using MinIO for efficient data loading and distributed training capabilities. +""" + +import os +import logging +import math +import pandas as pd +import torch +from torch.utils.data import Dataset, DataLoader +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration +from diffusers import AutoencoderKL +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from tqdm.auto import tqdm +import albumentations as alb +from albumentations.pytorch import ToTensorV2 + +from utils.minio_utils import MinioHandler + +# Will error if the minimal version of diffusers is not installed +check_min_version("0.15.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +def parse_args(): + """Parse training arguments.""" + import argparse + + parser = argparse.ArgumentParser(description="VAE fine-tuning script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models", + ) + parser.add_argument( + "--output_dir", + type=str, + default="vae-fine-tuned", + help="Output directory for checkpoints and models", + ) + parser.add_argument( + "--data_path", + type=str, + required=True, + help="Path to CSV file containing training data paths", + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help="Resolution for training images", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=16, + help="Batch size (per device) for training", + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=100, + help="Number of training epochs", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before backward pass", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate", + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="Mixed precision training mode", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=4, + help="Number of dataloader workers", + ) + # MinIO configuration + parser.add_argument("--minio_endpoint", type=str, required=True) + parser.add_argument("--minio_access_key", type=str, required=True) + parser.add_argument("--minio_secret_key", type=str, required=True) + parser.add_argument("--minio_bucket", type=str, required=True) + + args = parser.parse_args() + return args + + +class TrainingDataset(Dataset): + """Dataset for VAE training using MinIO storage.""" + + def __init__( + self, + minio_handler: MinioHandler, + data_paths: list, + resolution: int = 512, + ): + """ + Initialize the dataset. + + Args: + minio_handler: Initialized MinIO handler + data_paths: List of image paths in MinIO + resolution: Target image resolution + """ + self.minio = minio_handler + self.image_paths = data_paths + self.resolution = resolution + + self.transform = alb.Compose( + [ + alb.SmallestMaxSize(max_size=resolution), + alb.CenterCrop(resolution, resolution), + alb.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + ToTensorV2(), + ] + ) + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + image_path = self.image_paths[idx] + + try: + # Load and preprocess image + image = self.minio.download_image(image_path) + transformed = self.transform(image=image) + return {"pixel_values": transformed["image"]} + except Exception as e: + logger.error(f"Error loading image {image_path}: {str(e)}") + # Return a random noise image as fallback + return {"pixel_values": torch.randn(3, self.resolution, self.resolution)} + + +def main(): + args = parse_args() + + # Initialize accelerator + accelerator_project_config = ProjectConfiguration() + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + # Initialize MinIO handler + minio_handler = MinioHandler( + endpoint=args.minio_endpoint, + access_key=args.minio_access_key, + secret_key=args.minio_secret_key, + bucket_name=args.minio_bucket, + ) + + # Load data paths + df = pd.read_csv(args.data_path) + image_paths = df["path"].tolist() + + # Create dataset and dataloader + dataset = TrainingDataset( + minio_handler=minio_handler, + data_paths=image_paths, + resolution=args.resolution, + ) + + dataloader = DataLoader( + dataset, + batch_size=args.train_batch_size, + shuffle=True, + num_workers=args.dataloader_num_workers, + ) + + # Load VAE model + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + ) + + # Optimizer + optimizer = torch.optim.AdamW( + vae.parameters(), + lr=args.learning_rate, + ) + + # Get number of training steps + num_update_steps_per_epoch = math.ceil( + len(dataloader) / args.gradient_accumulation_steps + ) + num_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + # Learning rate scheduler + lr_scheduler = get_scheduler( + "cosine", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=num_train_steps, + ) + + # Prepare everything with accelerator + vae, optimizer, dataloader, lr_scheduler = accelerator.prepare( + vae, optimizer, dataloader, lr_scheduler + ) + + # Train! + total_batch_size = ( + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(dataset)}") + logger.info(f" Num epochs = {args.num_train_epochs}") + logger.info(f" Batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size = {total_batch_size}") + logger.info(f" Total optimization steps = {num_train_steps}") + + global_step = 0 + for epoch in range(args.num_train_epochs): + vae.train() + train_loss = 0.0 + + progress_bar = tqdm( + total=num_update_steps_per_epoch, + disable=not accelerator.is_local_main_process, + ) + progress_bar.set_description(f"Epoch {epoch}") + + for step, batch in enumerate(dataloader): + with accelerator.accumulate(vae): + # Get VAE loss + loss = vae(batch["pixel_values"], return_dict=False)[0] + accelerator.backward(loss) + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(vae.parameters(), 1.0) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + progress_bar.update(1) + global_step += 1 + train_loss += loss.detach().item() + + logs = { + "loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + "step": global_step, + } + progress_bar.set_postfix(**logs) + + progress_bar.close() + + accelerator.wait_for_everyone() + + # Save checkpoint + if accelerator.is_main_process: + pipeline = accelerator.unwrap_model(vae) + pipeline.save_pretrained( + os.path.join(args.output_dir, f"checkpoint-{global_step}") + ) + + train_loss = train_loss / num_update_steps_per_epoch + logger.info(f"Epoch {epoch}: Average loss = {train_loss:.4f}") + + +if __name__ == "__main__": + main() diff --git a/utils/minio_utils.py b/utils/minio_utils.py new file mode 100644 index 0000000..374392a --- /dev/null +++ b/utils/minio_utils.py @@ -0,0 +1,163 @@ +""" +MinIO utility functions for handling S3-compatible storage operations. + +This module provides a clean interface for interacting with MinIO/S3 storage, +replacing the previous pcache_fileio implementation. +""" + +import cv2 +import numpy as np +from minio import Minio +from typing import Optional, Union, BinaryIO +import io + + +class MinioHandler: + def __init__( + self, + endpoint: str, + access_key: str, + secret_key: str, + bucket_name: str, + secure: bool = True, + ): + """ + Initialize MinIO client with credentials and configuration. + + Args: + endpoint (str): MinIO server endpoint + access_key (str): Access key for authentication + secret_key (str): Secret key for authentication + bucket_name (str): Default bucket to use + secure (bool): Whether to use HTTPS (default: True) + """ + self.client = Minio( + endpoint=endpoint, + access_key=access_key, + secret_key=secret_key, + secure=secure, + ) + self.bucket_name = bucket_name + + # Ensure bucket exists + if not self.client.bucket_exists(bucket_name): + raise ValueError(f"Bucket {bucket_name} does not exist") + + def download_file(self, object_name: str) -> bytes: + """ + Download a file from MinIO storage. + + Args: + object_name (str): Name of the object to download + + Returns: + bytes: File contents as bytes + + Raises: + Exception: If download fails + """ + try: + response = self.client.get_object(self.bucket_name, object_name) + return response.read() + except Exception as e: + raise Exception(f"Failed to download {object_name}: {str(e)}") + finally: + response.close() + response.release_conn() + + def download_image(self, object_name: str) -> np.ndarray: + """ + Download and decode an image from MinIO storage. + + Args: + object_name (str): Name of the image object + + Returns: + np.ndarray: Decoded image as numpy array + + Raises: + Exception: If download or decoding fails + """ + try: + content = self.download_file(object_name) + img_array = np.frombuffer(content, dtype=np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) + if img is None: + raise ValueError("Failed to decode image") + return img + except Exception as e: + raise Exception(f"Failed to download/decode image {object_name}: {str(e)}") + + def upload_file( + self, + file_data: Union[bytes, BinaryIO], + object_name: str, + content_type: Optional[str] = None, + ) -> None: + """ + Upload a file to MinIO storage. + + Args: + file_data: File contents as bytes or file-like object + object_name (str): Name to give the uploaded object + content_type (str, optional): Content type of the file + + Raises: + Exception: If upload fails + """ + try: + if isinstance(file_data, bytes): + file_data = io.BytesIO(file_data) + + file_size = file_data.seek(0, 2) + file_data.seek(0) + + self.client.put_object( + bucket_name=self.bucket_name, + object_name=object_name, + data=file_data, + length=file_size, + content_type=content_type, + ) + except Exception as e: + raise Exception(f"Failed to upload {object_name}: {str(e)}") + + def list_objects(self, prefix: str = "", recursive: bool = True): + """ + List objects in the bucket with optional prefix filtering. + + Args: + prefix (str): Filter objects by prefix + recursive (bool): Whether to list objects recursively in directories + + Returns: + Generator yielding object names + """ + try: + objects = self.client.list_objects( + self.bucket_name, prefix=prefix, recursive=recursive + ) + return (obj.object_name for obj in objects) + except Exception as e: + raise Exception(f"Failed to list objects: {str(e)}") + + def read_json(self, object_name: str) -> dict: + """ + Read and parse a JSON file from MinIO storage. + + Args: + object_name (str): Name of the JSON object + + Returns: + dict: Parsed JSON content + + Raises: + Exception: If reading or parsing fails + """ + import json + + try: + content = self.download_file(object_name) + return json.loads(content.decode("utf-8")) + except Exception as e: + raise Exception(f"Failed to read/parse JSON {object_name}: {str(e)}")