diff --git a/example/lptm_quantization.ipynb b/example/lptm_quantization.ipynb new file mode 100644 index 0000000..61e6bc8 --- /dev/null +++ b/example/lptm_quantization.ipynb @@ -0,0 +1,80 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from samay.model import LPTMModel\n", + "\n", + "config = {\n", + " \"task_name\": \"forecasting\",\n", + " \"forecast_horizon\": 192,\n", + " \"head_dropout\": 0,\n", + " \"weight_decay\": 0,\n", + " \"max_patch\": 16,\n", + " \"freeze_encoder\": True, # Freeze the patch embedding layer\n", + " \"freeze_embedder\": True, # Freeze the transformer encoder\n", + " \"freeze_head\": False, # The linear forecasting head must be trained\n", + " \"freeze_segment\": True, # Freeze the segmention module\n", + "}\n", + "model = LPTMModel(config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from samay.dataset import LPTMDataset\n", + "\n", + "train_dataset = LPTMDataset(\n", + " name=\"ett\",\n", + " datetime_col=\"date\",\n", + " path=\"../data/data/ETTh1.csv\",\n", + " mode=\"train\",\n", + " horizon=192,\n", + ")\n", + "\n", + "finetuned_model = model.finetune(train_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from src.samay.model import LPTMModel\n", + "\n", + "model = model.quantize(quant_type = \"int8\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "val_dataset = LPTMDataset(\n", + " name=\"ett\",\n", + " datetime_col=\"date\",\n", + " path=\"../data/data/ETTh1.csv\",\n", + " mode=\"train\",\n", + " horizon=192,\n", + ")\n", + "metrics, trues, preds, histories = model.evaluate(val_dataset, task_name=\"forecasting\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/example/moment_quantization.ipynb b/example/moment_quantization.ipynb new file mode 100644 index 0000000..76d1227 --- /dev/null +++ b/example/moment_quantization.ipynb @@ -0,0 +1,171 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Moment Forecasting Example\n", + "\n", + "## Loading Moment Model" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/nethome/sli999/anaconda3/envs/torch/lib/python3.11/site-packages/gluonts/json.py:102: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n", + " warnings.warn(\n", + "INFO:p-2597098:t-140082893653824:moment.py:_validate_inputs:Setting d_model to 1024\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading MOMENT model from AutonLab/MOMENT-1-large\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:p-2597098:t-140082893653824:moment.py:_get_transformer_backbone:Initializing pre-trained transformer from google/flan-t5-large.\n", + "INFO:p-2597098:t-140082893653824:moment.py:_get_transformer_backbone:Enabling gradient checkpointing.\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "import numpy as np\n", + "\n", + "src_path = os.path.abspath(os.path.join(\"..\", \"src\"))\n", + "if src_path not in sys.path:\n", + " sys.path.insert(0, src_path)\n", + "\n", + "from samay.model import MomentModel\n", + "from samay.dataset import MomentDataset\n", + "from samay.utils import load_args\n", + "\n", + "repo = \"AutonLab/MOMENT-1-large\"\n", + "config = {\n", + " \"task_name\": \"forecasting\",\n", + " \"forecast_horizon\": 192,\n", + " \"head_dropout\": 0.1,\n", + " \"weight_decay\": 0,\n", + " \"freeze_encoder\": True, # Freeze the patch embedding layer\n", + " \"freeze_embedder\": True, # Freeze the transformer encoder\n", + " \"freeze_head\": False, # The linear forecasting head must be trained\n", + "}\n", + "mmt = MomentModel(config=config, repo=repo)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Finetune Moment Model on the ETT dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/nethome/sli999/anaconda3/envs/torch/lib/python3.11/site-packages/torch/utils/checkpoint.py:87: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0: Train loss: 0.068\n", + "Epoch 1: Train loss: 0.064\n", + "Epoch 2: Train loss: 0.060\n", + "Epoch 3: Train loss: 0.056\n", + "Epoch 4: Train loss: 0.053\n" + ] + }, + { + "data": { + "text/plain": [ + "{'mse': 0.06429593,\n", + " 'mae': 0.05884363,\n", + " 'mase': 1.8647041,\n", + " 'mape': 0.02874577,\n", + " 'rmse': 0.2535664,\n", + " 'nrmse': 0.02665093539652813,\n", + " 'smape': 0.2105672,\n", + " 'msis': 0.046476997,\n", + " 'nd': 26.39926135498086}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset = MomentDataset(name=\"ett\", datetime_col='date', path='../src/samay/models/moment/data/ETTh1.csv', \n", + " mode='train', horizon_len=192)\n", + "\n", + "val_dataset = MomentDataset(name=\"ett\", datetime_col='date', path='../src/samay/models/moment/data/ETTh1.csv',\n", + " mode='test', horizon_len=192)\n", + "\n", + "\n", + "finetuned_model = mmt.finetune(train_dataset, task_name=\"forecasting\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from src.samay.model import MomentModel\n", + "\n", + "mmt = mmt.quantize(quant_type = \"int8\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mmt.evaluate(val_dataset, task_name=\"forecasting\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torch", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/samay/model.py b/src/samay/model.py index 3bbd803..55ce1d8 100644 --- a/src/samay/model.py +++ b/src/samay/model.py @@ -53,7 +53,7 @@ TinyTimeMixerForPrediction, ) from .utils import cleanup_dataloader, get_least_used_gpu, quantile_loss, visualize - +from quantization import quantize_linear_layers class Basemodel: def __init__(self, config=None, repo=None): @@ -1159,6 +1159,7 @@ def finetune(self, dataset: LPTMDataset, task_name: str = "forecasting", **kwarg max_epoch = 5 if "epoch" not in kwargs else kwargs["epoch"] max_norm = 5.0 if "norm" not in kwargs else kwargs["norm"] mask_ratio = 0.25 if "mask_ratio" not in kwargs else kwargs["mask_ratio"] + quantization = False if "quantization" not in kwargs else kwargs["quantization"] if task_name == "imputation" or task_name == "detection": mask_generator = Masking(mask_ratio=mask_ratio) @@ -1276,6 +1277,11 @@ def finetune(self, dataset: LPTMDataset, task_name: str = "forecasting", **kwarg scheduler.step() + def quantize(self, quant_type="int8", device="cuda"): + self.model.eval() + self.model = self.model.to(device) + with torch.no_grad(): + self.model = quantize_linear_layers(self.model, quantization_type=quant_type) return self.model def evaluate(self, dataset: LPTMDataset, task_name: str = "forecasting", metric_only=False, **kwargs): @@ -1724,7 +1730,14 @@ def finetune(self, dataset: MomentDataset, task_name: str = "forecasting", **kwa scheduler.step() return self.model - + + def quantize(self, quant_type="int8", device="cuda"): + self.model.eval() + self.model = self.model.to(device) + with torch.no_grad(): + self.model = quantize_linear_layers(self.model, quantization_type=quant_type) + return self.model + def plot(self, dataset: MomentDataset, task_name: str = "forecasting"): """Visualize results from the MOMENT model. diff --git a/src/samay/quantization.py b/src/samay/quantization.py new file mode 100644 index 0000000..781a972 --- /dev/null +++ b/src/samay/quantization.py @@ -0,0 +1,148 @@ +import os +import torch +import torch.nn as nn +import bitsandbytes as bnb + +#CHANGE THESE +USE_CUDA = False +QUANT_TYPE = "int8" + +DEVICE = torch.device("cuda" if (USE_CUDA and torch.cuda.is_available()) else "cpu") +print("Using device:", DEVICE) +print("Quantization type:", QUANT_TYPE) + + +def quantize_linear_layers(module, threshold=6.0, quantization_type="int8"): + for name, child in module.named_children(): + if isinstance(child, nn.Linear) and child.in_features >= 128: + if quantization_type == "int8": + q = bnb.nn.Linear8bitLt( + child.in_features, + child.out_features, + bias=(child.bias is not None), + threshold=threshold, + has_fp16_weights=False, + ) + elif quantization_type == "nf4": + q = bnb.nn.Linear4bit( + child.in_features, + child.out_features, + bias=(child.bias is not None), + quant_type="nf4", + compute_dtype=torch.float16, + ) + with torch.no_grad(): + q.weight.copy_(child.weight) + if child.bias is not None: + q.bias.copy_(child.bias) + setattr(module, name, q) + else: + quantize_linear_layers(child, threshold=threshold, quantization_type=quantization_type) + return module + + +#LOADING MODEL -- FROM THE EXAMPLE LPTM NOTEBOOk +from samay.model import LPTMModel + +config = { + "task_name": "forecasting", + "forecast_horizon": 192, + "head_dropout": 0, + "weight_decay": 0, + "max_patch": 16, + "freeze_encoder": True, + "freeze_embedder": True, + "freeze_head": False, + "freeze_segment": True, +} + +lptm = LPTMModel(config) +lptm.model = lptm.model.to(DEVICE) + + +#QUANTIZATION +print("Before quantization (bytes):", + sum(p.numel() * p.element_size() for p in lptm.model.parameters())) + +lptm.model = quantize_linear_layers(lptm.model) +lptm.model = lptm.model.to(DEVICE) + + +#CHECKING IF QUANTIZATION IS SUCCESSFUL +def proof_report_lptm(lptm): + print("PROOF REPORT: bitsandbytes quantization (LPTM)") + + m = lptm.model + + n8 = sum(1 for x in m.modules() if isinstance(x, bnb.nn.Linear8bitLt)) + n4 = sum(1 for x in m.modules() if isinstance(x, bnb.nn.Linear4bit)) + nlin = sum(1 for x in m.modules() + if isinstance(x, nn.Linear) + and not isinstance(x, bnb.nn.Linear8bitLt)) + + print("Linear8bitLt layers:", n8) + print("Linear4bit layers: ", n4) + print("Pure nn.Linear left:", nlin) + + if QUANT_TYPE == "int8" and n8 == 0: + raise RuntimeError("INT8 quantization FAILED") + + if QUANT_TYPE == "nf4" and n4 == 0: + raise RuntimeError("NF4 quantization FAILED") + + for name, layer in m.named_modules(): + if isinstance(layer, (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)): + with torch.no_grad(): + x = torch.randn(4, layer.in_features, device=DEVICE) + y = layer(x) + print("Tested layer:", name, "->", tuple(y.shape)) + break + + print("PASS: Quantized layers exist and execute\n") + + +proof_report_lptm(lptm) + + +#EVALUATION WITH MEMORY TRACKING +from samay.dataset import LPTMDataset + +val_dataset = LPTMDataset( + name="ett", + datetime_col="date", + path="data/data/ETTh1.csv", + mode="train", + horizon=192, +) + + +def gpu_mem_mb(): + if DEVICE.type != "cuda": + return None + return torch.cuda.max_memory_allocated() / 1024**2 + + +if DEVICE.type == "cuda": + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + +print("GPU memory before eval (MB):", gpu_mem_mb()) + +try: + print("Evaluating model...") + avg_loss, trues, preds, histories = lptm.evaluate( + val_dataset, + task_name="forecasting" + ) + + if DEVICE.type == "cuda": + torch.cuda.synchronize() + + print("GPU peak memory during eval (MB):", gpu_mem_mb()) + print("Inference SUCCESS") + print("Avg loss:", avg_loss) + print("Num predictions:", len(preds) if hasattr(preds, "__len__") else "n/a") + +except Exception as e: + print("Inference FAILED") + print("Error:", repr(e))