Skip to content
/ SNIP Public

Code for the paper "SNIP: An Adaptive Mixed Precision Framework for Subbyte Large Language Model Training" ASPLOS'26

Notifications You must be signed in to change notification settings

pyjhzwh/SNIP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

112 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SNIP

Overview

This project provides the code for the paper "SNIP: An Adaptive Mixed Precision Framework for Subbyte Large Language Model Training" (arXiv:2602.01410), accepted at ASPLOS'26. SNIP is a research project aimed at identifying the optimal quantization configurations per layer in LLM models during training to balance efficiency and quality.

Efficiency Metric

We use the fraction of the number of flops that utilize fp4 instead of fp8 as the proxy metric for efficiency, as we currently do not have access to the NVIDIA Blackwell GPU.

Quality Metrics

We employ two types of metrics for quality evaluation:

  1. Training Quality: Evaluated using training loss and testing benchmark scores.
  2. Proxy Metric for Heuristics: Utilizes the forward loss divergence and backward weight divergence compared to the baseline (high-precision models, bf16) due to quantization.

File Structures

  • snip: The main source files.
    • fpx:
      • fake quantization: fpx_configs.py, fpx_fake_quantization.py
      • estimate forward loss divergence and backward weight divergence: estimate_error_utils.py, process_bwd_err_utils.py
    • module: the linear layer with fake quantization and collecting stats
    • training: optimizer, trainer, trainer_callback
  • script: Scripts for training, collecting stats, generating quantization schemes, and evaluation.
    • eval: Use lm-eval-harness to test on benchmarks
    • configs: model train and collect stats configurations in json files.
    • run_train_fpx_configs.sh: Bash script to train models (bf16 or fpx).
    • run_collect_stats.sh: Bash script to run additional passes to collect stats.
    • generate_fpx_configs.py: Python script to calculate the proxy accuracy metric (loss divergence and backward divergence). Then it will use integer linear programming (ILP) to generate the quantization precision per layer. This script is run on CPU and is IO-bound.
  • tests: Test files.
  • run_clm.py: LLM training script, Adapted from huggingface

Prepare Dataset

StarCoderData

Starcoderdata It contains 250B tokens, and size of the downloaded dataset files is 311 GB. We only sample 1% of dataset, then tokenize and group with block_size=2048.

python script/prepare_dataset/prepare_starcoder.py \
  --source_path /path/to/starcoderdata_2.5B/ \
  --target_path /path/to/starcoderdata_2.5B_grouped_2048 \
  --tokenizer_path /path/to/tokenizer/ \
  --tokenize_and_group \
  --block_size 2048

SlimPajama

Slimpajama-627B It contains 628B tokens, and size of the downloaded dataset files is 895 GB. So we use another sampled dataset SlimPajama-6B that is 1% of the original dataset, resulting in 6B tokens.

python script/prepare_dataset/prepare_slimparjama.py \
  --source_path /path/to/SlimPajama-6B/data/ \
  --target_path /path/to/SlimPajama-6B_grouped_2048 \
  --tokenizer_path /path/to/tokenizer/ \
  --block_size 2048

Combine StarCoderData and SlimPajama

TinyLlama is trained on the mixed of SlimPajama and StarCoderData, we use the script to combine them to a single dataset

python script/prepare_dataset/merge_datasets.py \
  --source_paths /path/to/SlimPajama-6B_grouped_2048 /path/to/starcoderdata_2.5B_grouped_2048 \
  --target_path /path/to/merged_dataset \
  --split train

cp /path/to/SlimPajama-6B_grouped_2048/validation/* /path/to/merged_dataset/validation/

echo '{"splits": ["train", "validation"]}' > /path/to/merged_dataset/dataset_dict.json

RedPajama

Redpajama-1B. We use the 1B-token sample of the RedPajama-1T dataset.

python script/prepare_dataset/prepare_redparjama.py \
  --source_path /path/to/RedPajama-Data-1T-Sample \
  --target_path /path/to/RedPajama-1B_grouped_2048 \
  --tokenizer_path openlm-research/open_llama_3b \
  --block_size 2048

Usage

Set the training and collecting stats configurations

Edit script/example.config accordingly. This config includes paths and hyperparameters necessary for training and collecting statistics.

Model & Checkpoints

  • model_name_or_path: Path to the base pretrained model (e.g., "meta-llama/Llama-2-7b-hf" or your own fine-tuned model directory).
  • resume_from_checkpoint: Path to a full training checkpoint that includes the optimizer state. Useful for resuming training exactly where it left off. Set to null if not resuming.
  • resume_from_checkpoint_wo_optimizer: Path to a checkpoint without optimizer state. Useful if you only want to load model weights but start with a fresh optimizer. Also set to null if not used.

Dataset

  • dataset_name: Path to the raw dataset. It will be tokenized and prepared for training.
  • grouped_dataset_name: Path to the processed dataset, where input tokens have been grouped into fixed-length chunks (like sliding windows). This avoids redundant tokenization during training. The script/prepare_dataset scripts will generate the grouped dataset.
  • cache_dir: Path to Hugging Face’s cache directory.

Model Architecture

  • num_layers:Total number of layers in your model. Used to identify and iterate over specific submodules for stats collection.
  • layer_types: List of linear layer components (like attention and MLP projections) you want to analyze. Example: "self_attn.q_proj" refers to the query projection in a self-attention block.

Training & Evaluation

  • num_ranks: Number of GPU ranks (or processes) being used. Typically = number of GPUs.

  • global_batch_size: Total batch size across all devices. Example: if per_device_train_batch_size = 1 and num_ranks = 8, then global_batch_size = 8.

  • per_device_train_batch_size / per_device_eval_batch_size: Batch size per GPU for training and evaluation, respectively.

  • learning_rate: Initial learning rate for the optimizer.

  • lr_scheduler_typeL Type of learning rate scheduler to use (e.g., "constant", "cosine").

  • gradient_checkpointing: Whether to use gradient checkpointing to save memory at the cost of compute.

  • train_max_step: The end step of training

  • save_steps: Save model checkpoints every n steps during training.

  • additional_bf16_fwd_every_n_steps: Inject additional bfloat16 forward passes every n steps. Used for evaluating model deviation (divergence) compared to high-precision (BF16) baseline.

Stats Collection

  • add_noise_num: Number of times to add noise for variance estimation during stats collection.

  • save_grad_dir: Directory to save gradients. Can require significant space—ensure you have room.

  • meta_parent_dir_path: Base directory where all output stats, logs, and metadata will be stored.

  • num_random: Number of random quantization schemes need to be generated for a given efficiency saving target.

  • collect_stats_max_steps Max number of steps to run while collecting stats. It is the start step of the training + 1.

Collect the stats

bash script/run_collect_stats.sh --config_file script/example.config [--use_accelerate]
  • --config_file (required): Path to the JSON config file (same format as example.config). This provides the core setup for model, dataset, optimizer, and training.
  • --use_accelerate flag is optional, use it if you're leveraging 🤗 Accelerate for multi-GPU or distributed setups. If you only use DDP (Data Distributed Parallelism), no need to set this.

What does it do?

  1. Get the basic metadata for estimating the proxy quality metric
  2. For the backward weight divergence estimation, we need a few more passes to add noise in the backward pass of the last layer to estimate the Frobenius Norm of some gradients.
  3. For the backward weight divergence estimation, we need a few more passes to estimate the relationship of the foward loss error and the last layer's tensor error by adding a noise to the last layer in the forward pass.

Run the analysis to generate the quantization schemes

python script/generate_fpx_configs.py --efficiency_saving <efficiency_saving> --num_random <num_random, default is 5> --random_seed <random_seed, default is 0>
  • efficiency_saving: The fraction (floating point from 0 to 1) of the number of flops using fp4 instead of fp8.
  • num_random: The number of random quantization schemes generated that meet the efficiency_saving budget.
  • random_seed: For reproducibility of the random quantization schemes.

For a given efficiency_saving budget, this script will generate three different types of quantization schemes:

  1. best: Our heuristics that minimize the forward loss error and backward weight divergence during training.
  2. random: Randomly select some layers to apply fp4 instead of fp8 that meet the efficiency_saving budget. We generate num_random of such random quantization schemes.
  3. best_quant_err: Uses the sum of quantization error for every tensor (input, weight, and gradient) as the proxy quality metric. This shows that using local information (quantization error per tensor) is far from the optimal quantization schemes.
  4. best_rel_quant_err: Uses the sum of relative quantization error for every tensor (input, weight, and gradient) as the proxy quality metric. This shows that using local information (quantization error per tensor) is far from the optimal quantization schemes. All quantization schemes are saved in .pt files.

Train the model using different precisions or different quantization schemes

bash script/run_train_fpx_configs.sh \
  --config_file path/to/example.config \
  [--efficiency_saving 0.5] \
  [--precision bf16|fp8|fp4] \
  [--use_accelerate]
  • --config_file (required): Path to the JSON config file (same format as example.config). This provides the core setup for model, dataset, optimizer, and training.
  • --efficiency_saving (optional): Set this as a floating point number to run the training for all different quantization schemes under the current effiency saving budget (generated in preious step)
  • --precision (optional): Default precision for all layers: bf16, fp8, or fp4. Do not set both precision or efficinecy_saving.
  • --use_accelerate flag is optional, use it if you're leveraging 🤗 Accelerate for multi-GPU or distributed setups (like DeepSpeed). If you only use DDP (Data Distributed Parallelism), no need to set this.

Evaluate on test benchmarks

Q&A

Why do we need a proxy quality metric? Why not try to quantize each layer and test?

Quantizing each layer and testing the model's performance would be ideal but is computationally expensive and time-consuming. Large language models have numerous layers, and evaluating every possible quantization configuration would require significant resources. For example, for a LLama3.2 1B model, we have 16 blocks, and each block has 7 layers (Wq, Wk, Wv, Wo, gate_proj, up_proj, down_proj). In total there are 112 linear layers (we do not consider lm_head), which means we need additional 112 fwd+bwd passes to evaluate the quality impact when quantizing each layer using one quantization scheme.

A proxy quality metric allows us to estimate the impact of quantization on model performance more efficiently. By using metrics like forward loss error and backward weight error, we can quickly identify promising quantization schemes without exhaustive testing. This approach balances the need for efficiency and quality, enabling us to find optimal configurations in a more practical manner. In our heuristics, we only need 1 + 3 + 3 = 7 additional passes to evaluate the quality impact when quantizing each layer using all different quantization schemes.

Why do we care the backward weight divergence?

Backward weight divergence is crucial because it directly impacts the stability and convergence of the training process. When quantization introduces significant errors in the backward pass, it can lead to incorrect weight updates, causing the model to diverge or converge to suboptimal solutions. By minimizing backward weight divergence, we ensure that the quantized model maintains training stability and achieves performance comparable to the non-quantized baseline. This is particularly important for large language models, where even small errors can accumulate over many layers and iterations, significantly affecting the final model quality.

About

Code for the paper "SNIP: An Adaptive Mixed Precision Framework for Subbyte Large Language Model Training" ASPLOS'26

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published