Skip to content

IVP-Lab/QualiFusion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MambaVision for Image Quality Assessment (IQA)

This repository contains advanced implementations of MambaVision adapted for Full Reference Image Quality Assessment (FR-IQA). The models take two images (reference and distorted) as input and output quality scores with state-of-the-art performance.

Overview

The repository includes multiple IQA model variants:

  1. Original MambaVision-IQA: Adapted MambaVision with dual image input and fusion mechanisms
  2. MultiScale Cross-Attention IQA: Advanced model with DINOv2 backbone and multi-scale cross-attention (recommended)

Key Features

  • Dual Image Input: Models accept both reference and distorted images
  • Multiple Fusion Strategies: Concatenation, difference, and attention-based fusion
  • Multi-Scale Processing: Advanced cross-attention across different scales
  • Regression Output: Continuous quality score prediction
  • Standard IQA Metrics: PLCC, SROCC, RMSE, MAE evaluation
  • Comprehensive Visualization: Model architecture and data flow visualization tools

Installation

Prerequisites

  • Python 3.8 or higher
  • CUDA-compatible GPU (recommended)
  • 8GB+ GPU memory for training

Step 1: Environment Setup

Create and activate a virtual environment:

# Using conda (recommended)
conda create -n mambavision-iqa python=3.8
conda activate mambavision-iqa

# Or using venv
python -m venv mambavision-iqa
# On Windows:
mambavision-iqa\Scripts\activate
# On Linux/Mac:
source mambavision-iqa/bin/activate

Step 2: Install PyTorch

Install PyTorch with CUDA support:

# For CUDA 12.4 (recommended)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

# For CUDA 11.8
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# For CPU only (not recommended for training)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

Step 3: Install Dependencies

Install all required packages:

# Install from requirements.txt
pip install -r requirements.txt

# Or install manually
pip install mamba-ssm==2.2.4
pip install timm==1.0.15
pip install tensorboardX==2.6.2.2
pip install einops==0.8.1
pip install transformers==4.50.0
pip install Pillow==11.1.0
pip install requests==2.32.3

# Additional dependencies for IQA
pip install scipy scikit-learn pandas pyyaml matplotlib seaborn

Step 4: Optional Dependencies

For enhanced performance and features:

# Mixed precision training (optional but recommended)
pip install apex

# For advanced visualizations
pip install graphviz torchviz

# For distributed training
pip install accelerate

Step 5: Verify Installation

Test the installation:

python -c "import torch; print(f'PyTorch version: {torch.__version__}'); print(f'CUDA available: {torch.cuda.is_available()}')"
python -c "from mambavision.models.multi_scale_cross_attention_iqa_model import MultiScaleIQA; print('MultiScaleIQA imported successfully')"

Model Architecture

MultiScale Cross-Attention IQA (Recommended)

The advanced MultiScaleIQA model features:

  1. DINOv2 Backbone: Pre-trained vision transformer for robust feature extraction
  2. Multi-Scale Processing: Processes images at multiple resolutions (224×224, 448×448, 896×896)
  3. Cross-Attention Fusion: Advanced attention mechanism between reference and distorted features
  4. Hierarchical Feature Integration: Combines features from different scales and layers
  5. Quality Regression Head: Outputs continuous quality scores

Original MambaVision-IQA

The original adaptation includes:

  1. Separate Patch Embeddings: Individual processing for reference and distorted images
  2. Configurable Fusion: Three fusion strategies available
  3. Quality Head: Regression output for quality scores
  4. Multiple Variants: Tiny (T), Small (S), and Base (B) configurations

Fusion Types (Original Model)

  • Concatenation (concat): Concatenates features from both images
  • Difference (diff): Computes the difference between reference and distorted features
  • Attention (attention): Uses cross-attention mechanism to fuse features

Dataset Format

The models expect a CSV file with the following structure:

dist_img,ref_img,dmos,var
I01_01_01.png,I01.png,4.57,0.496
I01_01_02.png,I01.png,4.33,0.869
I01_01_03.png,I01.png,2.67,0.789
I01_01_04.png,I01.png,1.67,0.596

Where:

  • dist_img: Filename of the distorted image
  • ref_img: Filename of the reference image
  • dmos: Differential Mean Opinion Score (quality score)
  • var: Variance of the score (optional, used for weighting)

Training

Quick Start

Train the MultiScale model (recommended):

python mambavision/train_iqa.py \
    --csv_file /path/to/your/dataset.csv \
    --image_dir /path/to/your/images \
    --model multiscale_iqa_dinov2_base \
    --batch_size 8 \
    --lr 1e-4 \
    --epochs 100 \
    --output ./output/multiscale_experiments

Configuration-Based Training

  1. Update configuration files in configs/:

    • configs/iqa_config_tiny.yaml - For MambaVision-T
    • configs/iqa_config_small.yaml - For MambaVision-S
    • configs/iqa_config_base.yaml - For MambaVision-B
  2. Start training:

python mambavision/train_iqa.py --config configs/iqa_config_base.yaml

Advanced Training Options

python mambavision/train_iqa.py \
    --csv_file /path/to/dataset.csv \
    --image_dir /path/to/images \
    --model multiscale_iqa_dinov2_base \
    --fusion_type attention \
    --batch_size 8 \
    --lr 1e-4 \
    --weight_decay 1e-5 \
    --epochs 100 \
    --warmup_epochs 10 \
    --eval_metric rmse \
    --normalize_scores \
    --amp \
    --channels_last \
    --grad_checkpointing \
    --output ./output/advanced_training \
    --save_checkpoint_freq 10 \
    --log_freq 100

Key Training Parameters

Parameter Description Default Options
--model Model architecture multiscale_iqa_dinov2_base mamba_vision_iqa_T/S/B, multiscale_iqa_dinov2_base
--fusion_type Fusion strategy attention concat, diff, attention
--batch_size Training batch size 8 Adjust based on GPU memory
--lr Learning rate 1e-4 1e-5 to 1e-3
--epochs Training epochs 100 50-200
--amp Mixed precision False Enable for faster training
--normalize_scores Score normalization False Recommended for better convergence

Evaluation and Testing

Validation During Training

Validation is automatically performed during training. Monitor progress with TensorBoard:

python launch_tensorboard.py --logdir ./output/your_experiment/tensorboard

Post-Training Validation

Evaluate a trained model on test data:

python mambavision/validate_iqa.py \
    --csv_file /path/to/test_dataset.csv \
    --image_dir /path/to/test_images \
    --model multiscale_iqa_dinov2_base \
    --checkpoint /path/to/model_best.pth.tar \
    --split test \
    --batch_size 32 \
    --results_file validation_results.csv

Testing with MOS Scores

Test model performance on datasets with Mean Opinion Scores:

python test_iqa_with_mos.py \
    --csv_file /path/to/mos_dataset.csv \
    --image_dir /path/to/images \
    --model multiscale_iqa_dinov2_base \
    --checkpoint /path/to/checkpoint.pth \
    --batch_size 16 \
    --output_file mos_test_results.csv

Testing on Unlabeled Data

Generate quality scores for unlabeled image pairs:

python test_iqa_unlabeled.py \
    --ref_dir /path/to/reference_images \
    --dist_dir /path/to/distorted_images \
    --model multiscale_iqa_dinov2_base \
    --checkpoint /path/to/checkpoint.pth \
    --batch_size 16 \
    --output_file unlabeled_predictions.csv

Evaluation Metrics

The models are evaluated using standard IQA metrics:

  • PLCC: Pearson Linear Correlation Coefficient (higher is better)
  • SROCC: Spearman Rank Order Correlation Coefficient (higher is better)
  • RMSE: Root Mean Square Error (lower is better)
  • MAE: Mean Absolute Error (lower is better)

Visualization

Model Architecture Visualization

Generate comprehensive model architecture diagrams:

python visualize_model_architecture.py \
    --model multiscale_iqa_dinov2_base \
    --output_dir ./visualizations \
    --save_diagrams \
    --show_details

This creates:

  • Model architecture diagram
  • Data flow visualization
  • Layer-wise parameter analysis
  • Fusion mechanism illustration

Training Progress Visualization

Monitor training with TensorBoard:

# Launch TensorBoard
python launch_tensorboard.py --logdir ./output

# Or manually
tensorboard --logdir ./output --port 6006

Custom Visualization

Test visualization tools:

python test_visualization.py

Model Variants and Performance

MultiScale Cross-Attention IQA

Configuration Parameters Input Sizes Memory Performance
DINOv2 Base ~86M 224×224, 448×448, 896×896 ~16GB Best

Original MambaVision-IQA

Model Parameters Input Size Batch Size Memory Speed
MambaVision-IQA-T ~7M 224×224 16 ~6GB Fast
MambaVision-IQA-S ~26M 224×224 12 ~8GB Medium
MambaVision-IQA-B ~75M 224×224 8 ~12GB Slow

Configuration Files

Sample Configuration (iqa_config_base.yaml)

# Dataset Configuration
csv_file: "path/to/dataset.csv"
image_dir: "path/to/images"
train_split: 0.8
val_split: 0.1
test_split: 0.1

# Model Configuration
model: "multiscale_iqa_dinov2_base"
fusion_type: "attention"
img_size: 224
pretrained: true

# Training Configuration
batch_size: 8
lr: 1e-4
weight_decay: 1e-5
epochs: 100
warmup_epochs: 10
eval_metric: "rmse"

# Optimization
amp: true
channels_last: true
grad_checkpointing: false

# Data Processing
normalize_scores: true
score_range: [0, 5]
workers: 4

# Output
output: "./output/iqa_experiments"
save_checkpoint_freq: 10
log_freq: 100

Best Practices and Tips

Data Preparation

  1. Image Quality: Ensure high-quality, properly aligned image pairs
  2. Score Distribution: Check for balanced score distribution across quality levels
  3. Data Augmentation: Use moderate augmentation (rotation, flip) - avoid heavy distortions
  4. Normalization: Enable score normalization for better convergence

Training Optimization

  1. Batch Size: Use largest batch size that fits in GPU memory
  2. Learning Rate: Start with 1e-4, reduce if loss plateaus
  3. Mixed Precision: Enable AMP for faster training and lower memory usage
  4. Gradient Checkpointing: Use for larger models to save memory
  5. Early Stopping: Monitor validation metrics to prevent overfitting

Model Selection

  1. MultiScale Model: Best performance, requires more resources
  2. Original MambaVision: Faster training, good for experimentation
  3. Fusion Type: Attention usually performs best but is slower
  4. Model Size: Balance between performance and computational requirements

Performance Optimization

# Recommended training command for best performance
python mambavision/train_iqa.py \
    --model multiscale_iqa_dinov2_base \
    --amp \
    --channels_last \
    --normalize_scores \
    --fusion_type attention \
    --batch_size 8 \
    --lr 1e-4 \
    --warmup_epochs 10 \
    --grad_checkpointing

Troubleshooting

Common Issues and Solutions

  1. CUDA Out of Memory

    • Reduce batch size
    • Enable gradient checkpointing
    • Use mixed precision (--amp)
    • Try smaller model variant
  2. Poor Convergence

    • Lower learning rate (1e-5)
    • Enable score normalization
    • Check data quality and distribution
    • Try different fusion type
  3. NaN Loss

    • Check input data for invalid values
    • Reduce learning rate
    • Enable gradient clipping
    • Verify score normalization
  4. Slow Training

    • Enable mixed precision (--amp)
    • Use channels_last memory format
    • Increase number of workers
    • Check data loading bottlenecks
  5. Import Errors

    • Verify all dependencies are installed
    • Check Python and CUDA versions
    • Reinstall mamba-ssm if needed

Performance Debugging

# Check GPU utilization
nvidia-smi

# Profile training
python -m torch.profiler mambavision/train_iqa.py [args]

# Memory debugging
python -c "import torch; print(torch.cuda.memory_summary())"

Output Structure

Training outputs are organized as follows:

output/
├── experiment_name/
│   ├── args.yaml                  # Training configuration
│   ├── model_best.pth.tar        # Best model checkpoint
│   ├── checkpoint_epoch_*.pth.tar # Regular checkpoints
│   ├── tensorboard/              # TensorBoard logs
│   ├── training_log.txt          # Training progress log
│   └── validation_results.csv    # Validation metrics
├── visualizations/               # Generated visualizations
│   ├── model_architecture.png
│   ├── data_flow_diagram.png
│   └── training_curves.png
└── test_results/                # Test outputs
    ├── mos_predictions.csv
    └── unlabeled_scores.csv

Advanced Usage

Distributed Training

# Multi-GPU training
python -m torch.distributed.launch --nproc_per_node=2 \
    mambavision/train_iqa.py \
    --model multiscale_iqa_dinov2_base \
    --distributed

Custom Model Development

Extend the MultiScaleIQA model:

from mambavision.models.multi_scale_cross_attention_iqa_model import MultiScaleIQA

class CustomIQAModel(MultiScaleIQA):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Add custom layers
        
    def forward(self, ref_imgs, dist_imgs):
        # Custom forward pass
        return super().forward(ref_imgs, dist_imgs)

Integration with Other Frameworks

The models can be integrated with other deep learning frameworks and pipelines. See the implementation files for detailed API documentation.

Citation

If you use this IQA implementation, please cite the original MambaVision paper:

@article{mambavision2024,
  title={MambaVision: A Hybrid Mamba-Transformer Vision Backbone},
  author={Author et al.},
  journal={arXiv preprint arXiv:xxxx.xxxxx},
  year={2024}
}

License

This project follows the same license as the original MambaVision implementation.

Support and Contributing

For issues, questions, or contributions:

  1. Check the troubleshooting section above
  2. Review existing issues in the repository
  3. Create detailed bug reports with reproduction steps
  4. Follow the coding standards for contributions

Changelog

Latest Updates

  • Added MultiScale Cross-Attention IQA model with DINOv2 backbone
  • Enhanced visualization tools for model architecture
  • Improved training stability and performance
  • Added comprehensive evaluation and testing scripts
  • Updated documentation with complete installation and usage guide

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors