Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions example/lptm_quantization.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from samay.model import LPTMModel\n",
"\n",
"config = {\n",
" \"task_name\": \"forecasting\",\n",
" \"forecast_horizon\": 192,\n",
" \"head_dropout\": 0,\n",
" \"weight_decay\": 0,\n",
" \"max_patch\": 16,\n",
" \"freeze_encoder\": True, # Freeze the patch embedding layer\n",
" \"freeze_embedder\": True, # Freeze the transformer encoder\n",
" \"freeze_head\": False, # The linear forecasting head must be trained\n",
" \"freeze_segment\": True, # Freeze the segmention module\n",
"}\n",
"model = LPTMModel(config)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from samay.dataset import LPTMDataset\n",
"\n",
"train_dataset = LPTMDataset(\n",
" name=\"ett\",\n",
" datetime_col=\"date\",\n",
" path=\"../data/data/ETTh1.csv\",\n",
" mode=\"train\",\n",
" horizon=192,\n",
")\n",
"\n",
"finetuned_model = model.finetune(train_dataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from src.samay.model import LPTMModel\n",
"\n",
"model = model.quantize(quant_type = \"int8\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val_dataset = LPTMDataset(\n",
" name=\"ett\",\n",
" datetime_col=\"date\",\n",
" path=\"../data/data/ETTh1.csv\",\n",
" mode=\"train\",\n",
" horizon=192,\n",
")\n",
"metrics, trues, preds, histories = model.evaluate(val_dataset, task_name=\"forecasting\")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
171 changes: 171 additions & 0 deletions example/moment_quantization.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Moment Forecasting Example\n",
"\n",
"## Loading Moment Model"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/nethome/sli999/anaconda3/envs/torch/lib/python3.11/site-packages/gluonts/json.py:102: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n",
" warnings.warn(\n",
"INFO:p-2597098:t-140082893653824:moment.py:_validate_inputs:Setting d_model to 1024\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading MOMENT model from AutonLab/MOMENT-1-large\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:p-2597098:t-140082893653824:moment.py:_get_transformer_backbone:Initializing pre-trained transformer from google/flan-t5-large.\n",
"INFO:p-2597098:t-140082893653824:moment.py:_get_transformer_backbone:Enabling gradient checkpointing.\n"
]
}
],
"source": [
"import os\n",
"import sys\n",
"import numpy as np\n",
"\n",
"src_path = os.path.abspath(os.path.join(\"..\", \"src\"))\n",
"if src_path not in sys.path:\n",
" sys.path.insert(0, src_path)\n",
"\n",
"from samay.model import MomentModel\n",
"from samay.dataset import MomentDataset\n",
"from samay.utils import load_args\n",
"\n",
"repo = \"AutonLab/MOMENT-1-large\"\n",
"config = {\n",
" \"task_name\": \"forecasting\",\n",
" \"forecast_horizon\": 192,\n",
" \"head_dropout\": 0.1,\n",
" \"weight_decay\": 0,\n",
" \"freeze_encoder\": True, # Freeze the patch embedding layer\n",
" \"freeze_embedder\": True, # Freeze the transformer encoder\n",
" \"freeze_head\": False, # The linear forecasting head must be trained\n",
"}\n",
"mmt = MomentModel(config=config, repo=repo)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Finetune Moment Model on the ETT dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/nethome/sli999/anaconda3/envs/torch/lib/python3.11/site-packages/torch/utils/checkpoint.py:87: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0: Train loss: 0.068\n",
"Epoch 1: Train loss: 0.064\n",
"Epoch 2: Train loss: 0.060\n",
"Epoch 3: Train loss: 0.056\n",
"Epoch 4: Train loss: 0.053\n"
]
},
{
"data": {
"text/plain": [
"{'mse': 0.06429593,\n",
" 'mae': 0.05884363,\n",
" 'mase': 1.8647041,\n",
" 'mape': 0.02874577,\n",
" 'rmse': 0.2535664,\n",
" 'nrmse': 0.02665093539652813,\n",
" 'smape': 0.2105672,\n",
" 'msis': 0.046476997,\n",
" 'nd': 26.39926135498086}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_dataset = MomentDataset(name=\"ett\", datetime_col='date', path='../src/samay/models/moment/data/ETTh1.csv', \n",
" mode='train', horizon_len=192)\n",
"\n",
"val_dataset = MomentDataset(name=\"ett\", datetime_col='date', path='../src/samay/models/moment/data/ETTh1.csv',\n",
" mode='test', horizon_len=192)\n",
"\n",
"\n",
"finetuned_model = mmt.finetune(train_dataset, task_name=\"forecasting\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from src.samay.model import MomentModel\n",
"\n",
"mmt = mmt.quantize(quant_type = \"int8\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mmt.evaluate(val_dataset, task_name=\"forecasting\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
17 changes: 15 additions & 2 deletions src/samay/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
TinyTimeMixerForPrediction,
)
from .utils import cleanup_dataloader, get_least_used_gpu, quantile_loss, visualize

from quantization import quantize_linear_layers

class Basemodel:
def __init__(self, config=None, repo=None):
Expand Down Expand Up @@ -1159,6 +1159,7 @@ def finetune(self, dataset: LPTMDataset, task_name: str = "forecasting", **kwarg
max_epoch = 5 if "epoch" not in kwargs else kwargs["epoch"]
max_norm = 5.0 if "norm" not in kwargs else kwargs["norm"]
mask_ratio = 0.25 if "mask_ratio" not in kwargs else kwargs["mask_ratio"]
quantization = False if "quantization" not in kwargs else kwargs["quantization"]

if task_name == "imputation" or task_name == "detection":
mask_generator = Masking(mask_ratio=mask_ratio)
Expand Down Expand Up @@ -1276,6 +1277,11 @@ def finetune(self, dataset: LPTMDataset, task_name: str = "forecasting", **kwarg

scheduler.step()

def quantize(self, quant_type="int8", device="cuda"):
self.model.eval()
self.model = self.model.to(device)
with torch.no_grad():
self.model = quantize_linear_layers(self.model, quantization_type=quant_type)
return self.model

def evaluate(self, dataset: LPTMDataset, task_name: str = "forecasting", metric_only=False, **kwargs):
Expand Down Expand Up @@ -1724,7 +1730,14 @@ def finetune(self, dataset: MomentDataset, task_name: str = "forecasting", **kwa
scheduler.step()

return self.model


def quantize(self, quant_type="int8", device="cuda"):
self.model.eval()
self.model = self.model.to(device)
with torch.no_grad():
self.model = quantize_linear_layers(self.model, quantization_type=quant_type)
return self.model

def plot(self, dataset: MomentDataset, task_name: str = "forecasting"):
"""Visualize results from the MOMENT model.

Expand Down
Loading