This repository contains advanced implementations of MambaVision adapted for Full Reference Image Quality Assessment (FR-IQA). The models take two images (reference and distorted) as input and output quality scores with state-of-the-art performance.
The repository includes multiple IQA model variants:
- Original MambaVision-IQA: Adapted MambaVision with dual image input and fusion mechanisms
- MultiScale Cross-Attention IQA: Advanced model with DINOv2 backbone and multi-scale cross-attention (recommended)
- Dual Image Input: Models accept both reference and distorted images
- Multiple Fusion Strategies: Concatenation, difference, and attention-based fusion
- Multi-Scale Processing: Advanced cross-attention across different scales
- Regression Output: Continuous quality score prediction
- Standard IQA Metrics: PLCC, SROCC, RMSE, MAE evaluation
- Comprehensive Visualization: Model architecture and data flow visualization tools
- Python 3.8 or higher
- CUDA-compatible GPU (recommended)
- 8GB+ GPU memory for training
Create and activate a virtual environment:
# Using conda (recommended)
conda create -n mambavision-iqa python=3.8
conda activate mambavision-iqa
# Or using venv
python -m venv mambavision-iqa
# On Windows:
mambavision-iqa\Scripts\activate
# On Linux/Mac:
source mambavision-iqa/bin/activateInstall PyTorch with CUDA support:
# For CUDA 12.4 (recommended)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# For CUDA 11.8
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# For CPU only (not recommended for training)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpuInstall all required packages:
# Install from requirements.txt
pip install -r requirements.txt
# Or install manually
pip install mamba-ssm==2.2.4
pip install timm==1.0.15
pip install tensorboardX==2.6.2.2
pip install einops==0.8.1
pip install transformers==4.50.0
pip install Pillow==11.1.0
pip install requests==2.32.3
# Additional dependencies for IQA
pip install scipy scikit-learn pandas pyyaml matplotlib seabornFor enhanced performance and features:
# Mixed precision training (optional but recommended)
pip install apex
# For advanced visualizations
pip install graphviz torchviz
# For distributed training
pip install accelerateTest the installation:
python -c "import torch; print(f'PyTorch version: {torch.__version__}'); print(f'CUDA available: {torch.cuda.is_available()}')"
python -c "from mambavision.models.multi_scale_cross_attention_iqa_model import MultiScaleIQA; print('MultiScaleIQA imported successfully')"The advanced MultiScaleIQA model features:
- DINOv2 Backbone: Pre-trained vision transformer for robust feature extraction
- Multi-Scale Processing: Processes images at multiple resolutions (224×224, 448×448, 896×896)
- Cross-Attention Fusion: Advanced attention mechanism between reference and distorted features
- Hierarchical Feature Integration: Combines features from different scales and layers
- Quality Regression Head: Outputs continuous quality scores
The original adaptation includes:
- Separate Patch Embeddings: Individual processing for reference and distorted images
- Configurable Fusion: Three fusion strategies available
- Quality Head: Regression output for quality scores
- Multiple Variants: Tiny (T), Small (S), and Base (B) configurations
- Concatenation (
concat): Concatenates features from both images - Difference (
diff): Computes the difference between reference and distorted features - Attention (
attention): Uses cross-attention mechanism to fuse features
The models expect a CSV file with the following structure:
dist_img,ref_img,dmos,var
I01_01_01.png,I01.png,4.57,0.496
I01_01_02.png,I01.png,4.33,0.869
I01_01_03.png,I01.png,2.67,0.789
I01_01_04.png,I01.png,1.67,0.596Where:
dist_img: Filename of the distorted imageref_img: Filename of the reference imagedmos: Differential Mean Opinion Score (quality score)var: Variance of the score (optional, used for weighting)
Train the MultiScale model (recommended):
python mambavision/train_iqa.py \
--csv_file /path/to/your/dataset.csv \
--image_dir /path/to/your/images \
--model multiscale_iqa_dinov2_base \
--batch_size 8 \
--lr 1e-4 \
--epochs 100 \
--output ./output/multiscale_experiments-
Update configuration files in
configs/:configs/iqa_config_tiny.yaml- For MambaVision-Tconfigs/iqa_config_small.yaml- For MambaVision-Sconfigs/iqa_config_base.yaml- For MambaVision-B
-
Start training:
python mambavision/train_iqa.py --config configs/iqa_config_base.yamlpython mambavision/train_iqa.py \
--csv_file /path/to/dataset.csv \
--image_dir /path/to/images \
--model multiscale_iqa_dinov2_base \
--fusion_type attention \
--batch_size 8 \
--lr 1e-4 \
--weight_decay 1e-5 \
--epochs 100 \
--warmup_epochs 10 \
--eval_metric rmse \
--normalize_scores \
--amp \
--channels_last \
--grad_checkpointing \
--output ./output/advanced_training \
--save_checkpoint_freq 10 \
--log_freq 100| Parameter | Description | Default | Options |
|---|---|---|---|
--model |
Model architecture | multiscale_iqa_dinov2_base |
mamba_vision_iqa_T/S/B, multiscale_iqa_dinov2_base |
--fusion_type |
Fusion strategy | attention |
concat, diff, attention |
--batch_size |
Training batch size | 8 | Adjust based on GPU memory |
--lr |
Learning rate | 1e-4 | 1e-5 to 1e-3 |
--epochs |
Training epochs | 100 | 50-200 |
--amp |
Mixed precision | False | Enable for faster training |
--normalize_scores |
Score normalization | False | Recommended for better convergence |
Validation is automatically performed during training. Monitor progress with TensorBoard:
python launch_tensorboard.py --logdir ./output/your_experiment/tensorboardEvaluate a trained model on test data:
python mambavision/validate_iqa.py \
--csv_file /path/to/test_dataset.csv \
--image_dir /path/to/test_images \
--model multiscale_iqa_dinov2_base \
--checkpoint /path/to/model_best.pth.tar \
--split test \
--batch_size 32 \
--results_file validation_results.csvTest model performance on datasets with Mean Opinion Scores:
python test_iqa_with_mos.py \
--csv_file /path/to/mos_dataset.csv \
--image_dir /path/to/images \
--model multiscale_iqa_dinov2_base \
--checkpoint /path/to/checkpoint.pth \
--batch_size 16 \
--output_file mos_test_results.csvGenerate quality scores for unlabeled image pairs:
python test_iqa_unlabeled.py \
--ref_dir /path/to/reference_images \
--dist_dir /path/to/distorted_images \
--model multiscale_iqa_dinov2_base \
--checkpoint /path/to/checkpoint.pth \
--batch_size 16 \
--output_file unlabeled_predictions.csvThe models are evaluated using standard IQA metrics:
- PLCC: Pearson Linear Correlation Coefficient (higher is better)
- SROCC: Spearman Rank Order Correlation Coefficient (higher is better)
- RMSE: Root Mean Square Error (lower is better)
- MAE: Mean Absolute Error (lower is better)
Generate comprehensive model architecture diagrams:
python visualize_model_architecture.py \
--model multiscale_iqa_dinov2_base \
--output_dir ./visualizations \
--save_diagrams \
--show_detailsThis creates:
- Model architecture diagram
- Data flow visualization
- Layer-wise parameter analysis
- Fusion mechanism illustration
Monitor training with TensorBoard:
# Launch TensorBoard
python launch_tensorboard.py --logdir ./output
# Or manually
tensorboard --logdir ./output --port 6006Test visualization tools:
python test_visualization.py| Configuration | Parameters | Input Sizes | Memory | Performance |
|---|---|---|---|---|
| DINOv2 Base | ~86M | 224×224, 448×448, 896×896 | ~16GB | Best |
| Model | Parameters | Input Size | Batch Size | Memory | Speed |
|---|---|---|---|---|---|
| MambaVision-IQA-T | ~7M | 224×224 | 16 | ~6GB | Fast |
| MambaVision-IQA-S | ~26M | 224×224 | 12 | ~8GB | Medium |
| MambaVision-IQA-B | ~75M | 224×224 | 8 | ~12GB | Slow |
# Dataset Configuration
csv_file: "path/to/dataset.csv"
image_dir: "path/to/images"
train_split: 0.8
val_split: 0.1
test_split: 0.1
# Model Configuration
model: "multiscale_iqa_dinov2_base"
fusion_type: "attention"
img_size: 224
pretrained: true
# Training Configuration
batch_size: 8
lr: 1e-4
weight_decay: 1e-5
epochs: 100
warmup_epochs: 10
eval_metric: "rmse"
# Optimization
amp: true
channels_last: true
grad_checkpointing: false
# Data Processing
normalize_scores: true
score_range: [0, 5]
workers: 4
# Output
output: "./output/iqa_experiments"
save_checkpoint_freq: 10
log_freq: 100- Image Quality: Ensure high-quality, properly aligned image pairs
- Score Distribution: Check for balanced score distribution across quality levels
- Data Augmentation: Use moderate augmentation (rotation, flip) - avoid heavy distortions
- Normalization: Enable score normalization for better convergence
- Batch Size: Use largest batch size that fits in GPU memory
- Learning Rate: Start with 1e-4, reduce if loss plateaus
- Mixed Precision: Enable AMP for faster training and lower memory usage
- Gradient Checkpointing: Use for larger models to save memory
- Early Stopping: Monitor validation metrics to prevent overfitting
- MultiScale Model: Best performance, requires more resources
- Original MambaVision: Faster training, good for experimentation
- Fusion Type: Attention usually performs best but is slower
- Model Size: Balance between performance and computational requirements
# Recommended training command for best performance
python mambavision/train_iqa.py \
--model multiscale_iqa_dinov2_base \
--amp \
--channels_last \
--normalize_scores \
--fusion_type attention \
--batch_size 8 \
--lr 1e-4 \
--warmup_epochs 10 \
--grad_checkpointing-
CUDA Out of Memory
- Reduce batch size
- Enable gradient checkpointing
- Use mixed precision (--amp)
- Try smaller model variant
-
Poor Convergence
- Lower learning rate (1e-5)
- Enable score normalization
- Check data quality and distribution
- Try different fusion type
-
NaN Loss
- Check input data for invalid values
- Reduce learning rate
- Enable gradient clipping
- Verify score normalization
-
Slow Training
- Enable mixed precision (--amp)
- Use channels_last memory format
- Increase number of workers
- Check data loading bottlenecks
-
Import Errors
- Verify all dependencies are installed
- Check Python and CUDA versions
- Reinstall mamba-ssm if needed
# Check GPU utilization
nvidia-smi
# Profile training
python -m torch.profiler mambavision/train_iqa.py [args]
# Memory debugging
python -c "import torch; print(torch.cuda.memory_summary())"Training outputs are organized as follows:
output/
├── experiment_name/
│ ├── args.yaml # Training configuration
│ ├── model_best.pth.tar # Best model checkpoint
│ ├── checkpoint_epoch_*.pth.tar # Regular checkpoints
│ ├── tensorboard/ # TensorBoard logs
│ ├── training_log.txt # Training progress log
│ └── validation_results.csv # Validation metrics
├── visualizations/ # Generated visualizations
│ ├── model_architecture.png
│ ├── data_flow_diagram.png
│ └── training_curves.png
└── test_results/ # Test outputs
├── mos_predictions.csv
└── unlabeled_scores.csv
# Multi-GPU training
python -m torch.distributed.launch --nproc_per_node=2 \
mambavision/train_iqa.py \
--model multiscale_iqa_dinov2_base \
--distributedExtend the MultiScaleIQA model:
from mambavision.models.multi_scale_cross_attention_iqa_model import MultiScaleIQA
class CustomIQAModel(MultiScaleIQA):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Add custom layers
def forward(self, ref_imgs, dist_imgs):
# Custom forward pass
return super().forward(ref_imgs, dist_imgs)The models can be integrated with other deep learning frameworks and pipelines. See the implementation files for detailed API documentation.
If you use this IQA implementation, please cite the original MambaVision paper:
@article{mambavision2024,
title={MambaVision: A Hybrid Mamba-Transformer Vision Backbone},
author={Author et al.},
journal={arXiv preprint arXiv:xxxx.xxxxx},
year={2024}
}This project follows the same license as the original MambaVision implementation.
For issues, questions, or contributions:
- Check the troubleshooting section above
- Review existing issues in the repository
- Create detailed bug reports with reproduction steps
- Follow the coding standards for contributions
- Added MultiScale Cross-Attention IQA model with DINOv2 backbone
- Enhanced visualization tools for model architecture
- Improved training stability and performance
- Added comprehensive evaluation and testing scripts
- Updated documentation with complete installation and usage guide