Skip to content
39 changes: 39 additions & 0 deletions examples/foundation_ehr/multimodal_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
import os

# PyHealth Packages
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks.ehr_foundational_model_mimic4 import EHRFoundationalModelMIMIC4
from pyhealth.tasks.base_task import BaseTask

# Load MIMIC4 Files
# There's probably better ways dealing with this on the cluster, but working locally for now
# (see: https://github.com/sunlabuiuc/PyHealth/blob/master/examples/mortality_prediction/multimodal_mimic4_minimal.py)

PYHEALTH_REPO_ROOT = '/Users/wpang/Desktop/PyHealth'

EHR_ROOT = os.path.join(PYHEALTH_REPO_ROOT, "srv/local/data/physionet.org/files/mimiciv/2.2")
NOTE_ROOT = os.path.join(PYHEALTH_REPO_ROOT, "srv/local/data/physionet.org/files/mimic-iv-note/2.2")
CXR_ROOT = os.path.join(PYHEALTH_REPO_ROOT,"srv/local/data/physionet.org/files/mimic-cxr-jpg/2.0.0")
CACHE_DIR = os.path.join(PYHEALTH_REPO_ROOT,"srv/local/data/wp/pyhealth_cache")

if __name__ == "__main__":

dataset = MIMIC4Dataset(
ehr_root=EHR_ROOT,
note_root=NOTE_ROOT,
ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"],
note_tables=["discharge", "radiology"],
cache_dir=CACHE_DIR,
num_workers=8,
# dev=True
)

# Apply multimodal task
task = EHRFoundationalModelMIMIC4()
samples = dataset.set_task(task, cache_dir=f"{CACHE_DIR}/task", num_workers=8)

# Get and print sample
sample = samples[0]
print(sample)
36 changes: 15 additions & 21 deletions pyhealth/processors/tuple_time_text_processor.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just do to get pass the serialization problems:
texts, time_diffs = value time_tensor = torch.tensor(time_diffs, dtype=torch.float32) return pickle.dumps(texts), time_tensor

However, I think we may need to think of a better approach here, one that includes a tokenizer in our processor here.

Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
- List[str]: Clinical text entries (e.g., discharge notes, progress notes)
- List[float]: Time differences between entries (in any time unit)

Output: Tuple[List[str], torch.Tensor, str]
- List[str]: Same text entries (unmodified)
- torch.Tensor: 1D float tensor of time differences
- str: Type tag for automatic modality routing (default: "note")
Output: str (JSON)
JSON string with keys:
- "texts": List[str] - Same text entries (unmodified)
- "time_diffs": List[float] - Time differences (unmodified)
- "type_tag": str - Type tag for automatic modality routing (default: "note")

Use Case:
This processor enables automatic modality bucketing in multimodal pipelines.
Expand Down Expand Up @@ -51,7 +52,7 @@
"""

from typing import Any, List, Tuple
import torch
import json
from .base_processor import FeatureProcessor
from . import register_processor

Expand All @@ -73,30 +74,23 @@ def __init__(self, type_tag: str = "note"):
super().__init__()
self.type_tag = type_tag

def process(self, value: Tuple[List[str], List[float]]) -> Tuple[List[str], torch.Tensor, str]:
def process(self, value: Tuple[List[str], List[float]]) -> str:
"""Process a tuple of texts and time differences.


Serializes the data as a JSON string so litdata can store it natively.
Downstream code should parse with json.loads() to recover the dict
with keys "texts", "time_diffs", and "type_tag".

Args:
value: Tuple containing:
- List[str]: Text entries (clinical notes, observations, etc.)
- List[float]: Time differences corresponding to each text entry

Returns:
Tuple containing:
- List[str]: Original text entries (unmodified)
- torch.Tensor: 1D float tensor of time differences [shape: (N,)]
- str: Type tag for modality routing

Example:
>>> processor = TupleTimeTextProcessor(type_tag="clinical_note")
>>> texts = ["Note 1", "Note 2"]
>>> times = [0.0, 24.0] # hours
>>> result = processor.process((texts, times))
>>> print(result[1]) # tensor([0., 24.])
str: JSON string with keys "texts", "time_diffs", "type_tag".
"""
texts, time_diffs = value
time_tensor = torch.tensor(time_diffs, dtype=torch.float32)
return texts, time_tensor, self.type_tag
return json.dumps({"texts": texts, "time_diffs": time_diffs, "type_tag": self.type_tag})

def size(self):
"""Return the size of the processor vocabulary (not applicable for this processor)."""
Expand Down
140 changes: 140 additions & 0 deletions pyhealth/tasks/ehr_foundational_model_mimic4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from datetime import datetime
from typing import Any, Dict, List, Optional

from pyhealth.tasks.base_task import BaseTask

class EHRFoundationalModelMIMIC4(BaseTask):

task_name: str = "EHRFoundationalModelMIMIC4"

def __init__(self):
"""Initialize the EHR Foundational Model task."""
self.input_schema: Dict[str, str] = {
"discharge_note_times": "tuple_time_text",
"radiology_note_times": "tuple_time_text",
}
self.output_schema: Dict[str, str] = {"mortality": "binary"}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found another bug, mortality should be a "binary" variable for our purposes here.

Copy link
Contributor Author

@will-pang will-pang Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, haha good catch! I made it into regression for testing because I was getting this error:

Traceback (most recent call last):
  File "/Users/wpang/Desktop/PyHealth/examples/foundation_ehr/multimodal_task.py", line 35, in <module>
    samples = dataset.set_task(task, cache_dir=f"{CACHE_DIR}/task", num_workers=8)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wpang/Desktop/PyHealth/pyhealth/datasets/base_dataset.py", line 936, in set_task
    builder.fit(dataset)
  File "/Users/wpang/Desktop/PyHealth/pyhealth/datasets/sample_dataset.py", line 167, in fit
    processor.fit(samples, key)
  File "/Users/wpang/Desktop/PyHealth/pyhealth/processors/label_processor.py", line 25, in fit
    raise ValueError(f"Expected 2 unique labels, got {len(all_labels)}")
ValueError: Expected 2 unique labels, got 1

I wonder if it was because in the dev sample I was working with haha all the patients had a mortality label of 0, so for testing purposes I just made it a regression task. I've made the update!

def _clean_text(self, text: Optional[str]) -> Optional[str]:
"""Return text if non-empty, otherwise None."""
return text if text else None

def _compute_time_diffs(self, notes_with_timestamps, anchor_time=None):
# TODO: Add docstrings.
# anchor_time is in case we want it to normalize/center the time on admission time or something like that.

if not notes_with_timestamps: # TODO: Maybe I should move this somewhere else as it's not relevant to time diffs
return (["<missing>"], [0.0]) # TODO: How should we handle notes with missing timestamps?
result = []
for i, (text, timestamp) in enumerate(notes_with_timestamps):
if anchor_time is not None:
diff = (timestamp - anchor_time).total_seconds() / 3600
elif i == 0:
diff = 0.0
else:
diff = (timestamp - notes_with_timestamps[i - 1][1]).total_seconds() / 3600
result.append((text, diff))
texts, time_diffs = zip(*result)
return (list(texts), list(time_diffs))

def __call__(self, patient: Any) -> List[Dict[str, Any]]:
# Get demographic info to filter by age
demographics = patient.get_events(event_type="patients")
if not demographics:
return []

demographics = demographics[0]

# Get visits
admissions = patient.get_events(event_type="admissions")
if len(admissions) == 0:
return []

# Determine which admissions to process iteratively
# Check each admission's NEXT admission for mortality flag
admissions_to_process = []
mortality_label = 0

for i, admission in enumerate(admissions):
# Check if THIS admission has the death flag
if admission.hospital_expire_flag in [1, "1"]:
# Patient died in this admission - set mortality label
# but don't include this admission's data
mortality_label = 1
break

# Check if there's a next admission with death flag
if i + 1 < len(admissions):
next_admission = admissions[i + 1]
if next_admission.hospital_expire_flag in [1, "1"]:
# Next admission has death - include current, set mortality
admissions_to_process.append(admission)
mortality_label = 1
break

# No death in current or next - include this admission
admissions_to_process.append(admission)

if len(admissions_to_process) == 0:
return []

# Get first admission time as reference for lab time calculations
first_admission_time = admissions_to_process[0].timestamp

# Aggregated data across all admissions
all_discharge_notes_timestamped = [] # List of (note_text, timestamp) tuples
all_radiology_notes_timestamped = [] # List of (note_text, timestamp) tuples

# Process each admission and aggregate data
for admission in admissions_to_process:
# Parse admission discharge time for lab events filtering
try:
admission_dischtime = datetime.strptime(
admission.dischtime, "%Y-%m-%d %H:%M:%S"
)
except (ValueError, AttributeError):
# If we can't parse discharge time, skip this admission
continue

# Skip if discharge is before admission (data quality issue)
if admission_dischtime < admission.timestamp:
continue

# Get notes using hadm_id filtering
discharge_notes = patient.get_events(
event_type="discharge", filters=[("hadm_id", "==", admission.hadm_id)]
)
radiology_notes = patient.get_events(
event_type="radiology", filters=[("hadm_id", "==", admission.hadm_id)]
)

# Extract and aggregate notes as individual items in lists
# Note: attribute is "text" (from mimic4_note.yaml), not "discharge"/"radiology"
for note in discharge_notes:
try:
note_text = self._clean_text(note.text)
if note_text:
all_discharge_notes_timestamped.append((note_text, note.timestamp))
except AttributeError:
pass

for note in radiology_notes:
try:
note_text = self._clean_text(note.text)
if note_text:
all_radiology_notes_timestamped.append((note_text, note.timestamp))
except AttributeError:
pass

# Convert (note_text, timestamp) tuples to (note_text, time_diff_hours) tuples
discharge_note_times = self._compute_time_diffs(all_discharge_notes_timestamped)
radiology_note_times = self._compute_time_diffs(all_radiology_notes_timestamped)

return [
{
"patient_id": patient.patient_id,
"discharge_note_times": discharge_note_times,
"radiology_note_times": radiology_note_times,
"mortality": mortality_label,
}
]