From 037d91a333fa7372688af65faaf7a8f2c562014b Mon Sep 17 00:00:00 2001 From: Steier <637682@bah.com> Date: Sun, 8 Feb 2026 10:48:21 -0600 Subject: [PATCH 1/2] =?UTF-8?q?[Model]=20Add=20GraphCare=20=E2=80=94=20KG-?= =?UTF-8?q?enhanced=20EHR=20predictions=20(ICLR=202024)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/api/models.rst | 1 + docs/api/models/pyhealth.models.GraphCare.rst | 0 examples/graphcare_tutorial.ipynb | 902 ++++++++++++++++++ pyhealth/models/__init__.py | 1 + pyhealth/models/graphcare.py | 506 ++++++++++ pyhealth/models/graphcare_utils.py | 558 +++++++++++ tests/core/test_graphcare.py | 528 ++++++++++ 7 files changed, 2496 insertions(+) create mode 100644 docs/api/models/pyhealth.models.GraphCare.rst create mode 100644 examples/graphcare_tutorial.ipynb create mode 100644 pyhealth/models/graphcare.py create mode 100644 pyhealth/models/graphcare_utils.py create mode 100644 tests/core/test_graphcare.py diff --git a/docs/api/models.rst b/docs/api/models.rst index 9a63b1bb5..0d721cd12 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -17,6 +17,7 @@ We implement the following models for supporting multiple healthcare predictive models/pyhealth.models.TransformersModel models/pyhealth.models.RETAIN models/pyhealth.models.GAMENet + models/pyhealth.models.GraphCare models/pyhealth.models.MICRON models/pyhealth.models.SafeDrug models/pyhealth.models.MoleRec diff --git a/docs/api/models/pyhealth.models.GraphCare.rst b/docs/api/models/pyhealth.models.GraphCare.rst new file mode 100644 index 000000000..e69de29bb diff --git a/examples/graphcare_tutorial.ipynb b/examples/graphcare_tutorial.ipynb new file mode 100644 index 000000000..83e979b43 --- /dev/null +++ b/examples/graphcare_tutorial.ipynb @@ -0,0 +1,902 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GraphCare Tutorial\n", + "\n", + "This notebook demonstrates how to use the **GraphCare** model for healthcare predictions with personalized knowledge graphs in PyHealth.\n", + "\n", + "**Contributors:** Josh Steier\n", + "\n", + "**Paper:** Pengcheng Jiang et al. *GraphCare: Enhancing Healthcare Predictions with Personalized Knowledge Graphs.* ICLR 2024.\n", + "\n", + "## Overview\n", + "\n", + "1. Understand the GraphCare architecture\n", + "2. Build synthetic patient knowledge graphs\n", + "3. Instantiate the model with different GNN backbones\n", + "4. Train and evaluate on binary classification\n", + "5. Inspect attention weights for interpretability" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Environment Setup\n", + "\n", + "Import required libraries and set seeds for reproducibility." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on device: cpu\n" + ] + } + ], + "source": [ + "import os\n", + "import random\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "SEED = 42\n", + "random.seed(SEED)\n", + "np.random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Running on device: {device}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "All imports successful!\n" + ] + } + ], + "source": [ + "from torch_geometric.data import Data, Batch\n", + "from torch_geometric.loader import DataLoader as PyGDataLoader\n", + "\n", + "from pyhealth.models.graphcare import GraphCare, BiAttentionGNNConv\n", + "\n", + "print(\"All imports successful!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Understanding GraphCare Architecture\n", + "\n", + "GraphCare operates on **personalized patient knowledge graphs** derived from EHR codes.\n", + "\n", + "### Architecture Overview\n", + "\n", + "```\n", + "Patient EHR Record\n", + " ├── Conditions ──┐\n", + " ├── Procedures ──┤── Concept-specific KGs ── Patient KG\n", + " └── Medications ──┘ │\n", + " ▼\n", + " ┌───────────────┐\n", + " │ GNN Encoder │\n", + " │ (BAT/GAT/GIN) │\n", + " │ + Bi-Attention │\n", + " └───────┬───────┘\n", + " │\n", + " ┌──────────────┼──────────────┐\n", + " ▼ ▼ ▼\n", + " graph pool node embed joint (concat)\n", + " │ │ │\n", + " └──────────────┴──────────────┘\n", + " │\n", + " ▼\n", + " MLP Head\n", + " │\n", + " ▼\n", + " Prediction\n", + "```\n", + "\n", + "### Key Components\n", + "\n", + "- **Bi-Attention (BAT):** Visit-level (α) and node-level (β) attention with temporal decay\n", + "- **Temporal Decay:** λ_j = exp(γ(V−j)) — more recent visits get higher weight\n", + "- **Patient Modes:** `graph` (global pool), `node` (EHR-node avg), `joint` (both concatenated)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Create Synthetic Patient Graphs\n", + "\n", + "In a real setting, patient KGs are constructed from EHR codes using LLM-prompted subgraphs and medical KGs (see the [GraphCare repo](https://github.com/pat-jj/GraphCare) for the generation pipeline).\n", + "\n", + "Here we create synthetic data to demonstrate the model API." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Node embeddings: torch.Size([500, 64])\n", + "Relation embeddings: torch.Size([50, 64])\n" + ] + } + ], + "source": [ + "# --- Global KG parameters ---\n", + "NUM_NODES = 500 # Total cluster nodes in the KG\n", + "NUM_RELS = 50 # Total relation types\n", + "MAX_VISIT = 5 # Max visits per patient\n", + "EMBEDDING_DIM = 64 # Pre-trained embedding dim\n", + "HIDDEN_DIM = 64 # Model hidden dim\n", + "NUM_PATIENTS = 200 # Synthetic patients\n", + "\n", + "# --- Fake pre-trained embeddings ---\n", + "# In practice these come from word2vec / TransE on the KG\n", + "node_emb = torch.randn(NUM_NODES, EMBEDDING_DIM)\n", + "rel_emb = torch.randn(NUM_RELS, EMBEDDING_DIM)\n", + "\n", + "print(f\"Node embeddings: {node_emb.shape}\")\n", + "print(f\"Relation embeddings: {rel_emb.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created 200 patient graphs\n", + "Example graph: Data(edge_index=[2, 49], y=[35], relation=[49], visit_padded_node=[5, 500], ehr_nodes=[500], label=[1], num_nodes=35)\n" + ] + } + ], + "source": [ + "def create_synthetic_patient_graph(patient_idx, num_nodes, num_rels, max_visit):\n", + " \"\"\"Create a single synthetic patient KG as a PyG Data object.\n", + " \n", + " Each patient graph is a subgraph of the global KG, with:\n", + " - Node IDs (y): indices into the global KG node embedding table\n", + " - Relation IDs: indices into the global relation embedding table\n", + " - visit_padded_node: binary (max_visit, num_nodes) indicating which\n", + " KG nodes appear in each visit\n", + " - ehr_nodes: binary (num_nodes,) indicating direct EHR nodes\n", + " - label: binary mortality label\n", + " \"\"\"\n", + " # Random subgraph size\n", + " n = random.randint(15, 40) # nodes in this patient's subgraph\n", + " e = random.randint(n, n * 3) # edges\n", + " \n", + " # Local edge indices (within the subgraph)\n", + " src = torch.randint(0, n, (e,))\n", + " dst = torch.randint(0, n, (e,))\n", + " \n", + " # Node IDs: which global KG nodes are in this subgraph\n", + " y = torch.randint(0, num_nodes, (n,))\n", + " \n", + " # Relation IDs per edge\n", + " relation = torch.randint(0, num_rels, (e,))\n", + " \n", + " # Visit-padded node indicators\n", + " vpn = torch.zeros(max_visit, num_nodes)\n", + " for v in range(max_visit):\n", + " # Each visit activates some random KG nodes\n", + " active = y[torch.randint(0, n, (random.randint(2, 8),))]\n", + " vpn[v, active] = 1.0\n", + " \n", + " # Direct EHR nodes (subset of the patient's KG nodes)\n", + " ehr = torch.zeros(num_nodes)\n", + " ehr_active = y[torch.randint(0, n, (random.randint(3, 10),))]\n", + " ehr[ehr_active] = 1.0\n", + " \n", + " # Binary label (mortality)\n", + " label = torch.tensor([float(patient_idx % 2)])\n", + " \n", + " data = Data(\n", + " edge_index=torch.stack([src, dst]),\n", + " y=y,\n", + " relation=relation,\n", + " visit_padded_node=vpn,\n", + " ehr_nodes=ehr,\n", + " label=label,\n", + " )\n", + " data.num_nodes = n\n", + " return data\n", + "\n", + "\n", + "# Create dataset\n", + "all_graphs = [\n", + " create_synthetic_patient_graph(i, NUM_NODES, NUM_RELS, MAX_VISIT)\n", + " for i in range(NUM_PATIENTS)\n", + "]\n", + "\n", + "print(f\"Created {len(all_graphs)} patient graphs\")\n", + "print(f\"Example graph: {all_graphs[0]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train/Val/Test sizes: 160, 20, 20\n" + ] + } + ], + "source": [ + "# --- Train / Val / Test split ---\n", + "n_train = int(0.8 * NUM_PATIENTS)\n", + "n_val = int(0.1 * NUM_PATIENTS)\n", + "\n", + "train_graphs = all_graphs[:n_train]\n", + "val_graphs = all_graphs[n_train:n_train + n_val]\n", + "test_graphs = all_graphs[n_train + n_val:]\n", + "\n", + "BATCH_SIZE = 16\n", + "\n", + "train_loader = PyGDataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)\n", + "val_loader = PyGDataLoader(val_graphs, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)\n", + "test_loader = PyGDataLoader(test_graphs, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)\n", + "\n", + "print(f\"Train/Val/Test sizes: {len(train_graphs)}, {len(val_graphs)}, {len(test_graphs)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Inspect Batch Structure\n", + "\n", + "Examine what a PyG batch looks like after collation." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch attributes:\n", + " edge_index: torch.Size([2, 783])\n", + " y (node_ids): torch.Size([436])\n", + " relation: torch.Size([783])\n", + " batch vector: torch.Size([436]) (max=15)\n", + " visit_padded_node: torch.Size([80, 500])\n", + " ehr_nodes: torch.Size([8000])\n", + " label: torch.Size([16])\n", + "\n", + "Note: visit_padded_node and ehr_nodes need reshaping before forward().\n", + " visit_node reshaped: (16, 5, 500)\n", + " ehr_nodes reshaped: (16, 500)\n" + ] + } + ], + "source": [ + "batch = next(iter(train_loader))\n", + "\n", + "print(\"Batch attributes:\")\n", + "print(f\" edge_index: {batch.edge_index.shape}\")\n", + "print(f\" y (node_ids): {batch.y.shape}\")\n", + "print(f\" relation: {batch.relation.shape}\")\n", + "print(f\" batch vector: {batch.batch.shape} (max={batch.batch.max().item()})\")\n", + "print(f\" visit_padded_node: {batch.visit_padded_node.shape}\")\n", + "print(f\" ehr_nodes: {batch.ehr_nodes.shape}\")\n", + "print(f\" label: {batch.label.shape}\")\n", + "print()\n", + "print(\"Note: visit_padded_node and ehr_nodes need reshaping before forward().\")\n", + "print(f\" visit_node reshaped: ({BATCH_SIZE}, {MAX_VISIT}, {NUM_NODES})\")\n", + "print(f\" ehr_nodes reshaped: ({BATCH_SIZE}, {NUM_NODES})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Instantiate GraphCare\n", + "\n", + "Compare the three GNN backbones (BAT, GAT, GIN) and three patient modes." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GraphCare Configuration Comparison\n", + "============================================================\n", + " BAT/joint — 549,941 params\n", + " BAT/graph — 549,877 params\n", + " BAT/node — 549,877 params\n", + " GAT/joint — 550,067 params\n", + " GIN/joint — 549,811 params\n" + ] + } + ], + "source": [ + "print(\"GraphCare Configuration Comparison\")\n", + "print(\"=\" * 60)\n", + "\n", + "configs = [\n", + " (\"BAT\", \"joint\"),\n", + " (\"BAT\", \"graph\"),\n", + " (\"BAT\", \"node\"),\n", + " (\"GAT\", \"joint\"),\n", + " (\"GIN\", \"joint\"),\n", + "]\n", + "\n", + "for gnn, mode in configs:\n", + " model = GraphCare(\n", + " num_nodes=NUM_NODES,\n", + " num_rels=NUM_RELS,\n", + " max_visit=MAX_VISIT,\n", + " embedding_dim=EMBEDDING_DIM,\n", + " hidden_dim=HIDDEN_DIM,\n", + " out_channels=1,\n", + " layers=2,\n", + " node_emb=node_emb,\n", + " rel_emb=rel_emb,\n", + " gnn=gnn,\n", + " patient_mode=mode,\n", + " )\n", + " n_params = sum(p.numel() for p in model.parameters())\n", + " print(f\" {gnn:>3s}/{mode:<6s} — {n_params:>10,} params\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5.1 Forward Pass Verification" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input: 488 total nodes across 16 patient graphs\n", + "Output logits: torch.Size([16, 1])\n", + "Predictions: [0.508 0.484 0.51 0.493 0.505 0.522 0.501 0.493 0.528 0.482 0.509 0.447\n", + " 0.505 0.492 0.509 0.502]\n" + ] + } + ], + "source": [ + "model = GraphCare(\n", + " num_nodes=NUM_NODES,\n", + " num_rels=NUM_RELS,\n", + " max_visit=MAX_VISIT,\n", + " embedding_dim=EMBEDDING_DIM,\n", + " hidden_dim=HIDDEN_DIM,\n", + " out_channels=1,\n", + " layers=2,\n", + " dropout=0.5,\n", + " decay_rate=0.01,\n", + " node_emb=node_emb,\n", + " rel_emb=rel_emb,\n", + " gnn=\"BAT\",\n", + " patient_mode=\"joint\",\n", + " use_alpha=True,\n", + " use_beta=True,\n", + " use_edge_attn=True,\n", + ").to(device)\n", + "\n", + "# Reshape batch tensors\n", + "batch = next(iter(train_loader))\n", + "batch = batch.to(device)\n", + "\n", + "node_ids = batch.y\n", + "rel_ids = batch.relation\n", + "edge_index = batch.edge_index\n", + "batch_vec = batch.batch\n", + "visit_node = batch.visit_padded_node.reshape(BATCH_SIZE, MAX_VISIT, NUM_NODES).float()\n", + "ehr_nodes = batch.ehr_nodes.reshape(BATCH_SIZE, NUM_NODES).float()\n", + "\n", + "# Forward pass\n", + "model.eval()\n", + "with torch.no_grad():\n", + " logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes)\n", + "\n", + "print(f\"Input: {batch.y.shape[0]} total nodes across {BATCH_SIZE} patient graphs\")\n", + "print(f\"Output logits: {logits.shape}\")\n", + "print(f\"Predictions: {torch.sigmoid(logits).squeeze().cpu().numpy().round(3)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Training Loop\n", + "\n", + "Train GraphCare with a standard PyTorch loop. Note that GraphCare uses `torch_geometric.loader.DataLoader` rather than PyHealth's `Trainer`, since the data pipeline requires PyG batching." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import roc_auc_score, average_precision_score\n", + "\n", + "\n", + "def train_one_epoch(model, loader, optimizer, device):\n", + " \"\"\"Train for one epoch.\"\"\"\n", + " model.train()\n", + " total_loss = 0.0\n", + " \n", + " for data in loader:\n", + " data = data.to(device)\n", + " optimizer.zero_grad()\n", + " \n", + " bs = BATCH_SIZE\n", + " vn = data.visit_padded_node.reshape(bs, MAX_VISIT, NUM_NODES).float()\n", + " en = data.ehr_nodes.reshape(bs, NUM_NODES).float()\n", + " \n", + " logits = model(\n", + " node_ids=data.y,\n", + " rel_ids=data.relation,\n", + " edge_index=data.edge_index,\n", + " batch=data.batch,\n", + " visit_node=vn,\n", + " ehr_nodes=en,\n", + " in_drop=True,\n", + " )\n", + " \n", + " labels = data.label.reshape(bs, -1).float()\n", + " loss = F.binary_cross_entropy_with_logits(logits, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + " total_loss += loss.item()\n", + " \n", + " return total_loss / len(loader)\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def evaluate(model, loader, device):\n", + " \"\"\"Evaluate and return metrics.\"\"\"\n", + " model.eval()\n", + " y_true_all, y_prob_all = [], []\n", + " \n", + " for data in loader:\n", + " data = data.to(device)\n", + " bs = BATCH_SIZE\n", + " vn = data.visit_padded_node.reshape(bs, MAX_VISIT, NUM_NODES).float()\n", + " en = data.ehr_nodes.reshape(bs, NUM_NODES).float()\n", + " \n", + " logits = model(\n", + " node_ids=data.y,\n", + " rel_ids=data.relation,\n", + " edge_index=data.edge_index,\n", + " batch=data.batch,\n", + " visit_node=vn,\n", + " ehr_nodes=en,\n", + " )\n", + " \n", + " labels = data.label.reshape(bs, -1)\n", + " y_prob_all.append(torch.sigmoid(logits).cpu())\n", + " y_true_all.append(labels.cpu())\n", + " \n", + " y_true = torch.cat(y_true_all).numpy()\n", + " y_prob = torch.cat(y_prob_all).numpy()\n", + " \n", + " return {\n", + " \"roc_auc\": roc_auc_score(y_true, y_prob),\n", + " \"pr_auc\": average_precision_score(y_true, y_prob),\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: BAT/joint, 549,941 parameters\n", + "Training for 10 epochs...\n", + "\n", + "Epoch 1 | loss=0.7127 | val_roc_auc=0.5938 | val_pr_auc=0.6341\n", + "Epoch 2 | loss=0.6975 | val_roc_auc=0.5625 | val_pr_auc=0.6046\n", + "Epoch 3 | loss=0.6992 | val_roc_auc=0.5625 | val_pr_auc=0.5265\n", + "Epoch 4 | loss=0.6872 | val_roc_auc=0.5625 | val_pr_auc=0.5225\n", + "Epoch 5 | loss=0.6639 | val_roc_auc=0.5312 | val_pr_auc=0.5072\n", + "Epoch 6 | loss=0.6854 | val_roc_auc=0.4844 | val_pr_auc=0.4868\n", + "Epoch 7 | loss=0.6784 | val_roc_auc=0.5000 | val_pr_auc=0.4957\n", + "Epoch 8 | loss=0.6564 | val_roc_auc=0.4219 | val_pr_auc=0.4709\n", + "Epoch 9 | loss=0.6403 | val_roc_auc=0.3438 | val_pr_auc=0.4432\n", + "Epoch 10 | loss=0.6539 | val_roc_auc=0.3438 | val_pr_auc=0.4446\n" + ] + } + ], + "source": [ + "# --- Training ---\n", + "EPOCHS = 10\n", + "LR = 1e-3\n", + "\n", + "model = GraphCare(\n", + " num_nodes=NUM_NODES,\n", + " num_rels=NUM_RELS,\n", + " max_visit=MAX_VISIT,\n", + " embedding_dim=EMBEDDING_DIM,\n", + " hidden_dim=HIDDEN_DIM,\n", + " out_channels=1,\n", + " layers=2,\n", + " dropout=0.5,\n", + " decay_rate=0.01,\n", + " node_emb=node_emb,\n", + " rel_emb=rel_emb,\n", + " gnn=\"BAT\",\n", + " patient_mode=\"joint\",\n", + " drop_rate=0.1,\n", + ").to(device)\n", + "\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)\n", + "total_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Model: BAT/joint, {total_params:,} parameters\")\n", + "print(f\"Training for {EPOCHS} epochs...\\n\")\n", + "\n", + "for epoch in range(1, EPOCHS + 1):\n", + " train_loss = train_one_epoch(model, train_loader, optimizer, device)\n", + " val_metrics = evaluate(model, val_loader, device)\n", + " \n", + " print(\n", + " f\"Epoch {epoch:2d} | loss={train_loss:.4f} | \"\n", + " f\"val_roc_auc={val_metrics['roc_auc']:.4f} | \"\n", + " f\"val_pr_auc={val_metrics['pr_auc']:.4f}\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Evaluate on Test Set" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test Results\n", + "========================================\n", + " roc_auc: 0.7969\n", + " pr_auc: 0.8318\n" + ] + } + ], + "source": [ + "test_metrics = evaluate(model, test_loader, device)\n", + "\n", + "print(\"Test Results\")\n", + "print(\"=\" * 40)\n", + "for k, v in test_metrics.items():\n", + " print(f\" {k}: {v:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Attention Interpretability\n", + "\n", + "GraphCare's BAT layers produce interpretable attention weights:\n", + "- **Alpha (α):** Visit-level attention — which visits matter most\n", + "- **Beta (β):** Node-level attention with temporal decay — which nodes matter, weighted by recency" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Attention weight shapes (per layer):\n", + " Layer 1:\n", + " alpha: torch.Size([16, 5, 500]) (batch, max_visit, num_nodes)\n", + " beta: torch.Size([16, 5, 1]) (batch, max_visit, 1)\n", + " attn: torch.Size([914, 1]) (num_edges_in_batch, 1)\n", + " edge_w: torch.Size([914, 1]) (num_edges_in_batch, 1)\n", + " Layer 2:\n", + " alpha: torch.Size([16, 5, 500]) (batch, max_visit, num_nodes)\n", + " beta: torch.Size([16, 5, 1]) (batch, max_visit, 1)\n", + " attn: torch.Size([914, 1]) (num_edges_in_batch, 1)\n", + " edge_w: torch.Size([914, 1]) (num_edges_in_batch, 1)\n" + ] + } + ], + "source": [ + "# Get attention weights\n", + "batch = next(iter(test_loader))\n", + "batch = batch.to(device)\n", + "\n", + "vn = batch.visit_padded_node.reshape(BATCH_SIZE, MAX_VISIT, NUM_NODES).float()\n", + "en = batch.ehr_nodes.reshape(BATCH_SIZE, NUM_NODES).float()\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " logits, alphas, betas, attns, edge_ws = model(\n", + " node_ids=batch.y,\n", + " rel_ids=batch.relation,\n", + " edge_index=batch.edge_index,\n", + " batch=batch.batch,\n", + " visit_node=vn,\n", + " ehr_nodes=en,\n", + " store_attn=True,\n", + " )\n", + "\n", + "print(\"Attention weight shapes (per layer):\")\n", + "for i, (a, b, att, ew) in enumerate(zip(alphas, betas, attns, edge_ws)):\n", + " print(f\" Layer {i+1}:\")\n", + " print(f\" alpha: {a.shape} (batch, max_visit, num_nodes)\")\n", + " print(f\" beta: {b.shape} (batch, max_visit, 1)\")\n", + " print(f\" attn: {att.shape} (num_edges_in_batch, 1)\")\n", + " print(f\" edge_w: {ew.shape} (num_edges_in_batch, 1)\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: Recent visits (higher index) get boosted by temporal decay λ_j = exp(γ(V−j))\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\637682\\AppData\\Local\\Temp\\ipykernel_25424\\2141043452.py:18: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " plt.show()\n" + ] + } + ], + "source": [ + "# --- Visualise visit-level attention for first patient ---\n", + "import matplotlib\n", + "matplotlib.use(\"Agg\") # Non-interactive backend\n", + "import matplotlib.pyplot as plt\n", + "\n", + "patient_idx = 0\n", + "\n", + "# Beta weights for patient 0, layer 1: (max_visit, 1)\n", + "beta_patient = betas[0][patient_idx].squeeze().cpu().numpy()\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 3))\n", + "visits = [f\"Visit {j+1}\" for j in range(MAX_VISIT)]\n", + "bars = ax.bar(visits, beta_patient, color=\"steelblue\")\n", + "ax.set_ylabel(\"β attention × temporal decay\")\n", + "ax.set_title(\"Visit-Level Attention Weights (Patient 0, Layer 1)\")\n", + "ax.axhline(y=0, color=\"gray\", linestyle=\"--\", linewidth=0.5)\n", + "plt.tight_layout()\n", + "plt.show()\n", + "print(\"Note: Recent visits (higher index) get boosted by temporal decay λ_j = exp(γ(V−j))\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 10 most attended KG nodes for Patient 0:\n", + " Node ID Importance\n", + "------------------------\n", + " 492 1.0000\n", + " 103 1.0000\n", + " 450 1.0000\n", + " 92 1.0000\n", + " 153 1.0000\n", + " 167 1.0000\n", + " 205 1.0000\n", + " 204 1.0000\n", + " 200 1.0000\n", + " 216 1.0000\n", + "\n", + "In a real setting, these node IDs map to medical concepts\n", + "(conditions, procedures, drugs) via the cluster mapping.\n" + ] + } + ], + "source": [ + "# --- Top attended nodes ---\n", + "# Alpha for patient 0, layer 1: (max_visit, num_nodes)\n", + "alpha_patient = alphas[0][patient_idx].cpu() # (max_visit, num_nodes)\n", + "\n", + "# Sum across visits to find globally important nodes\n", + "node_importance = alpha_patient.sum(dim=0).numpy() # (num_nodes,)\n", + "\n", + "top_k = 10\n", + "top_indices = np.argsort(node_importance)[-top_k:][::-1]\n", + "\n", + "print(f\"Top {top_k} most attended KG nodes for Patient 0:\")\n", + "print(f\"{'Node ID':>8} {'Importance':>12}\")\n", + "print(\"-\" * 24)\n", + "for idx in top_indices:\n", + " print(f\"{idx:>8d} {node_importance[idx]:>12.4f}\")\n", + "print()\n", + "print(\"In a real setting, these node IDs map to medical concepts\")\n", + "print(\"(conditions, procedures, drugs) via the cluster mapping.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 9. Using Pre-Built KG Artifacts (Real Data)\n", + "\n", + "For real EHR data with pre-built KGs, use the `graphcare_utils` module:\n", + "\n", + "```python\n", + "from pyhealth.models.graphcare_utils import (\n", + " load_kg_artifacts,\n", + " prepare_graphcare_data,\n", + " build_graphcare_dataloaders,\n", + " reshape_batch_tensors,\n", + ")\n", + "\n", + "# 1. Load pre-built KG artifacts\n", + "artifacts = load_kg_artifacts(\n", + " sample_dataset_path=\"sample_dataset_mimic3_mortality_th015.pkl\",\n", + " graph_path=\"graph_mimic3_mortality_th015.pkl\",\n", + " ent_emb_path=\"entity_embedding.pkl\",\n", + " rel_emb_path=\"relation_embedding.pkl\",\n", + " cluster_path=\"clusters_th015.json\",\n", + " cluster_rel_path=\"clusters_rel_th015.json\",\n", + " ccscm_id2clus_path=\"ccscm_id2clus.json\",\n", + " ccsproc_id2clus_path=\"ccsproc_id2clus.json\",\n", + ")\n", + "\n", + "# 2. Prepare data (labels, embeddings, splits)\n", + "prepared = prepare_graphcare_data(artifacts, task=\"mortality\")\n", + "\n", + "# 3. Build PyG DataLoaders\n", + "train_loader, val_loader, test_loader = build_graphcare_dataloaders(\n", + " prepared, batch_size=64,\n", + ")\n", + "\n", + "# 4. Build model\n", + "model = GraphCare(\n", + " num_nodes=prepared[\"num_nodes\"],\n", + " num_rels=prepared[\"num_rels\"],\n", + " max_visit=prepared[\"max_visit\"],\n", + " embedding_dim=prepared[\"node_emb\"].shape[1],\n", + " hidden_dim=128,\n", + " out_channels=prepared[\"task_config\"][\"out_channels\"],\n", + " node_emb=prepared[\"node_emb\"],\n", + " rel_emb=prepared[\"rel_emb\"],\n", + " gnn=\"BAT\",\n", + " patient_mode=\"joint\",\n", + ")\n", + "\n", + "# 5. Use reshape_batch_tensors in the training loop\n", + "for data in train_loader:\n", + " batch_tensors = reshape_batch_tensors(\n", + " data, batch_size=64,\n", + " max_visit=prepared[\"max_visit\"],\n", + " num_nodes=prepared[\"num_nodes\"],\n", + " )\n", + " logits = model(**batch_tensors) # minus 'label'\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "### GraphCare Model\n", + "\n", + "Predicts healthcare outcomes using personalized patient knowledge graphs with bi-attention augmented GNNs.\n", + "\n", + "| Component | Options | Description |\n", + "|-----------|---------|-------------|\n", + "| **GNN Backbone** | `BAT` (default), `GAT`, `GIN` | Message-passing layer; BAT adds bi-attention + edge attention |\n", + "| **Patient Mode** | `joint` (default), `graph`, `node` | How to produce patient-level representation |\n", + "| **Attention** | α (visit), β (node) | Visit-level and node-level attention with temporal decay |\n", + "| **Temporal Decay** | λ_j = exp(γ(V−j)) | Exponential weighting favouring recent visits |\n", + "\n", + "### Key Features\n", + "\n", + "- Pre-trained node/relation embeddings (from TransE or word2vec on KG)\n", + "- Optional edge dropout for regularisation\n", + "- Attention weights for clinical interpretability\n", + "- Supports mortality, readmission, drug recommendation, length-of-stay tasks\n", + "\n", + "### Files\n", + "\n", + "| File | Purpose |\n", + "|------|------|\n", + "| `pyhealth/models/graphcare.py` | Model implementation (GraphCare + BiAttentionGNNConv) |\n", + "| `pyhealth/models/graphcare_utils.py` | Data pipeline utilities (KG loading, subgraph extraction, dataloaders) |\n", + "| `examples/train_graphcare.py` | End-to-end training script with CLI |\n", + "| `tests/test_graphcare.py` | Unit tests |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index dee4dffdf..5c8f48265 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -7,6 +7,7 @@ from .deepr import Deepr, DeeprLayer from .embedding import EmbeddingModel from .gamenet import GAMENet, GAMENetLayer +from .graphcare import GraphCare, BiAttentionGNNConv from .logistic_regression import LogisticRegression from .gan import GAN from .gnn import GAT, GCN diff --git a/pyhealth/models/graphcare.py b/pyhealth/models/graphcare.py new file mode 100644 index 000000000..e440d71e3 --- /dev/null +++ b/pyhealth/models/graphcare.py @@ -0,0 +1,506 @@ +"""GraphCare model for PyHealth. + +Paper: Pengcheng Jiang et al. GraphCare: Enhancing Healthcare Predictions +with Personalized Knowledge Graphs. ICLR 2024. + +This model constructs personalized patient-level knowledge graphs from +medical codes (conditions, procedures, drugs) using pre-built code-level +knowledge subgraphs, then applies a GNN (GAT, GIN, or BAT) with +bi-attention pooling for downstream healthcare prediction tasks. + +Note: + This model requires ``torch-geometric`` to be installed:: + + pip install torch-geometric + +Note: + This model requires pre-computed knowledge graphs for medical codes. + See the GraphCare paper and the original implementation at + https://github.com/pat-jj/GraphCare for graph generation details. +""" + +from typing import Optional, Tuple, Union +import logging +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Lazy imports for torch_geometric +# --------------------------------------------------------------------------- +try: + from torch_geometric.nn import GATConv, GINConv + from torch_geometric.nn.conv import MessagePassing + from torch_geometric.nn import global_mean_pool + from torch_geometric.data import Data, Batch + from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size + from torch import Tensor + + HAS_TORCH_GEOMETRIC = True +except ImportError: + HAS_TORCH_GEOMETRIC = False + + +def _check_torch_geometric(): + """Raises ImportError with install instructions if torch_geometric is missing.""" + if not HAS_TORCH_GEOMETRIC: + raise ImportError( + "GraphCare requires torch-geometric. " + "Install it with: pip install torch-geometric" + ) + + +# =========================================================================== +# BiAttentionGNNConv – the BAT message-passing layer from the paper (§3.3) +# =========================================================================== + + +class BiAttentionGNNConv(MessagePassing if HAS_TORCH_GEOMETRIC else nn.Module): + r"""Bi-Attention augmented GNN convolution (BAT layer). + + This is a GIN-style message-passing layer augmented with: + * **node-level attention** (``attn``) injected from the outer model, + * **edge-level attention** via a learnable projection of relation + embeddings (``W_R``). + + The message for edge :math:`(j \to i)` is: + + .. math:: + m_{j \to i} = \text{ReLU}\bigl( + x_j \cdot \alpha_{j} + W_R(e_{ji}) \cdot e_{ji} + \bigr) + + where :math:`\alpha_j` is the pre-computed bi-attention weight for node + *j* and :math:`e_{ji}` is the relation embedding on that edge. + + Args: + nn: A neural network module applied after aggregation (typically + ``nn.Linear(hidden_dim, hidden_dim)``). + eps: Initial value of :math:`\varepsilon` for the self-loop weight. + train_eps: If ``True``, :math:`\varepsilon` is a learnable parameter. + edge_dim: Dimension of edge (relation) features. + edge_attn: If ``True``, use edge attention via ``W_R``. + """ + + def __init__( + self, + nn: torch.nn.Module, + eps: float = 0.0, + train_eps: bool = False, + edge_dim: Optional[int] = None, + edge_attn: bool = True, + **kwargs, + ): + _check_torch_geometric() + kwargs.setdefault("aggr", "add") + super().__init__(**kwargs) + self.nn = nn + self.initial_eps = eps + self.edge_attn = edge_attn + + if edge_attn: + self.W_R = torch.nn.Linear(edge_dim, 1) + else: + self.W_R = None + + if train_eps: + self.eps = torch.nn.Parameter(torch.Tensor([eps])) + else: + self.register_buffer("eps", torch.Tensor([eps])) + + self.reset_parameters() + + def reset_parameters(self): + self.nn.reset_parameters() + self.eps.data.fill_(self.initial_eps) + if self.W_R is not None: + self.W_R.reset_parameters() + + def forward( + self, + x: Union[Tensor, "OptPairTensor"], + edge_index: "Adj", + edge_attr: "OptTensor" = None, + size: "Size" = None, + attn: Tensor = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + if isinstance(x, Tensor): + x = (x, x) + + # propagate computes messages and aggregates + out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size, attn=attn) + + x_r = x[1] + if x_r is not None: + out = out + (1 + self.eps) * x_r + + if self.W_R is not None: + w_rel = self.W_R(edge_attr) + else: + w_rel = None + + return self.nn(out), w_rel + + def message(self, x_j: Tensor, edge_attr: Tensor, attn: Tensor) -> Tensor: + if self.edge_attn: + w_rel = self.W_R(edge_attr) + out = (x_j * attn + w_rel * edge_attr).relu() + else: + out = (x_j * attn).relu() + return out + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(nn={self.nn})" + + +# =========================================================================== +# GraphCare – main model (PyHealth-compatible) +# =========================================================================== + + +class GraphCare(nn.Module): + r"""GraphCare model for EHR-based healthcare predictions. + + This is a **graph-level** model that operates on pre-constructed + patient knowledge graphs (KGs) via ``torch_geometric``. It is + designed to be called **outside** the standard PyHealth ``Trainer`` + loop because the data pipeline requires ``torch_geometric.loader.DataLoader`` + rather than PyHealth's default collation. + + The architecture has three components: + + 1. **Embedding layer** – maps node IDs and relation IDs to dense + vectors using (optionally pre-trained) embedding tables, then + projects them to ``hidden_dim``. + 2. **GNN encoder** – *L* layers of message passing. Three GNN + back-ends are supported: + + * ``"BAT"`` – Bi-Attention augmented GNN (default, from the paper). + Uses per-node attention weights derived from visit-level + (alpha) and node-level (beta) attention with temporal decay. + * ``"GAT"`` – standard Graph Attention Network. + * ``"GIN"`` – Graph Isomorphism Network. + + 3. **Patient representation head** – produces a patient-level vector + from the node embeddings, with three modes: + + * ``"graph"`` – global mean pool over all graph nodes. + * ``"node"`` – weighted average of direct EHR-node embeddings. + * ``"joint"`` – concatenation of both (default). + + Args: + num_nodes: Total number of nodes in the knowledge graph. + num_rels: Total number of relation types. + max_visit: Maximum number of visits per patient. + embedding_dim: Dimension of pre-trained node/relation embeddings. + hidden_dim: Hidden dimension used throughout the model. + out_channels: Number of output classes / labels. + layers: Number of GNN layers. Default ``3``. + dropout: Dropout rate for the final MLP. Default ``0.5``. + decay_rate: Temporal decay rate :math:`\gamma` for visit + weighting: :math:`\lambda_j = e^{\gamma (V - j)}`. + Default ``0.01``. + node_emb: Optional pre-trained node embedding tensor of shape + ``[num_nodes, embedding_dim]``. If ``None``, embeddings + are learned from scratch. + rel_emb: Optional pre-trained relation embedding tensor of shape + ``[num_rels, embedding_dim]``. If ``None``, embeddings + are learned from scratch. + freeze: If ``True``, freeze pre-trained embeddings. + patient_mode: Patient representation mode – one of + ``"joint"`` (default), ``"graph"``, or ``"node"``. + use_alpha: Use visit-level (alpha) attention. Default ``True``. + use_beta: Use node-level (beta) attention with temporal decay. + Default ``True``. + use_edge_attn: Use edge (relation) attention in BAT layers. + Default ``True``. + gnn: GNN backbone – ``"BAT"`` (default), ``"GAT"``, or ``"GIN"``. + attn_init: Optional tensor of shape ``[num_nodes]`` for + initialising the diagonal of alpha attention weights. + drop_rate: Edge dropout rate during training. Default ``0.0``. + self_attn: Initial value of the GIN-style self-loop weight + :math:`\varepsilon` in the BAT layer. Default ``0.0``. + + Example: + + The model is instantiated and called directly with + ``torch_geometric``-style batched data:: + + >>> model = GraphCare( + ... num_nodes=5000, num_rels=100, max_visit=20, + ... embedding_dim=128, hidden_dim=128, out_channels=1, + ... ) + >>> # node_ids, rel_ids, edge_index, batch from PyG DataLoader + >>> logits = model(node_ids, rel_ids, edge_index, batch, + ... visit_node, ehr_nodes) + """ + + def __init__( + self, + num_nodes: int, + num_rels: int, + max_visit: int, + embedding_dim: int, + hidden_dim: int, + out_channels: int, + layers: int = 3, + dropout: float = 0.5, + decay_rate: float = 0.01, + node_emb: Optional[torch.Tensor] = None, + rel_emb: Optional[torch.Tensor] = None, + freeze: bool = False, + patient_mode: str = "joint", + use_alpha: bool = True, + use_beta: bool = True, + use_edge_attn: bool = True, + gnn: str = "BAT", + attn_init: Optional[torch.Tensor] = None, + drop_rate: float = 0.0, + self_attn: float = 0.0, + ): + super().__init__() + _check_torch_geometric() + + assert patient_mode in ("joint", "graph", "node"), \ + f"patient_mode must be 'joint', 'graph', or 'node', got '{patient_mode}'" + assert gnn in ("BAT", "GAT", "GIN"), \ + f"gnn must be 'BAT', 'GAT', or 'GIN', got '{gnn}'" + + self.gnn = gnn + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + self.decay_rate = decay_rate + self.patient_mode = patient_mode + self.use_alpha = use_alpha + self.use_beta = use_beta + self.edge_attn = use_edge_attn + self.drop_rate = drop_rate + self.num_nodes = num_nodes + self.num_rels = num_rels + self.max_visit = max_visit + self.num_layers = layers + self.dropout = dropout + + # --- Temporal decay weights: lambda_j = exp(gamma * (V - j)) --- + j = torch.arange(max_visit).float() + lambda_j = ( + torch.exp(self.decay_rate * (max_visit - j)) + .unsqueeze(0) + .reshape(1, max_visit, 1) + .float() + ) + self.register_buffer("lambda_j", lambda_j) + + # --- Embeddings --- + if node_emb is None: + self.node_emb = nn.Embedding(num_nodes, embedding_dim) + else: + self.node_emb = nn.Embedding.from_pretrained(node_emb.float(), freeze=freeze) + + if rel_emb is None: + self.rel_emb = nn.Embedding(num_rels, embedding_dim) + else: + self.rel_emb = nn.Embedding.from_pretrained(rel_emb.float(), freeze=freeze) + + # --- Projection to hidden_dim --- + self.lin = nn.Linear(embedding_dim, hidden_dim) + + # --- Per-layer modules --- + self.alpha_attn = nn.ModuleDict() + self.beta_attn = nn.ModuleDict() + self.conv = nn.ModuleDict() + + for layer_idx in range(1, layers + 1): + k = str(layer_idx) + + # Visit-level attention (alpha) + if self.use_alpha: + self.alpha_attn[k] = nn.Linear(num_nodes, num_nodes) + if attn_init is not None: + attn_init_f = attn_init.float() + attn_init_matrix = torch.eye(num_nodes).float() * attn_init_f + self.alpha_attn[k].weight.data.copy_(attn_init_matrix) + else: + nn.init.xavier_normal_(self.alpha_attn[k].weight) + + # Node-level attention (beta) with temporal decay + if self.use_beta: + self.beta_attn[k] = nn.Linear(num_nodes, 1) + nn.init.xavier_normal_(self.beta_attn[k].weight) + + # GNN convolution + if self.gnn == "BAT": + self.conv[k] = BiAttentionGNNConv( + nn.Linear(hidden_dim, hidden_dim), + edge_dim=hidden_dim, + edge_attn=self.edge_attn, + eps=self_attn, + ) + elif self.gnn == "GAT": + self.conv[k] = GATConv(hidden_dim, hidden_dim) + elif self.gnn == "GIN": + self.conv[k] = GINConv(nn.Linear(hidden_dim, hidden_dim)) + + # --- Final MLP head --- + if self.patient_mode == "joint": + self.MLP = nn.Linear(hidden_dim * 2, out_channels) + else: + self.MLP = nn.Linear(hidden_dim, out_channels) + + def forward( + self, + node_ids: torch.Tensor, + rel_ids: torch.Tensor, + edge_index: torch.Tensor, + batch: torch.Tensor, + visit_node: torch.Tensor, + ehr_nodes: Optional[torch.Tensor] = None, + store_attn: bool = False, + in_drop: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + """Forward pass. + + Args: + node_ids: Node ID tensor of shape ``[total_nodes_in_batch]``. + These are the ``data.y`` values from the PyG batch + (original node IDs used to look up embeddings). + rel_ids: Relation ID tensor of shape ``[total_edges_in_batch]``. + Used to look up relation embeddings. + edge_index: Edge index tensor of shape ``[2, total_edges_in_batch]``. + batch: Batch assignment vector of shape ``[total_nodes_in_batch]``, + mapping each node to its graph index in the batch. + visit_node: Padded visit-node tensor of shape + ``[batch_size, max_visit, num_nodes]``. Binary indicator of + which KG nodes appear in each patient visit. + ehr_nodes: Direct EHR node indicator of shape + ``[batch_size, num_nodes]`` (one-hot). Required when + ``patient_mode`` is ``"node"`` or ``"joint"``. + store_attn: If ``True``, return intermediate attention weights + for interpretability. + in_drop: If ``True`` and ``drop_rate > 0``, randomly drop edges + during training. + + Returns: + Logits tensor of shape ``[batch_size, out_channels]``. + If ``store_attn`` is ``True``, also returns alpha, beta, + attention, and edge weight lists. + """ + # --- Optional edge dropout --- + if in_drop and self.drop_rate > 0: + edge_count = edge_index.size(1) + edges_to_remove = int(edge_count * self.drop_rate) + if edges_to_remove > 0: + indices_to_remove = set(random.sample(range(edge_count), edges_to_remove)) + keep = [i for i in range(edge_count) if i not in indices_to_remove] + edge_index = edge_index[:, keep].to(edge_index.device) + rel_ids = rel_ids[keep] + + # --- Embed & project --- + x = self.lin(self.node_emb(node_ids).float()) + edge_attr = self.lin(self.rel_emb(rel_ids).float()) + + batch_size = batch.max().item() + 1 + + if store_attn: + alpha_weights_list, beta_weights_list = [], [] + attention_weights_list, edge_weights_list = [], [] + + # --- GNN layers with bi-attention --- + for layer_idx in range(1, self.num_layers + 1): + k = str(layer_idx) + + # Compute alpha: visit-level attention (batch, max_visit, num_nodes) + if self.use_alpha: + alpha = torch.softmax( + self.alpha_attn[k](visit_node.float()), dim=1 + ) + + # Compute beta: node-level attention with temporal decay + if self.use_beta: + beta = ( + torch.tanh(self.beta_attn[k](visit_node.float())) + * self.lambda_j + ) + + # Combine alpha and beta + if self.use_alpha and self.use_beta: + attn = alpha * beta + elif self.use_alpha: + attn = alpha * torch.ones( + batch_size, self.max_visit, 1, device=edge_index.device + ) + elif self.use_beta: + attn = beta * torch.ones( + batch_size, self.max_visit, self.num_nodes, + device=edge_index.device, + ) + else: + attn = torch.ones( + batch_size, self.max_visit, self.num_nodes, + device=edge_index.device, + ) + + # Sum over visits → (batch, num_nodes) + attn = torch.sum(attn, dim=1) + + # Index into per-edge attention: for each edge (j→i), + # get attn[batch_of_j, node_id_of_j] + xj_node_ids = node_ids[edge_index[0]] + xj_batch = batch[edge_index[0]] + attn_per_edge = attn[xj_batch, xj_node_ids].reshape(-1, 1) + + # Apply GNN conv + if self.gnn == "BAT": + x, w_rel = self.conv[k](x, edge_index, edge_attr, attn=attn_per_edge) + else: + x = self.conv[k](x, edge_index) + w_rel = None + + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + if store_attn: + alpha_weights_list.append(alpha if self.use_alpha else None) + beta_weights_list.append(beta if self.use_beta else None) + attention_weights_list.append(attn_per_edge) + edge_weights_list.append(w_rel) + + # --- Patient representation --- + if self.patient_mode in ("joint", "graph"): + x_graph = global_mean_pool(x, batch) + x_graph = F.dropout(x_graph, p=self.dropout, training=self.training) + + if self.patient_mode in ("joint", "node"): + # Weighted average of direct EHR node embeddings + # ehr_nodes: (batch_size, num_nodes) — binary indicators + x_node = torch.stack([ + ehr_nodes[i].view(1, -1) @ self.node_emb.weight + / torch.sum(ehr_nodes[i]).clamp(min=1) + for i in range(batch_size) + ]) + x_node = self.lin(x_node).squeeze(1) + x_node = F.dropout(x_node, p=self.dropout, training=self.training) + + # --- Prediction head --- + if self.patient_mode == "joint": + x_concat = torch.cat((x_graph, x_node), dim=1) + x_concat = F.dropout(x_concat, p=self.dropout, training=self.training) + logits = self.MLP(x_concat) + elif self.patient_mode == "graph": + logits = self.MLP(x_graph) + else: # "node" + logits = self.MLP(x_node) + + if store_attn: + return ( + logits, + alpha_weights_list, + beta_weights_list, + attention_weights_list, + edge_weights_list, + ) + return logits \ No newline at end of file diff --git a/pyhealth/models/graphcare_utils.py b/pyhealth/models/graphcare_utils.py new file mode 100644 index 000000000..f4cfa4531 --- /dev/null +++ b/pyhealth/models/graphcare_utils.py @@ -0,0 +1,558 @@ +"""GraphCare data pipeline utilities for PyHealth. + +This module provides utilities to convert PyHealth ``SampleDataset`` records +and pre-built knowledge graph artifacts into ``torch_geometric`` Data objects +suitable for the :class:`~pyhealth.models.graphcare.GraphCare` model. + +The pipeline mirrors the original GraphCare implementation +(https://github.com/pat-jj/GraphCare) but is refactored for cleaner +integration with PyHealth. + +Typical usage:: + + from pyhealth.models.graphcare_utils import ( + load_kg_artifacts, + prepare_graphcare_data, + build_graphcare_dataloaders, + ) + + # 1. Load pre-built KG artifacts + artifacts = load_kg_artifacts( + sample_dataset_path="sample_dataset_mimic3_mortality_th015.pkl", + graph_path="graph_mimic3_mortality_th015.pkl", + ent_emb_path="entity_embedding.pkl", + rel_emb_path="relation_embedding.pkl", + cluster_path="clusters_th015.json", + cluster_rel_path="clusters_rel_th015.json", + ccscm_id2clus_path="ccscm_id2clus.json", + ccsproc_id2clus_path="ccsproc_id2clus.json", + ) + + # 2. Prepare PyG-compatible data + prepared = prepare_graphcare_data(artifacts, task="mortality") + + # 3. Build dataloaders + train_loader, val_loader, test_loader = build_graphcare_dataloaders( + prepared, batch_size=64, + ) +""" + +from typing import Dict, List, Optional, Tuple, Any +import os +import json +import pickle +import logging + +import torch +import numpy as np +from copy import deepcopy + +logger = logging.getLogger(__name__) + +# Lazy import torch_geometric +try: + from torch_geometric.data import Data, Batch + from torch_geometric.loader import DataLoader as PyGDataLoader + from torch_geometric.utils import from_networkx, k_hop_subgraph + + HAS_TORCH_GEOMETRIC = True +except ImportError: + HAS_TORCH_GEOMETRIC = False + + +def _check_torch_geometric(): + if not HAS_TORCH_GEOMETRIC: + raise ImportError( + "GraphCare data utilities require torch-geometric. " + "Install with: pip install torch-geometric" + ) + + +# =========================================================================== +# 1. Loading pre-built KG artifacts +# =========================================================================== + + +def load_kg_artifacts( + sample_dataset_path: str, + graph_path: str, + ent_emb_path: str, + rel_emb_path: str, + cluster_path: str, + cluster_rel_path: str, + ccscm_id2clus_path: str, + ccsproc_id2clus_path: str, + atc3_id2clus_path: Optional[str] = None, +) -> Dict[str, Any]: + """Load all pre-built KG artifacts required by GraphCare. + + These artifacts are produced by the GraphCare KG construction pipeline + (LLM prompting + subgraph sampling + node/edge clustering). See the + original repo for generation scripts. + + Args: + sample_dataset_path: Path to pickled PyHealth SampleDataset list. + graph_path: Path to pickled NetworkX graph (the global KG). + ent_emb_path: Path to pickled entity (node) embeddings. + rel_emb_path: Path to pickled relation embeddings. + cluster_path: Path to JSON node cluster mapping. + cluster_rel_path: Path to JSON relation cluster mapping. + ccscm_id2clus_path: Path to JSON CCS-CM code → cluster ID mapping. + ccsproc_id2clus_path: Path to JSON CCS-Proc code → cluster ID mapping. + atc3_id2clus_path: Optional path to JSON ATC3 drug code → cluster ID + mapping. Required for mortality/readmission tasks that include + drug features. + + Returns: + Dictionary with keys: ``sample_dataset``, ``graph``, ``ent_emb``, + ``rel_emb``, ``cluster_map``, ``cluster_rel_map``, + ``ccscm_id2clus``, ``ccsproc_id2clus``, ``atc3_id2clus``. + """ + def _load_pkl(path): + with open(path, "rb") as f: + return pickle.load(f) + + def _load_json(path): + with open(path, "r") as f: + return json.load(f) + + artifacts = { + "sample_dataset": _load_pkl(sample_dataset_path), + "graph": _load_pkl(graph_path), + "ent_emb": _load_pkl(ent_emb_path), + "rel_emb": _load_pkl(rel_emb_path), + "cluster_map": _load_json(cluster_path), + "cluster_rel_map": _load_json(cluster_rel_path), + "ccscm_id2clus": _load_json(ccscm_id2clus_path), + "ccsproc_id2clus": _load_json(ccsproc_id2clus_path), + "atc3_id2clus": _load_json(atc3_id2clus_path) if atc3_id2clus_path else None, + } + + logger.info( + f"Loaded KG artifacts: {len(artifacts['sample_dataset'])} patients, " + f"graph has {artifacts['graph'].number_of_nodes()} nodes / " + f"{artifacts['graph'].number_of_edges()} edges" + ) + return artifacts + + +# =========================================================================== +# 2. Labelling & subgraph extraction +# =========================================================================== + + +def _flatten(lst): + """Recursively flatten nested lists.""" + result = [] + for item in lst: + if isinstance(item, list): + result.extend(_flatten(item)) + else: + result.append(item) + return result + + +def label_ehr_nodes( + sample_dataset: List[Dict], + task: str, + num_nodes: int, + ccscm_id2clus: Dict[str, str], + ccsproc_id2clus: Dict[str, str], + atc3_id2clus: Optional[Dict[str, str]] = None, +) -> List[Dict]: + """Add ``ehr_node_set`` (one-hot) to each patient record. + + Maps each patient's conditions, procedures, and (optionally) drugs + to their cluster node IDs and creates a binary indicator vector. + + Args: + sample_dataset: List of patient sample dicts from PyHealth. + task: Task name — ``"mortality"``, ``"readmission"``, ``"drugrec"``, + or ``"lenofstay"``. + num_nodes: Total number of cluster nodes in the KG. + ccscm_id2clus: CCS-CM condition code → cluster ID mapping. + ccsproc_id2clus: CCS-Proc procedure code → cluster ID mapping. + atc3_id2clus: ATC3 drug code → cluster ID mapping (required for + mortality/readmission). + + Returns: + The same dataset list, with ``ehr_node_set`` added to each record. + """ + for patient in sample_dataset: + nodes = [] + + for condition in _flatten(patient["conditions"]): + if condition in ccscm_id2clus: + ehr_node = int(ccscm_id2clus[condition]) + nodes.append(ehr_node) + patient["node_set"].append(ehr_node) + + for procedure in _flatten(patient["procedures"]): + if procedure in ccsproc_id2clus: + ehr_node = int(ccsproc_id2clus[procedure]) + nodes.append(ehr_node) + patient["node_set"].append(ehr_node) + + if task in ("mortality", "readmission") and atc3_id2clus is not None: + for drug in _flatten(patient.get("drugs", [])): + if drug in atc3_id2clus: + ehr_node = int(atc3_id2clus[drug]) + nodes.append(ehr_node) + patient["node_set"].append(ehr_node) + + node_vec = np.zeros(num_nodes) + if nodes: + node_vec[nodes] = 1 + patient["ehr_node_set"] = torch.tensor(node_vec) + + return sample_dataset + + +def get_rel_emb_from_clusters(cluster_rel_map: Dict) -> torch.Tensor: + """Extract relation embeddings from the cluster relation mapping. + + Args: + cluster_rel_map: Dict mapping relation cluster ID (str) to + a dict containing ``"embedding"`` key. + + Returns: + Tensor of shape ``[num_rels, emb_dim]``. + """ + rel_emb = [] + for i in range(len(cluster_rel_map)): + rel_emb.append(cluster_rel_map[str(i)]["embedding"][0]) + return torch.tensor(np.array(rel_emb)) + + +def extract_patient_subgraph( + G_tg: "Data", + patient: Dict, + task: str, + k_hop: int = 2, +) -> "Data": + """Extract a patient-specific subgraph from the global KG. + + Uses k-hop subgraph extraction centred on the patient's node set, + then attaches task labels and visit/EHR node metadata. + + Args: + G_tg: The global KG as a PyG ``Data`` object. + patient: A single patient record dict with keys ``node_set``, + ``visit_padded_node``, ``ehr_node_set``, ``patient_id``, + and task-specific label fields. + task: Task name for label extraction. + k_hop: Number of hops for subgraph extraction. Default ``2``. + + Returns: + A PyG ``Data`` object representing the patient's subgraph with + attached metadata. + """ + _check_torch_geometric() + + node_set = patient["node_set"] + if len(node_set) == 0: + # Return a minimal graph if no nodes + P = Data( + edge_index=torch.zeros(2, 0, dtype=torch.long), + y=torch.zeros(0, dtype=torch.long), + relation=torch.zeros(0, dtype=torch.long), + ) + else: + nodes, _, _, edge_mask = k_hop_subgraph( + torch.tensor(node_set), k_hop, G_tg.edge_index + ) + mask_idx = torch.where(edge_mask)[0] + L = G_tg.edge_subgraph(mask_idx) + P = L.subgraph(torch.tensor(node_set)) + + # Attach label + if task == "drugrec": + P.label = patient["drugs_ind"] + elif task == "lenofstay": + label = np.zeros(10) + label[patient["label"]] = 1 + P.label = torch.tensor(label) + else: # mortality, readmission + P.label = patient["label"] + + # Attach visit and EHR node info + P.visit_padded_node = patient["visit_padded_node"] + P.ehr_nodes = patient["ehr_node_set"] + P.patient_id = patient["patient_id"] + + return P + + +# =========================================================================== +# 3. Dataset & DataLoader +# =========================================================================== + + +class GraphCareDataset(torch.utils.data.Dataset): + """PyTorch Dataset that lazily extracts patient subgraphs. + + Each ``__getitem__`` call extracts the k-hop subgraph for one patient + from the global KG, attaches labels and metadata, and returns a + ``torch_geometric.data.Data`` object. + + Args: + G_tg: Global KG as a PyG Data object. + dataset: List of patient record dicts. + task: Task name. + k_hop: Number of hops for subgraph extraction. + """ + + def __init__(self, G_tg, dataset: List[Dict], task: str, k_hop: int = 2): + self.G_tg = G_tg + self.dataset = dataset + self.task = task + self.k_hop = k_hop + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + patient = self.dataset[idx] + # Skip patients with empty node sets by falling back + while len(patient["node_set"]) == 0 and idx > 0: + idx -= 1 + patient = self.dataset[idx] + return extract_patient_subgraph( + self.G_tg, patient, self.task, self.k_hop + ) + + +# =========================================================================== +# 4. High-level prepare & build functions +# =========================================================================== + + +def get_task_config(task: str, sample_dataset: List[Dict]) -> Dict[str, Any]: + """Get task-specific configuration. + + Args: + task: One of ``"mortality"``, ``"readmission"``, ``"drugrec"``, + ``"lenofstay"``. + + Returns: + Dict with ``mode``, ``out_channels``, and ``loss_fn``. + """ + import torch.nn.functional as F_ + + if task in ("mortality", "readmission"): + return { + "mode": "binary", + "out_channels": 1, + "loss_fn": F_.binary_cross_entropy_with_logits, + } + elif task == "drugrec": + return { + "mode": "multilabel", + "out_channels": len(sample_dataset[0]["drugs_ind"]), + "loss_fn": F_.binary_cross_entropy_with_logits, + } + elif task == "lenofstay": + return { + "mode": "multiclass", + "out_channels": 10, + "loss_fn": F_.cross_entropy, + } + else: + raise ValueError(f"Unknown task: {task}") + + +def _split_by_patient( + dataset: List[Dict], + ratios: List[float], + seed: int = 528, +) -> Tuple[List[Dict], List[Dict], List[Dict]]: + """Split patient records into train/val/test by patient ID. + + Args: + dataset: List of patient record dicts (must have ``"patient_id"``). + ratios: Three-element list of train/val/test ratios (must sum to 1). + seed: Random seed for reproducibility. + + Returns: + Tuple of (train, val, test) record lists. + """ + import random as _rng + + patient_ids = sorted(set(p["patient_id"] for p in dataset)) + _rng.seed(seed) + _rng.shuffle(patient_ids) + + n = len(patient_ids) + n_train = int(n * ratios[0]) + n_val = int(n * ratios[1]) + + train_ids = set(patient_ids[:n_train]) + val_ids = set(patient_ids[n_train : n_train + n_val]) + test_ids = set(patient_ids[n_train + n_val :]) + + train = [p for p in dataset if p["patient_id"] in train_ids] + val = [p for p in dataset if p["patient_id"] in val_ids] + test = [p for p in dataset if p["patient_id"] in test_ids] + + return train, val, test + + +def prepare_graphcare_data( + artifacts: Dict[str, Any], + task: str, + split_ratios: Tuple[float, float, float] = (0.8, 0.1, 0.1), + seed: int = 528, +) -> Dict[str, Any]: + """Prepare all data needed for GraphCare training. + + Takes the raw artifacts from :func:`load_kg_artifacts` and produces: + + * Labelled patient records with ``ehr_node_set`` + * The global KG converted to PyG format + * Node and relation embedding tensors + * Task configuration (mode, out_channels, loss_fn) + * Train/val/test splits + + Args: + artifacts: Output of :func:`load_kg_artifacts`. + task: Task name. + split_ratios: Train/val/test split ratios. + seed: Random seed for splitting. + + Returns: + Dict with keys: ``G_tg``, ``node_emb``, ``rel_emb``, + ``num_nodes``, ``num_rels``, ``max_visit``, ``task_config``, + ``train_dataset``, ``val_dataset``, ``test_dataset``, ``task``. + """ + _check_torch_geometric() + + sample_dataset = artifacts["sample_dataset"] + graph = artifacts["graph"] + cluster_map = artifacts["cluster_map"] + cluster_rel_map = artifacts["cluster_rel_map"] + + # Label EHR nodes + num_cluster_nodes = len(cluster_map) + sample_dataset = label_ehr_nodes( + sample_dataset, + task=task, + num_nodes=num_cluster_nodes, + ccscm_id2clus=artifacts["ccscm_id2clus"], + ccsproc_id2clus=artifacts["ccsproc_id2clus"], + atc3_id2clus=artifacts["atc3_id2clus"], + ) + + # Convert NetworkX graph to PyG + G_tg = from_networkx(graph) + + # Embeddings + rel_emb = get_rel_emb_from_clusters(cluster_rel_map) + node_emb = G_tg.x + + # Task config + task_config = get_task_config(task, sample_dataset) + + # Split + train_dataset, val_dataset, test_dataset = _split_by_patient( + sample_dataset, list(split_ratios), seed=seed + ) + + max_visit = sample_dataset[0]["visit_padded_node"].shape[0] + + return { + "G_tg": G_tg, + "node_emb": node_emb, + "rel_emb": rel_emb, + "num_nodes": node_emb.shape[0], + "num_rels": rel_emb.shape[0], + "max_visit": max_visit, + "task_config": task_config, + "train_dataset": train_dataset, + "val_dataset": val_dataset, + "test_dataset": test_dataset, + "task": task, + } + + +def build_graphcare_dataloaders( + prepared: Dict[str, Any], + batch_size: int = 64, + k_hop: int = 2, +) -> Tuple["PyGDataLoader", "PyGDataLoader", "PyGDataLoader"]: + """Build PyG DataLoaders for GraphCare. + + Args: + prepared: Output of :func:`prepare_graphcare_data`. + batch_size: Batch size. + k_hop: Number of hops for subgraph extraction. + + Returns: + Tuple of (train_loader, val_loader, test_loader). + """ + _check_torch_geometric() + + G_tg = prepared["G_tg"] + task = prepared["task"] + + train_set = GraphCareDataset(G_tg, prepared["train_dataset"], task, k_hop) + val_set = GraphCareDataset(G_tg, prepared["val_dataset"], task, k_hop) + test_set = GraphCareDataset(G_tg, prepared["test_dataset"], task, k_hop) + + train_loader = PyGDataLoader( + train_set, batch_size=batch_size, shuffle=True, drop_last=True + ) + val_loader = PyGDataLoader( + val_set, batch_size=batch_size, shuffle=False, drop_last=True + ) + test_loader = PyGDataLoader( + test_set, batch_size=batch_size, shuffle=False, drop_last=True + ) + + return train_loader, val_loader, test_loader + + +def reshape_batch_tensors( + data: "Batch", + batch_size: int, + max_visit: int, + num_nodes: int, + patient_mode: str = "joint", +) -> Dict[str, torch.Tensor]: + """Reshape batched tensors from PyG DataLoader for GraphCare forward pass. + + The PyG DataLoader concatenates per-graph tensors. This function + reshapes them back into the shapes expected by ``GraphCare.forward()``. + + Args: + data: A batched ``torch_geometric.data.Batch`` object. + batch_size: Number of graphs in the batch. + max_visit: Maximum visits per patient. + num_nodes: Total KG nodes (for reshaping visit_padded_node). + patient_mode: Patient mode to determine if ehr_nodes is needed. + + Returns: + Dict with keys ``node_ids``, ``rel_ids``, ``edge_index``, + ``batch``, ``visit_node``, ``ehr_nodes``, ``label``. + """ + result = { + "node_ids": data.y, + "rel_ids": data.relation, + "edge_index": data.edge_index, + "batch": data.batch, + "visit_node": data.visit_padded_node.reshape( + batch_size, max_visit, num_nodes + ).float(), + } + + if patient_mode != "graph": + result["ehr_nodes"] = data.ehr_nodes.reshape( + batch_size, num_nodes + ).float() + else: + result["ehr_nodes"] = None + + result["label"] = data.label.reshape( + batch_size, -1 + ) + + return result \ No newline at end of file diff --git a/tests/core/test_graphcare.py b/tests/core/test_graphcare.py new file mode 100644 index 000000000..aafe03f4e --- /dev/null +++ b/tests/core/test_graphcare.py @@ -0,0 +1,528 @@ +# Author: Josh Steier +# Description: Tests for GraphCare model + +"""Test cases for GraphCare and BiAttentionGNNConv. + +Run with: python -m pytest test_graphcare.py -v + +Note: Requires torch-geometric to be installed. +""" + +import random +import unittest + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.models.graphcare import ( + GraphCare, + BiAttentionGNNConv, + _check_torch_geometric, +) + +try: + from torch_geometric.data import Data, Batch + from torch_geometric.nn import global_mean_pool + + HAS_TORCH_GEOMETRIC = True +except ImportError: + HAS_TORCH_GEOMETRIC = False + + +def _make_fake_batch( + batch_size=4, + num_nodes=200, + num_rels=30, + max_visit=5, + nodes_per_graph_range=(10, 30), +): + """Create a fake PyG batch for testing. + + Returns: + Tuple of (batched_data, node_ids, rel_ids, edge_index, batch_vec, + visit_node, ehr_nodes, labels). + """ + graphs = [] + for _ in range(batch_size): + n = random.randint(*nodes_per_graph_range) + e = random.randint(n, n * 3) + src = torch.randint(0, n, (e,)) + dst = torch.randint(0, n, (e,)) + y = torch.randint(0, num_nodes, (n,)) + relation = torch.randint(0, num_rels, (e,)) + + vpn = torch.zeros(max_visit, num_nodes) + for v in range(max_visit): + active = torch.randint(0, num_nodes, (random.randint(1, 10),)) + vpn[v, active] = 1.0 + + ehr = torch.zeros(num_nodes) + ehr[torch.randint(0, num_nodes, (5,))] = 1.0 + + data = Data( + edge_index=torch.stack([src, dst]), + y=y, + relation=relation, + visit_padded_node=vpn, + ehr_nodes=ehr, + label=torch.tensor([1.0]), + ) + data.num_nodes = n + graphs.append(data) + + batched = Batch.from_data_list(graphs) + node_ids = batched.y + rel_ids = batched.relation + edge_index = batched.edge_index + batch_vec = batched.batch + visit_node = batched.visit_padded_node.reshape(batch_size, max_visit, num_nodes) + ehr_nodes = batched.ehr_nodes.reshape(batch_size, num_nodes) + labels = batched.label.reshape(batch_size, -1) + + return node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, labels + + +@unittest.skipUnless(HAS_TORCH_GEOMETRIC, "torch-geometric not installed") +class TestBiAttentionGNNConv(unittest.TestCase): + """Test cases for the BiAttentionGNNConv layer.""" + + def test_forward_with_edge_attn(self): + """Test BAT conv with edge attention enabled.""" + hidden_dim = 32 + conv = BiAttentionGNNConv( + nn.Linear(hidden_dim, hidden_dim), + edge_dim=hidden_dim, + edge_attn=True, + ) + + num_nodes, num_edges = 20, 40 + x = torch.randn(num_nodes, hidden_dim) + edge_index = torch.randint(0, num_nodes, (2, num_edges)) + edge_attr = torch.randn(num_edges, hidden_dim) + attn = torch.randn(num_edges, 1) + + out, w_rel = conv(x, edge_index, edge_attr, attn=attn) + + self.assertEqual(out.shape, (num_nodes, hidden_dim)) + self.assertIsNotNone(w_rel) + self.assertEqual(w_rel.shape, (num_edges, 1)) + + def test_forward_without_edge_attn(self): + """Test BAT conv with edge attention disabled.""" + hidden_dim = 32 + conv = BiAttentionGNNConv( + nn.Linear(hidden_dim, hidden_dim), + edge_dim=hidden_dim, + edge_attn=False, + ) + + num_nodes, num_edges = 20, 40 + x = torch.randn(num_nodes, hidden_dim) + edge_index = torch.randint(0, num_nodes, (2, num_edges)) + edge_attr = torch.randn(num_edges, hidden_dim) + attn = torch.randn(num_edges, 1) + + out, w_rel = conv(x, edge_index, edge_attr, attn=attn) + + self.assertEqual(out.shape, (num_nodes, hidden_dim)) + self.assertIsNone(w_rel) + + def test_gradient_flow(self): + """Test gradients flow through BAT conv.""" + hidden_dim = 16 + conv = BiAttentionGNNConv( + nn.Linear(hidden_dim, hidden_dim), + edge_dim=hidden_dim, + edge_attn=True, + ) + + x = torch.randn(10, hidden_dim, requires_grad=True) + edge_index = torch.randint(0, 10, (2, 20)) + edge_attr = torch.randn(20, hidden_dim) + attn = torch.randn(20, 1) + + out, _ = conv(x, edge_index, edge_attr, attn=attn) + loss = out.sum() + loss.backward() + + self.assertIsNotNone(x.grad) + self.assertFalse(torch.all(x.grad == 0)) + + def test_reset_parameters(self): + """Test reset_parameters doesn't crash.""" + conv = BiAttentionGNNConv( + nn.Linear(32, 32), edge_dim=32, edge_attn=True + ) + conv.reset_parameters() # Should not raise + + +@unittest.skipUnless(HAS_TORCH_GEOMETRIC, "torch-geometric not installed") +class TestGraphCare(unittest.TestCase): + """Test cases for the GraphCare model.""" + + NUM_NODES = 200 + NUM_RELS = 30 + MAX_VISIT = 5 + EMBEDDING_DIM = 64 + HIDDEN_DIM = 64 + OUT_CHANNELS = 1 + BATCH_SIZE = 4 + + def _make_model(self, gnn="BAT", patient_mode="joint", **kwargs): + """Helper to create a GraphCare model with default test params.""" + defaults = dict( + num_nodes=self.NUM_NODES, + num_rels=self.NUM_RELS, + max_visit=self.MAX_VISIT, + embedding_dim=self.EMBEDDING_DIM, + hidden_dim=self.HIDDEN_DIM, + out_channels=self.OUT_CHANNELS, + layers=2, + dropout=0.5, + decay_rate=0.01, + node_emb=torch.randn(self.NUM_NODES, self.EMBEDDING_DIM), + rel_emb=torch.randn(self.NUM_RELS, self.EMBEDDING_DIM), + gnn=gnn, + patient_mode=patient_mode, + ) + defaults.update(kwargs) + return GraphCare(**defaults) + + def _make_batch(self): + """Helper to create fake batch data.""" + return _make_fake_batch( + batch_size=self.BATCH_SIZE, + num_nodes=self.NUM_NODES, + num_rels=self.NUM_RELS, + max_visit=self.MAX_VISIT, + ) + + # --- Output shape tests for all GNN × patient_mode combos --- + + def test_bat_joint_output_shape(self): + """Test BAT/joint produces correct output shape.""" + model = self._make_model(gnn="BAT", patient_mode="joint") + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, self.OUT_CHANNELS)) + + def test_bat_graph_output_shape(self): + """Test BAT/graph produces correct output shape.""" + model = self._make_model(gnn="BAT", patient_mode="graph") + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes=None) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, self.OUT_CHANNELS)) + + def test_bat_node_output_shape(self): + """Test BAT/node produces correct output shape.""" + model = self._make_model(gnn="BAT", patient_mode="node") + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, self.OUT_CHANNELS)) + + def test_gat_joint_output_shape(self): + """Test GAT/joint produces correct output shape.""" + model = self._make_model(gnn="GAT", patient_mode="joint") + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, self.OUT_CHANNELS)) + + def test_gin_joint_output_shape(self): + """Test GIN/joint produces correct output shape.""" + model = self._make_model(gnn="GIN", patient_mode="joint") + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, self.OUT_CHANNELS)) + + # --- Multi-label / multi-class output --- + + def test_multilabel_output(self): + """Test model works with multi-label output.""" + model = self._make_model(out_channels=50) + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, 50)) + + def test_multiclass_output(self): + """Test model works with multi-class (10-way) output.""" + model = self._make_model(out_channels=10) + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, 10)) + + # --- Backward pass --- + + def test_backward_pass_bat(self): + """Test gradients flow through full BAT model.""" + model = self._make_model(gnn="BAT", patient_mode="joint") + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.train() + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + loss = F.binary_cross_entropy_with_logits( + logits, torch.ones(self.BATCH_SIZE, self.OUT_CHANNELS) + ) + loss.backward() + + # Check at least some parameters have gradients + has_grad = any( + p.grad is not None and p.grad.abs().sum() > 0 + for p in model.parameters() + ) + self.assertTrue(has_grad) + + def test_backward_pass_gat(self): + """Test gradients flow through full GAT model.""" + model = self._make_model(gnn="GAT", patient_mode="joint") + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.train() + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + loss = F.binary_cross_entropy_with_logits( + logits, torch.ones(self.BATCH_SIZE, self.OUT_CHANNELS) + ) + loss.backward() + + def test_backward_pass_gin(self): + """Test gradients flow through full GIN model.""" + model = self._make_model(gnn="GIN", patient_mode="joint") + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.train() + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + loss = F.binary_cross_entropy_with_logits( + logits, torch.ones(self.BATCH_SIZE, self.OUT_CHANNELS) + ) + loss.backward() + + # --- Edge dropout --- + + def test_edge_dropout(self): + """Test edge dropout doesn't crash and produces valid output.""" + model = self._make_model(drop_rate=0.3) + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.train() + logits = model( + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, + in_drop=True, + ) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, self.OUT_CHANNELS)) + self.assertFalse(torch.isnan(logits).any()) + + def test_no_edge_dropout_at_eval(self): + """Test edge dropout is not applied during eval.""" + model = self._make_model(drop_rate=0.5) + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits1 = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + logits2 = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + # Without dropout, same input → same output + self.assertTrue(torch.allclose(logits1, logits2)) + + # --- store_attn --- + + def test_store_attn(self): + """Test store_attn returns attention weights.""" + num_layers = 2 + model = self._make_model(gnn="BAT", patient_mode="joint", layers=num_layers) + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + result = model( + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, + store_attn=True, + ) + + logits, alphas, betas, attns, edge_ws = result + + self.assertEqual(logits.shape, (self.BATCH_SIZE, self.OUT_CHANNELS)) + self.assertEqual(len(alphas), num_layers) + self.assertEqual(len(betas), num_layers) + self.assertEqual(len(attns), num_layers) + self.assertEqual(len(edge_ws), num_layers) + + # Alpha shape: (batch, max_visit, num_nodes) + self.assertEqual(alphas[0].shape, (self.BATCH_SIZE, self.MAX_VISIT, self.NUM_NODES)) + # Beta shape: (batch, max_visit, 1) + self.assertEqual(betas[0].shape, (self.BATCH_SIZE, self.MAX_VISIT, 1)) + + def test_store_attn_disabled(self): + """Test store_attn=False returns just logits.""" + model = self._make_model(gnn="BAT") + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + result = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + self.assertIsInstance(result, torch.Tensor) + + # --- Attention flags --- + + def test_no_alpha(self): + """Test model works with alpha attention disabled.""" + model = self._make_model(use_alpha=False, use_beta=True) + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, self.OUT_CHANNELS)) + + def test_no_beta(self): + """Test model works with beta attention disabled.""" + model = self._make_model(use_alpha=True, use_beta=False) + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, self.OUT_CHANNELS)) + + def test_no_alpha_no_beta(self): + """Test model works with both attentions disabled.""" + model = self._make_model(use_alpha=False, use_beta=False) + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, self.OUT_CHANNELS)) + + # --- Embedding options --- + + def test_learned_embeddings(self): + """Test model works with learned (not pre-trained) embeddings.""" + model = self._make_model(node_emb=None, rel_emb=None) + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, self.OUT_CHANNELS)) + + def test_frozen_embeddings(self): + """Test frozen pre-trained embeddings don't get gradients.""" + node_emb = torch.randn(self.NUM_NODES, self.EMBEDDING_DIM) + model = self._make_model(node_emb=node_emb, freeze=True) + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.train() + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + loss = logits.sum() + loss.backward() + + self.assertFalse(model.node_emb.weight.requires_grad) + + def test_attn_init(self): + """Test attention initialization with pre-computed weights.""" + attn_init = torch.randn(self.NUM_NODES) + model = self._make_model(attn_init=attn_init) + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes) + + self.assertEqual(logits.shape, (self.BATCH_SIZE, self.OUT_CHANNELS)) + + # --- Device movement --- + + def test_to_device(self): + """Test lambda_j is a registered buffer and moves with the model.""" + model = self._make_model() + model = model.to("cpu") + + # lambda_j should be a registered buffer, visible in state_dict + self.assertIn("lambda_j", model.state_dict()) + self.assertEqual(model.lambda_j.device, torch.device("cpu")) + + # --- Invalid inputs --- + + def test_invalid_gnn(self): + """Test invalid GNN type raises error.""" + with self.assertRaises(AssertionError): + self._make_model(gnn="INVALID") + + def test_invalid_patient_mode(self): + """Test invalid patient_mode raises error.""" + with self.assertRaises(AssertionError): + self._make_model(patient_mode="INVALID") + + # --- Parameter count --- + + def test_joint_has_more_params_than_graph(self): + """Joint mode should have more params due to wider MLP.""" + model_joint = self._make_model(patient_mode="joint") + model_graph = self._make_model(patient_mode="graph") + + params_joint = sum(p.numel() for p in model_joint.parameters()) + params_graph = sum(p.numel() for p in model_graph.parameters()) + + self.assertGreater(params_joint, params_graph) + + # --- Numerical sanity --- + + def test_no_nan_in_output(self): + """Test model output contains no NaN values.""" + for gnn in ["BAT", "GAT", "GIN"]: + for mode in ["joint", "graph", "node"]: + model = self._make_model(gnn=gnn, patient_mode=mode) + node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes, _ = self._make_batch() + + model.eval() + with torch.no_grad(): + logits = model( + node_ids, rel_ids, edge_index, batch_vec, + visit_node, + ehr_nodes if mode != "graph" else None, + ) + + self.assertFalse( + torch.isnan(logits).any(), + f"NaN in output for gnn={gnn}, mode={mode}", + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) \ No newline at end of file From ccf85761d8f575a3bc63c1d47534423a6961a5c8 Mon Sep 17 00:00:00 2001 From: Steier <637682@bah.com> Date: Sun, 8 Feb 2026 12:42:24 -0600 Subject: [PATCH 2/2] fix: move 'from torch import Tensor' outside torch_geometric try/except Tensor is from torch (always available), not torch_geometric. Having it inside the try block caused NameError at import time when torch_geometric is not installed, breaking the entire pyhealth.models import chain in CI. --- pyhealth/models/graphcare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/models/graphcare.py b/pyhealth/models/graphcare.py index e440d71e3..7b8a8108e 100644 --- a/pyhealth/models/graphcare.py +++ b/pyhealth/models/graphcare.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor logger = logging.getLogger(__name__) @@ -38,7 +39,6 @@ from torch_geometric.nn import global_mean_pool from torch_geometric.data import Data, Batch from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size - from torch import Tensor HAS_TORCH_GEOMETRIC = True except ImportError: