diff --git a/examples/eeg/eeg_models/BIOT_tuev_eeg_event_classification.ipynb b/examples/eeg/eeg_models/BIOT_tuev_eeg_event_classification.ipynb new file mode 100644 index 000000000..dbd808375 --- /dev/null +++ b/examples/eeg/eeg_models/BIOT_tuev_eeg_event_classification.ipynb @@ -0,0 +1,632 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a2b5eb60", + "metadata": {}, + "source": [ + "## 1. Environment Setup\n", + "Seed the random generators, import core dependencies, and detect the training device." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f5284e16", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on device: cuda\n" + ] + } + ], + "source": [ + "import random\n", + "\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from pyhealth.datasets import TUEVDataset\n", + "from pyhealth.tasks import EEGEventsTUEV\n", + "from pyhealth.datasets.splitter import split_by_sample\n", + "from pyhealth.datasets.utils import get_dataloader\n", + "from pyhealth.models import BIOT\n", + "\n", + "SEED = 42\n", + "random.seed(SEED)\n", + "np.random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(SEED)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Running on device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "id": "1c999e55", + "metadata": {}, + "source": [ + "## 2. Load TUEV Dataset\n", + "Point to the TUEV dataset root and load the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d1230c58", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No config path provided, using default config\n", + "Using both train and eval subsets\n", + "Using cached metadata from /home/jp65/.cache/pyhealth/tuev\n", + "Initializing tuev dataset from /home/jp65/.cache/pyhealth/tuev (dev mode: True)\n", + "No cache_dir provided. Using default cache dir: /home/jp65/.cache/pyhealth/fe851030-67cc-5fe9-8de6-9cac645af5a7\n", + "Dataset: tuev\n", + "Dev mode: True\n", + "Number of patients: 189\n", + "Number of events: 259\n" + ] + } + ], + "source": [ + "dataset = TUEVDataset(\n", + " root='/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf', # Update this path\n", + " dev=True\n", + ")\n", + "dataset.stats()" + ] + }, + { + "cell_type": "markdown", + "id": "ff3f040f", + "metadata": {}, + "source": [ + "## 3. Prepare PyHealth Dataset\n", + "Set the task for the dataset and convert raw samples into PyHealth format for abnormal EEG classification." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "66f68916", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting task EEG_events for tuev base dataset...\n", + "Found cached processed samples at /home/jp65/.cache/pyhealth/fe851030-67cc-5fe9-8de6-9cac645af5a7/tasks/EEG_events_d595851a-c8f6-5f1d-bdfb-b3d527c2deb0/samples_47c27255-9fc0-5271-bd99-638ffecdb1cc.ld, skipping processing.\n", + "Total task samples: 53377\n", + "Input schema: {'signal': 'tensor'}\n", + "Output schema: {'label': 'multiclass'}\n", + "\n", + "Sample keys: dict_keys(['patient_id', 'signal_file', 'signal', 'offending_channel', 'label'])\n", + "Signal shape: torch.Size([16, 1000])\n", + "Label: 5\n" + ] + } + ], + "source": [ + "sample_dataset = dataset.set_task(EEGEventsTUEV(\n", + " resample_rate=200, # Resample rate\n", + " bandpass_filter=(0.1, 75.0), # Bandpass filter\n", + " notch_filter=50.0, # Notch filter\n", + "))\n", + "\n", + "print(f\"Total task samples: {len(sample_dataset)}\")\n", + "print(f\"Input schema: {sample_dataset.input_schema}\")\n", + "print(f\"Output schema: {sample_dataset.output_schema}\")\n", + "\n", + "# Inspect a sample\n", + "sample = sample_dataset[0]\n", + "print(f\"\\nSample keys: {sample.keys()}\")\n", + "print(f\"Signal shape: {sample['signal'].shape}\")\n", + "print(f\"Label: {sample['label']}\")" + ] + }, + { + "cell_type": "markdown", + "id": "37f99c63", + "metadata": {}, + "source": [ + "## 4. Split Dataset\n", + "Divide the processed samples into training, validation, and test subsets before building dataloaders." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c01a076f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train/Val/Test sizes: 37363, 5338, 10676\n" + ] + } + ], + "source": [ + "BATCH_SIZE = 32\n", + "\n", + "train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)\n", + "print(f\"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}\")\n", + "\n", + "train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)\n", + "val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None\n", + "test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None\n", + "\n", + "if len(train_loader) == 0:\n", + " raise RuntimeError(\"The training loader is empty. Increase the dataset size or adjust the split ratios.\")" + ] + }, + { + "cell_type": "markdown", + "id": "f6dcd48f", + "metadata": {}, + "source": [ + "## 5. Inspect Batch Structure\n", + "Peek at the first training batch to understand feature shapes and data structure." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1d490449", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch structure:\n", + " patient_id: list(len=32)\n", + " signal_file: list(len=32)\n", + " signal: Tensor(shape=(32, 16, 1000))\n", + " offending_channel: list(len=32)\n", + " label: Tensor(shape=(32,))\n" + ] + } + ], + "source": [ + "first_batch = next(iter(train_loader))\n", + "\n", + "def describe(value):\n", + " if hasattr(value, \"shape\"):\n", + " return f\"{type(value).__name__}(shape={tuple(value.shape)})\"\n", + " if isinstance(value, (list, tuple)):\n", + " return f\"{type(value).__name__}(len={len(value)})\"\n", + " return type(value).__name__\n", + "\n", + "batch_summary = {key: describe(value) for key, value in first_batch.items()}\n", + "print(\"Batch structure:\")\n", + "for key, desc in batch_summary.items():\n", + " print(f\" {key}: {desc}\")" + ] + }, + { + "cell_type": "markdown", + "id": "73afd561", + "metadata": {}, + "source": [ + "## 6. Instantiate BIOT\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7236ddc0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total parameters: 3,187,734\n", + "Trainable parameters: 3,187,718\n" + ] + } + ], + "source": [ + "model = BIOT(\n", + " dataset=sample_dataset,\n", + " emb_size=256,\n", + " heads= 8,\n", + " depth=4,\n", + " n_fft=200,\n", + " hop_length=100,\n", + " n_classes=6,\n", + " n_channels=16).to(device)\n", + "\n", + "total_params = sum(p.numel() for p in model.parameters())\n", + "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "print(f\"Total parameters: {total_params:,}\")\n", + "print(f\"Trainable parameters: {trainable_params:,}\")" + ] + }, + { + "cell_type": "markdown", + "id": "912ec100", + "metadata": {}, + "source": [ + "## 7. Test Forward Pass\n", + "Verify the model can process a batch and compute outputs." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "11d7f9c5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jp65/miniconda3/envs/pyhealth/lib/python3.12/site-packages/torch/functional.py:730: UserWarning: A window was not provided. A rectangular window will be applied,which is known to cause spectral leakage. Other windows such as torch.hann_window or torch.hamming_window are recommended to reduce spectral leakage.To suppress this warning and use a rectangular window, explicitly set `window=torch.ones(n_fft, device=)`. (Triggered internally at /pytorch/aten/src/ATen/native/SpectralOps.cpp:836.)\n", + " return _VF.stft( # type: ignore[attr-defined]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output keys: dict_keys(['loss', 'y_prob', 'y_true', 'logit'])\n", + "Loss: 119.1745\n", + "Logits shape: torch.Size([32, 6])\n", + "y_prob shape: torch.Size([32, 6])\n", + "Embeddings shape: torch.Size([32, 256])\n" + ] + } + ], + "source": [ + "# Move batch to device\n", + "test_batch = {key: value.to(device) if hasattr(value, 'to') else value \n", + " for key, value in first_batch.items()}\n", + "\n", + "# Forward pass\n", + "with torch.no_grad():\n", + " outputs = model(**test_batch)\n", + "\n", + "print(\"Output keys:\", outputs.keys())\n", + "print(f\"Loss: {outputs['loss'].item():.4f}\")\n", + "print(f\"Logits shape: {outputs['logit'].shape}\")\n", + "print(f\"y_prob shape: {outputs['y_prob'].shape}\")\n", + "\n", + "# Get embeddings\n", + "embeddings = model.get_embeddings(**test_batch)\n", + "print(f\"Embeddings shape: {embeddings['embeddings'].shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "b0818f3b", + "metadata": {}, + "source": [ + "## 8. Train Model\n", + "Train the model using PyHealth's Trainer:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "5521de25", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BIOT(\n", + " (biot): BIOTClassifier(\n", + " (biot): BIOTEncoder(\n", + " (patch_embedding): PatchFrequencyEmbedding(\n", + " (projection): Linear(in_features=101, out_features=256, bias=True)\n", + " )\n", + " (transformer): LinearAttentionTransformer(\n", + " (layers): SequentialSequence(\n", + " (layers): ModuleList(\n", + " (0-3): 4 x ModuleList(\n", + " (0): PreNorm(\n", + " (fn): SelfAttention(\n", + " (local_attn): LocalAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " )\n", + " (to_q): Linear(in_features=256, out_features=256, bias=False)\n", + " (to_k): Linear(in_features=256, out_features=256, bias=False)\n", + " (to_v): Linear(in_features=256, out_features=256, bias=False)\n", + " (to_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " )\n", + " (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (1): PreNorm(\n", + " (fn): Chunk(\n", + " (fn): FeedForward(\n", + " (w1): Linear(in_features=256, out_features=1024, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (w2): Linear(in_features=1024, out_features=256, bias=True)\n", + " )\n", + " )\n", + " (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (positional_encoding): PositionalEncoding(\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (channel_tokens): Embedding(16, 256)\n", + " )\n", + " (classifier): ClassificationHead(\n", + " (clshead): Sequential(\n", + " (0): ELU(alpha=1.0)\n", + " (1): Linear(in_features=256, out_features=6, bias=True)\n", + " )\n", + " )\n", + " )\n", + ")\n", + "Metrics: ['balanced_accuracy', 'cohen_kappa']\n", + "Device: cuda\n", + "\n", + "Training:\n", + "Batch size: 32\n", + "Optimizer: \n", + "Optimizer params: {'lr': 0.001}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: \n", + "Monitor: cohen_kappa\n", + "Monitor criterion: max\n", + "Epochs: 3\n", + "Patience: None\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "785afab7b89a43778e2a877acd16db00", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch 0 / 3: 0%| | 0/1168 [00:00 torch.FloatTensor: + """ + Args: + x: `embeddings`, shape (batch, max_len, d_model) + Returns: + `encoder input`, shape (batch, max_len, d_model) + """ + x = x + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class BIOTEncoder(nn.Module): + def __init__( + self, + emb_size=256, + heads=8, + depth=4, + n_channels=16, + n_fft=200, + hop_length=100, + **kwargs + ): + super().__init__() + + self.n_fft = n_fft + self.hop_length = hop_length + + self.patch_embedding = PatchFrequencyEmbedding( + emb_size=emb_size, n_freq=self.n_fft // 2 + 1 + ) + self.transformer = LinearAttentionTransformer( + dim=emb_size, + heads=heads, + depth=depth, + max_seq_len=1024, + attn_layer_dropout=0.2, # dropout right after self-attention layer + attn_dropout=0.2, # dropout post-attention + ) + self.positional_encoding = PositionalEncoding(emb_size) + + # channel token, N_channels >= your actual channels + self.channel_tokens = nn.Embedding(n_channels, 256) + self.index = nn.Parameter( + torch.LongTensor(range(n_channels)), requires_grad=False + ) + + def stft(self, sample): + spectral = torch.stft( + input = sample.squeeze(1), + n_fft = self.n_fft, + hop_length = self.hop_length, + center = False, + onesided = True, + return_complex = True, + ) + return torch.abs(spectral) + + def forward(self, x, n_channel_offset=0, perturb=False): + """ + x: [batch_size, channel, ts] + output: [batch_size, emb_size] + """ + emb_seq = [] + for i in range(x.shape[1]): + channel_spec_emb = self.stft(x[:, i : i + 1, :]) + channel_spec_emb = self.patch_embedding(channel_spec_emb) + batch_size, ts, _ = channel_spec_emb.shape + # (batch_size, ts, emb) + channel_token_emb = ( + self.channel_tokens(self.index[i + n_channel_offset]) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch_size, ts, 1) + ) + # (batch_size, ts, emb) + channel_emb = self.positional_encoding(channel_spec_emb + channel_token_emb) + + # perturb + if perturb: + ts = channel_emb.shape[1] + ts_new = np.random.randint(ts // 2, ts) + selected_ts = np.random.choice(range(ts), ts_new, replace=False) + channel_emb = channel_emb[:, selected_ts] + emb_seq.append(channel_emb) + + # (batch_size, 16 * ts, emb) + emb = torch.cat(emb_seq, dim=1) + # (batch_size, emb) + emb = self.transformer(emb).mean(dim=1) + return emb + + +class BIOTClassifier(nn.Module): + def __init__(self, emb_size=256, heads=8, depth=4, n_classes=6, **kwargs): + super().__init__() + self.biot = BIOTEncoder(emb_size=emb_size, heads=heads, depth=depth, **kwargs) + self.classifier = ClassificationHead(emb_size, n_classes) + + def get_embeddings(self, x): + x = self.biot(x) + return x + + def forward(self, x): + x = self.biot(x) + x = self.classifier(x) + return x + + +class BIOT(BaseModel): + """BIOT: Biosignal transformer for cross-data learning in the wild + Citation: + Yang, Chaoqi, M. Westover, and Jimeng Sun. "Biot: Biosignal transformer for cross-data learning in the wild." Advances in Neural Information Processing Systems 36 (2023): 78240-78260. + + The BIOT model encodes multichannel biosignal data (such as EEG) into compact feature representations + using spectral patch embeddings, channel positional encodings, and a transformer encoder. + + The model expects as input: + - Raw temporal biosignals: shape (batch_size, n_channels, n_time) + + Args: + dataset: the dataset to train or evaluate the model (must be compatible with SampleDataset). + emb_size: embedding dimension for token/channel representations. Default is 256. + heads: number of transformer attention heads. Default is 8. + depth: number of transformer encoder layers. Default is 4. + n_fft: number of frequency bins used in the STFT transform. Default is 200. + hop_length: hop length for the STFT transform. Default is 100. + n_classes: number of output classes for classification tasks (only used for BIOTClassifier). + n_channels: number of channels in the biosignal data. Default is 18. + This includes the 16 channels of the TUEV dataset and 2 additional channels for Sleep dataset. + Examples: + >>> from pyhealth.datasets import TUEVDataset + >>> from pyhealth.models import BIOT + >>> dataset = TUEVDataset(root="/path/to/tuev") + >>> sample_dataset = dataset.set_task() + >>> model = BIOT(dataset=sample_dataset, + >>> emb_size=256, + >>> heads=8, + >>> depth=4, + >>> n_fft=200, + >>> hop_length=100, + >>> n_classes=6, + >>> n_channels=18, + >>> ) + >>> model.load_pretrained_weights("pretrained-models/EEG-six-datasets-18-channels.ckpt") + >>> # Pretrained weights for the BIOT model trained on the EEG-six-datasets dataset with 18 channels. + >>> # Provided by the authors: https://github.com/ycq091044/BIOT/blob/main/pretrained-models/EEG-six-datasets-18-channels.ckpt + >>> output = model(torch.randn(8, 18, TIME_STEPS)) # (batch, channels, time) + """ + + def __init__(self, + dataset: SampleDataset, + emb_size: int = 256, + heads: int = 8, + depth: int = 4, + n_fft: int = 200, + hop_length: int = 100, + n_classes: int = 6, + n_channels: int = 18, + **kwargs): + super().__init__(dataset=dataset) + _get_linear_attention_transformer() + self.biot = BIOTClassifier(emb_size=emb_size, + heads=heads, + depth=depth, + n_classes=n_classes, + n_channels=n_channels, + n_fft=n_fft, + hop_length=hop_length) + + def load_pretrained_weights(self, checkpoint_path: str, strict: bool = False, map_location: str = None): + """Load pre-trained weights from checkpoint. + + Args: + checkpoint_path: path to the checkpoint file. + strict: whether to strictly enforce key matching. Default is True. + map_location: device to map the loaded tensors. Default is None. + """ + if map_location is None: + map_location = str(self.device) + + checkpoint = torch.load(checkpoint_path, map_location=map_location) + + if "model_state_dict" in checkpoint: + state_dict = checkpoint["model_state_dict"] + elif "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + self.biot.load_state_dict(state_dict, strict=strict) + print(f"✓ Successfully loaded weights from {checkpoint_path}") + + + def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]: + """Forward propagation. + + Args: + **kwargs: keyword arguments containing 'signal'. + + Returns: + a dictionary containing loss, y_prob, y_true, logit, tokens, embeddings. + """ + signal = kwargs.get("signal") + if signal is None: + raise ValueError("'signal' must be provided in inputs") + signal = signal.to(self.device) + logits = self.biot(signal) + label_key = self.label_keys[0] + y_true = kwargs[label_key].to(self.device) + + loss_fn = self.get_loss_function() + loss = loss_fn(logits, y_true) + y_prob = self.prepare_y_prob(logits) + + results = { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } + return results + + def get_embeddings(self, **kwargs: Any) -> Dict[str, torch.Tensor]: + """Get embeddings. + + Args: + **kwargs: keyword arguments containing 'signal'. + + Returns: + a dictionary containing embeddings. + """ + signal = kwargs.get("signal") + if signal is None: + raise ValueError("'signal' must be provided in inputs") + signal = signal.to(self.device) + embeddings = self.biot.get_embeddings(signal) + return { + "embeddings": embeddings, + } + + + +if __name__ == "__main__": + _get_linear_attention_transformer() + print("Testing BIOT model...") + model = BIOTClassifier(emb_size=256, heads=8, depth=4, n_classes=6, n_channels=18, n_fft=200, hop_length=100) + print(f"✓ Created BIOTClassifier: {model.__class__.__name__}") + + batch_size = 2 + n_channels = 18 + n_time = 10 + n_samples = 200*n_time + + dummy_signal = torch.randn(batch_size, n_channels, n_samples) + logits = model(dummy_signal) + print(f"✓ BIOTClassifier forward pass:") + print(f" Logits shape: {logits.shape}") + + print("\n✓ All tests passed!") \ No newline at end of file diff --git a/tests/core/test_biot.py b/tests/core/test_biot.py new file mode 100644 index 000000000..76a1170c7 --- /dev/null +++ b/tests/core/test_biot.py @@ -0,0 +1,185 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import BIOT + + +class TestBIOT(unittest.TestCase): + """Test cases for the BIOT model.""" + + def setUp(self): + """Set up test data and model.""" + n_channels = 18 + n_time = 10 + n_fft = 200 + hop_length = 100 + n_samples = n_fft * n_time # 2000 + + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "signal": torch.randn(n_channels, n_samples).numpy().tolist(), + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-0", + "signal": torch.randn(n_channels, n_samples).numpy().tolist(), + "label": 0, + }, + { + "patient_id": "patient-2", + "visit_id": "visit-0", + "signal": torch.randn(n_channels, n_samples).numpy().tolist(), + "label": 2, + }, + { + "patient_id": "patient-3", + "visit_id": "visit-0", + "signal": torch.randn(n_channels, n_samples).numpy().tolist(), + "label": 3, + }, + ] + + self.input_schema = { + "signal": "tensor", + } + self.output_schema = {"label": "multiclass"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test_biot", + ) + + self.model = BIOT( + dataset=self.dataset, + emb_size=256, + heads=8, + depth=4, + n_fft=200, + hop_length=100, + n_classes=6, + n_channels=18, + ) + + def test_model_initialization(self): + """Test that the BIOT model initializes correctly.""" + self.assertIsInstance(self.model, BIOT) + self.assertIsNotNone(self.model.biot) + self.assertEqual(len(self.model.feature_keys), 1) + self.assertIn("signal", self.model.feature_keys) + + def test_model_forward(self): + """Test that the BIOT forward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["y_true"].shape[0], 2) + self.assertEqual(ret["logit"].shape[0], 2) + self.assertEqual(ret["logit"].shape[1], 6) # n_classes + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + """Test that the BIOT backward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + ret = self.model(**data_batch) + ret["loss"].backward() + + has_gradient = any( + param.requires_grad and param.grad is not None + for param in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters have gradients after backward pass") + + def test_model_different_batch_sizes(self): + """Test BIOT with different batch sizes.""" + for batch_size in [1, 2, 4]: + train_loader = get_dataloader(self.dataset, batch_size=batch_size, shuffle=False) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + actual_batch = min(batch_size, len(self.samples)) + self.assertEqual(ret["y_prob"].shape[0], actual_batch) + self.assertEqual(ret["y_true"].shape[0], actual_batch) + + def test_model_output_probabilities(self): + """Test that output probabilities are valid.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + y_prob = ret["y_prob"] + # Probabilities should be between 0 and 1 + self.assertTrue(torch.all(y_prob >= 0), "Probabilities contain negative values") + self.assertTrue(torch.all(y_prob <= 1), "Probabilities exceed 1") + + def test_missing_signal_raises_error(self): + """Test that missing 'signal' input raises ValueError.""" + with self.assertRaises((ValueError, KeyError)): + self.model(label=torch.tensor([0, 1])) + + def test_model_different_n_classes(self): + """Test BIOT with different number of classes.""" + model_binary = BIOT( + dataset=self.dataset, + emb_size=256, + heads=8, + depth=4, + n_fft=200, + hop_length=100, + n_classes=2, + n_channels=18, + ) + + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model_binary(**data_batch) + + self.assertEqual(ret["logit"].shape[1], 2) + + def test_model(self): + """Test BIOT""" + model_small = BIOT( + dataset=self.dataset, + emb_size=256, + heads=4, + depth=2, + n_fft=200, + hop_length=100, + n_classes=6, + n_channels=18, + ) + + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model_small(**data_batch) + + self.assertIn("loss", ret) + self.assertEqual(ret["logit"].shape[1], 6) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file