diff --git a/examples/foundation_ehr/multimodal_task.py b/examples/foundation_ehr/multimodal_task.py new file mode 100644 index 000000000..d4b87bb26 --- /dev/null +++ b/examples/foundation_ehr/multimodal_task.py @@ -0,0 +1,39 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional +import os + +# PyHealth Packages +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.tasks.ehr_foundational_model_mimic4 import EHRFoundationalModelMIMIC4 +from pyhealth.tasks.base_task import BaseTask + +# Load MIMIC4 Files +# There's probably better ways dealing with this on the cluster, but working locally for now +# (see: https://github.com/sunlabuiuc/PyHealth/blob/master/examples/mortality_prediction/multimodal_mimic4_minimal.py) + +PYHEALTH_REPO_ROOT = '/Users/wpang/Desktop/PyHealth' + +EHR_ROOT = os.path.join(PYHEALTH_REPO_ROOT, "srv/local/data/physionet.org/files/mimiciv/2.2") +NOTE_ROOT = os.path.join(PYHEALTH_REPO_ROOT, "srv/local/data/physionet.org/files/mimic-iv-note/2.2") +CXR_ROOT = os.path.join(PYHEALTH_REPO_ROOT,"srv/local/data/physionet.org/files/mimic-cxr-jpg/2.0.0") +CACHE_DIR = os.path.join(PYHEALTH_REPO_ROOT,"srv/local/data/wp/pyhealth_cache") + +if __name__ == "__main__": + + dataset = MIMIC4Dataset( + ehr_root=EHR_ROOT, + note_root=NOTE_ROOT, + ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"], + note_tables=["discharge", "radiology"], + cache_dir=CACHE_DIR, + num_workers=8, + # dev=True + ) + + # Apply multimodal task + task = EHRFoundationalModelMIMIC4() + samples = dataset.set_task(task, cache_dir=f"{CACHE_DIR}/task", num_workers=8) + + # Get and print sample + sample = samples[0] + print(sample) \ No newline at end of file diff --git a/pyhealth/processors/tuple_time_text_processor.py b/pyhealth/processors/tuple_time_text_processor.py index ecc59c030..6542fa153 100644 --- a/pyhealth/processors/tuple_time_text_processor.py +++ b/pyhealth/processors/tuple_time_text_processor.py @@ -9,10 +9,11 @@ - List[str]: Clinical text entries (e.g., discharge notes, progress notes) - List[float]: Time differences between entries (in any time unit) - Output: Tuple[List[str], torch.Tensor, str] - - List[str]: Same text entries (unmodified) - - torch.Tensor: 1D float tensor of time differences - - str: Type tag for automatic modality routing (default: "note") + Output: str (JSON) + JSON string with keys: + - "texts": List[str] - Same text entries (unmodified) + - "time_diffs": List[float] - Time differences (unmodified) + - "type_tag": str - Type tag for automatic modality routing (default: "note") Use Case: This processor enables automatic modality bucketing in multimodal pipelines. @@ -51,7 +52,7 @@ """ from typing import Any, List, Tuple -import torch +import json from .base_processor import FeatureProcessor from . import register_processor @@ -73,30 +74,23 @@ def __init__(self, type_tag: str = "note"): super().__init__() self.type_tag = type_tag - def process(self, value: Tuple[List[str], List[float]]) -> Tuple[List[str], torch.Tensor, str]: + def process(self, value: Tuple[List[str], List[float]]) -> str: """Process a tuple of texts and time differences. - + + Serializes the data as a JSON string so litdata can store it natively. + Downstream code should parse with json.loads() to recover the dict + with keys "texts", "time_diffs", and "type_tag". + Args: value: Tuple containing: - List[str]: Text entries (clinical notes, observations, etc.) - List[float]: Time differences corresponding to each text entry - + Returns: - Tuple containing: - - List[str]: Original text entries (unmodified) - - torch.Tensor: 1D float tensor of time differences [shape: (N,)] - - str: Type tag for modality routing - - Example: - >>> processor = TupleTimeTextProcessor(type_tag="clinical_note") - >>> texts = ["Note 1", "Note 2"] - >>> times = [0.0, 24.0] # hours - >>> result = processor.process((texts, times)) - >>> print(result[1]) # tensor([0., 24.]) + str: JSON string with keys "texts", "time_diffs", "type_tag". """ texts, time_diffs = value - time_tensor = torch.tensor(time_diffs, dtype=torch.float32) - return texts, time_tensor, self.type_tag + return json.dumps({"texts": texts, "time_diffs": time_diffs, "type_tag": self.type_tag}) def size(self): """Return the size of the processor vocabulary (not applicable for this processor).""" diff --git a/pyhealth/tasks/ehr_foundational_model_mimic4.py b/pyhealth/tasks/ehr_foundational_model_mimic4.py new file mode 100644 index 000000000..bc6be375f --- /dev/null +++ b/pyhealth/tasks/ehr_foundational_model_mimic4.py @@ -0,0 +1,140 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pyhealth.tasks.base_task import BaseTask + +class EHRFoundationalModelMIMIC4(BaseTask): + + task_name: str = "EHRFoundationalModelMIMIC4" + + def __init__(self): + """Initialize the EHR Foundational Model task.""" + self.input_schema: Dict[str, str] = { + "discharge_note_times": "tuple_time_text", + "radiology_note_times": "tuple_time_text", + } + self.output_schema: Dict[str, str] = {"mortality": "binary"} + + def _clean_text(self, text: Optional[str]) -> Optional[str]: + """Return text if non-empty, otherwise None.""" + return text if text else None + + def _compute_time_diffs(self, notes_with_timestamps, anchor_time=None): + # TODO: Add docstrings. + # anchor_time is in case we want it to normalize/center the time on admission time or something like that. + + if not notes_with_timestamps: # TODO: Maybe I should move this somewhere else as it's not relevant to time diffs + return ([""], [0.0]) # TODO: How should we handle notes with missing timestamps? + result = [] + for i, (text, timestamp) in enumerate(notes_with_timestamps): + if anchor_time is not None: + diff = (timestamp - anchor_time).total_seconds() / 3600 + elif i == 0: + diff = 0.0 + else: + diff = (timestamp - notes_with_timestamps[i - 1][1]).total_seconds() / 3600 + result.append((text, diff)) + texts, time_diffs = zip(*result) + return (list(texts), list(time_diffs)) + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + # Get demographic info to filter by age + demographics = patient.get_events(event_type="patients") + if not demographics: + return [] + + demographics = demographics[0] + + # Get visits + admissions = patient.get_events(event_type="admissions") + if len(admissions) == 0: + return [] + + # Determine which admissions to process iteratively + # Check each admission's NEXT admission for mortality flag + admissions_to_process = [] + mortality_label = 0 + + for i, admission in enumerate(admissions): + # Check if THIS admission has the death flag + if admission.hospital_expire_flag in [1, "1"]: + # Patient died in this admission - set mortality label + # but don't include this admission's data + mortality_label = 1 + break + + # Check if there's a next admission with death flag + if i + 1 < len(admissions): + next_admission = admissions[i + 1] + if next_admission.hospital_expire_flag in [1, "1"]: + # Next admission has death - include current, set mortality + admissions_to_process.append(admission) + mortality_label = 1 + break + + # No death in current or next - include this admission + admissions_to_process.append(admission) + + if len(admissions_to_process) == 0: + return [] + + # Get first admission time as reference for lab time calculations + first_admission_time = admissions_to_process[0].timestamp + + # Aggregated data across all admissions + all_discharge_notes_timestamped = [] # List of (note_text, timestamp) tuples + all_radiology_notes_timestamped = [] # List of (note_text, timestamp) tuples + + # Process each admission and aggregate data + for admission in admissions_to_process: + # Parse admission discharge time for lab events filtering + try: + admission_dischtime = datetime.strptime( + admission.dischtime, "%Y-%m-%d %H:%M:%S" + ) + except (ValueError, AttributeError): + # If we can't parse discharge time, skip this admission + continue + + # Skip if discharge is before admission (data quality issue) + if admission_dischtime < admission.timestamp: + continue + + # Get notes using hadm_id filtering + discharge_notes = patient.get_events( + event_type="discharge", filters=[("hadm_id", "==", admission.hadm_id)] + ) + radiology_notes = patient.get_events( + event_type="radiology", filters=[("hadm_id", "==", admission.hadm_id)] + ) + + # Extract and aggregate notes as individual items in lists + # Note: attribute is "text" (from mimic4_note.yaml), not "discharge"/"radiology" + for note in discharge_notes: + try: + note_text = self._clean_text(note.text) + if note_text: + all_discharge_notes_timestamped.append((note_text, note.timestamp)) + except AttributeError: + pass + + for note in radiology_notes: + try: + note_text = self._clean_text(note.text) + if note_text: + all_radiology_notes_timestamped.append((note_text, note.timestamp)) + except AttributeError: + pass + + # Convert (note_text, timestamp) tuples to (note_text, time_diff_hours) tuples + discharge_note_times = self._compute_time_diffs(all_discharge_notes_timestamped) + radiology_note_times = self._compute_time_diffs(all_radiology_notes_timestamped) + + return [ + { + "patient_id": patient.patient_id, + "discharge_note_times": discharge_note_times, + "radiology_note_times": radiology_note_times, + "mortality": mortality_label, + } + ] \ No newline at end of file