From 5bdaa76393ffccc550ac9db507df78c2d13700ac Mon Sep 17 00:00:00 2001 From: William Pang Date: Tue, 10 Feb 2026 16:03:29 -0800 Subject: [PATCH 01/11] Create ehr_foundation_model task --- .../tasks/ehr_foundational_model_mimic4.py | 156 ++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 pyhealth/tasks/ehr_foundational_model_mimic4.py diff --git a/pyhealth/tasks/ehr_foundational_model_mimic4.py b/pyhealth/tasks/ehr_foundational_model_mimic4.py new file mode 100644 index 000000000..2841bf318 --- /dev/null +++ b/pyhealth/tasks/ehr_foundational_model_mimic4.py @@ -0,0 +1,156 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pyhealth.tasks.base_task import BaseTask + +class EHRFoundationalModelMIMIC4(BaseTask): + + task_name: str = "EHRFoundationalModelMIMIC4" + + def __init__(self): + """Initialize the EHR Foundational Model task.""" + self.input_schema: Dict[str, str] = { + "discharge": "raw", + "radiology": "raw", + "discharge_note_time_diffs": "tensor", + "radiology_note_time_diffs": "tensor", + } + self.output_schema: Dict[str, str] = {"mortality": "regression"} + + def _clean_text(self, text: Optional[str]) -> Optional[str]: + """Return text if non-empty, otherwise None.""" + return text if text else None + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + # Get demographic info to filter by age + demographics = patient.get_events(event_type="patients") + if not demographics: + return [] + + demographics = demographics[0] + + # Get visits + admissions = patient.get_events(event_type="admissions") + if len(admissions) == 0: + return [] + + # Determine which admissions to process iteratively + # Check each admission's NEXT admission for mortality flag + admissions_to_process = [] + mortality_label = 0 + + for i, admission in enumerate(admissions): + # Check if THIS admission has the death flag + if admission.hospital_expire_flag in [1, "1"]: + # Patient died in this admission - set mortality label + # but don't include this admission's data + mortality_label = 1 + break + + # Check if there's a next admission with death flag + if i + 1 < len(admissions): + next_admission = admissions[i + 1] + if next_admission.hospital_expire_flag in [1, "1"]: + # Next admission has death - include current, set mortality + admissions_to_process.append(admission) + mortality_label = 1 + break + + # No death in current or next - include this admission + admissions_to_process.append(admission) + + if len(admissions_to_process) == 0: + return [] + + # Get first admission time as reference for lab time calculations + first_admission_time = admissions_to_process[0].timestamp + + # Aggregated data across all admissions + all_discharge_notes = [] # List of individual discharge notes + all_radiology_notes = [] # List of individual radiology notes + all_discharge_notes_timestamps = [] # List of individual discharge notes timestamps + all_radiology_notes_timestamps = [] # List of individual discharge notes timestamps + discharge_note_time_diffs = [] + radiology_notes_time_diffs = [] + + + # Process each admission and aggregate data + for admission in admissions_to_process: + # Parse admission discharge time for lab events filtering + try: + admission_dischtime = datetime.strptime( + admission.dischtime, "%Y-%m-%d %H:%M:%S" + ) + except (ValueError, AttributeError): + # If we can't parse discharge time, skip this admission + continue + + # Skip if discharge is before admission (data quality issue) + if admission_dischtime < admission.timestamp: + continue + + # Get notes using hadm_id filtering + discharge_notes = patient.get_events( + event_type="discharge", filters=[("hadm_id", "==", admission.hadm_id)] + ) + radiology_notes = patient.get_events( + event_type="radiology", filters=[("hadm_id", "==", admission.hadm_id)] + ) + + # Extract and aggregate notes as individual items in lists + # Note: attribute is "text" (from mimic4_note.yaml), not "discharge"/"radiology" + for note in discharge_notes: + try: + note_text = self._clean_text(note.text) + if note_text: + all_discharge_notes.append(note_text) + all_discharge_notes_timestamps.append(note.timestamp) + except AttributeError: + pass + + for note in radiology_notes: + try: + note_text = self._clean_text(note.text) + if note_text: + all_radiology_notes.append(note_text) + all_radiology_notes_timestamps.append(note.timestamp) + except AttributeError: + pass + + # Sort discharge_notes by timestamp + all_discharge_notes_timestamps.sort() + all_radiology_notes_timestamps.sort() + + # Compute time difference for discharge notes (hours) + discharge_note_time_diffs = [0.0] + [ + (curr - prev).total_seconds() / 3600 + for prev, curr in zip(all_discharge_notes_timestamps, all_discharge_notes_timestamps[1:]) + ] + + # Compute time difference for radiology notes (hours) + radiology_note_time_diffs = [0.0] + [ + (curr - prev).total_seconds() / 3600 + for prev, curr in zip(all_radiology_notes_timestamps, all_radiology_notes_timestamps[1:]) + ] + + # ===== MODALITY REQUIREMENTS ===== + # Check notes - need at least one discharge OR radiology note + has_notes = len(all_discharge_notes) > 0 or len(all_radiology_notes) > 0 + + #Return empty list if any required modality is missing + if not ( + has_notes + ): + return [] + + + return [ + { + "patient_id": patient.patient_id, + "discharge": all_discharge_notes, + "discharge_note_time_diffs": discharge_note_time_diffs, + "radiology": all_radiology_notes, + "radiology_note_time_diffs": radiology_note_time_diffs, + "mortality": mortality_label, + } + ] \ No newline at end of file From 5a49febc19e4cbdd8dfe551325ca41bdfdb6af3d Mon Sep 17 00:00:00 2001 From: William Pang Date: Wed, 11 Feb 2026 06:14:56 -0800 Subject: [PATCH 02/11] Add example for testing --- examples/foundation_ehr/multimodal_task.py | 39 ++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 examples/foundation_ehr/multimodal_task.py diff --git a/examples/foundation_ehr/multimodal_task.py b/examples/foundation_ehr/multimodal_task.py new file mode 100644 index 000000000..9a52be423 --- /dev/null +++ b/examples/foundation_ehr/multimodal_task.py @@ -0,0 +1,39 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional +import os + +# PyHealth Packages +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.tasks.ehr_foundational_model_mimic4 import EHRFoundationalModelMIMIC4 +from pyhealth.tasks.base_task import BaseTask + +# Load MIMIC4 Files +# There's probably better ways dealing with this on the cluster, but working locally for now +# (see: https://github.com/sunlabuiuc/PyHealth/blob/master/examples/mortality_prediction/multimodal_mimic4_minimal.py) + +PYHEALTH_REPO_ROOT = #'/Users/wpang/Desktop/PyHealth' + +EHR_ROOT = os.path.join(PYHEALTH_REPO_ROOT, "srv/local/data/physionet.org/files/mimiciv/2.2") +NOTE_ROOT = os.path.join(PYHEALTH_REPO_ROOT, "srv/local/data/physionet.org/files/mimic-iv-note/2.2") +CXR_ROOT = os.path.join(PYHEALTH_REPO_ROOT,"srv/local/data/physionet.org/files/mimic-cxr-jpg/2.0.0") +CACHE_DIR = os.path.join(PYHEALTH_REPO_ROOT,"srv/local/data/wp/pyhealth_cache") + +if __name__ == "__main__": + + dataset = MIMIC4Dataset( + ehr_root=EHR_ROOT, + note_root=NOTE_ROOT, + ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"], + note_tables=["discharge", "radiology"], + cache_dir=CACHE_DIR, + num_workers=16, + dev=True + ) + + # Apply multimodal task + task = EHRFoundationalModelMIMIC4() + samples = dataset.set_task(task, cache_dir=f"{CACHE_DIR}/task", num_workers=8) + + # Get and print sample + sample = samples[0] + print(sample) \ No newline at end of file From 97773117de8062e6ecb3a9a5fd6b4a4c2c3b7029 Mon Sep 17 00:00:00 2001 From: William Pang Date: Wed, 11 Feb 2026 06:27:52 -0800 Subject: [PATCH 03/11] Update ehr_foundational_model_mimic4.py --- pyhealth/tasks/ehr_foundational_model_mimic4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/tasks/ehr_foundational_model_mimic4.py b/pyhealth/tasks/ehr_foundational_model_mimic4.py index 2841bf318..ef40cc5f9 100644 --- a/pyhealth/tasks/ehr_foundational_model_mimic4.py +++ b/pyhealth/tasks/ehr_foundational_model_mimic4.py @@ -15,7 +15,7 @@ def __init__(self): "discharge_note_time_diffs": "tensor", "radiology_note_time_diffs": "tensor", } - self.output_schema: Dict[str, str] = {"mortality": "regression"} + self.output_schema: Dict[str, str] = {"mortality": "binary"} def _clean_text(self, text: Optional[str]) -> Optional[str]: """Return text if non-empty, otherwise None.""" From 60641c099031d6db6ccb381005c6e20d58b6b9fe Mon Sep 17 00:00:00 2001 From: William Pang Date: Wed, 11 Feb 2026 19:31:35 -0800 Subject: [PATCH 04/11] Update ehr_foundational_model_mimic4.py --- pyhealth/tasks/ehr_foundational_model_mimic4.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/pyhealth/tasks/ehr_foundational_model_mimic4.py b/pyhealth/tasks/ehr_foundational_model_mimic4.py index ef40cc5f9..4f9ec2f50 100644 --- a/pyhealth/tasks/ehr_foundational_model_mimic4.py +++ b/pyhealth/tasks/ehr_foundational_model_mimic4.py @@ -10,10 +10,8 @@ class EHRFoundationalModelMIMIC4(BaseTask): def __init__(self): """Initialize the EHR Foundational Model task.""" self.input_schema: Dict[str, str] = { - "discharge": "raw", - "radiology": "raw", - "discharge_note_time_diffs": "tensor", - "radiology_note_time_diffs": "tensor", + "discharge_note_times": "tuple_time_text", + "radiology_note_times": "tuple_time_text", } self.output_schema: Dict[str, str] = {"mortality": "binary"} @@ -69,9 +67,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: all_discharge_notes = [] # List of individual discharge notes all_radiology_notes = [] # List of individual radiology notes all_discharge_notes_timestamps = [] # List of individual discharge notes timestamps - all_radiology_notes_timestamps = [] # List of individual discharge notes timestamps - discharge_note_time_diffs = [] - radiology_notes_time_diffs = [] + all_radiology_notes_timestamps = [] # List of individual radiology notes timestamps # Process each admission and aggregate data @@ -147,10 +143,8 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: return [ { "patient_id": patient.patient_id, - "discharge": all_discharge_notes, - "discharge_note_time_diffs": discharge_note_time_diffs, - "radiology": all_radiology_notes, - "radiology_note_time_diffs": radiology_note_time_diffs, + "discharge_note_times": (all_discharge_notes, discharge_note_time_diffs), + "radiology_note_times": (all_radiology_notes, radiology_note_time_diffs), "mortality": mortality_label, } ] \ No newline at end of file From e04f54e123e1bfcb93e3befeddab343a9a6425c6 Mon Sep 17 00:00:00 2001 From: William Pang Date: Wed, 11 Feb 2026 20:19:03 -0800 Subject: [PATCH 05/11] Update ehr_foundational_model_mimic4.py --- .../tasks/ehr_foundational_model_mimic4.py | 54 +++++++++---------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/pyhealth/tasks/ehr_foundational_model_mimic4.py b/pyhealth/tasks/ehr_foundational_model_mimic4.py index 4f9ec2f50..036dd6a2a 100644 --- a/pyhealth/tasks/ehr_foundational_model_mimic4.py +++ b/pyhealth/tasks/ehr_foundational_model_mimic4.py @@ -13,12 +13,27 @@ def __init__(self): "discharge_note_times": "tuple_time_text", "radiology_note_times": "tuple_time_text", } - self.output_schema: Dict[str, str] = {"mortality": "binary"} + self.output_schema: Dict[str, str] = {"mortality": "regression"} def _clean_text(self, text: Optional[str]) -> Optional[str]: """Return text if non-empty, otherwise None.""" return text if text else None + def _compute_time_diffs(self, notes_with_timestamps, anchor_time=None): + if not notes_with_timestamps: + return ([], []) + result = [] + for i, (text, timestamp) in enumerate(notes_with_timestamps): + if anchor_time is not None: + diff = (timestamp - anchor_time).total_seconds() / 3600 + elif i == 0: + diff = 0.0 + else: + diff = (timestamp - notes_with_timestamps[i - 1][1]).total_seconds() / 3600 + result.append((text, diff)) + texts, time_diffs = zip(*result) + return (list(texts), list(time_diffs)) + def __call__(self, patient: Any) -> List[Dict[str, Any]]: # Get demographic info to filter by age demographics = patient.get_events(event_type="patients") @@ -64,11 +79,8 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: first_admission_time = admissions_to_process[0].timestamp # Aggregated data across all admissions - all_discharge_notes = [] # List of individual discharge notes - all_radiology_notes = [] # List of individual radiology notes - all_discharge_notes_timestamps = [] # List of individual discharge notes timestamps - all_radiology_notes_timestamps = [] # List of individual radiology notes timestamps - + all_discharge_notes_timestamped = [] # List of (note_text, timestamp) tuples + all_radiology_notes_timestamped = [] # List of (note_text, timestamp) tuples # Process each admission and aggregate data for admission in admissions_to_process: @@ -99,8 +111,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: try: note_text = self._clean_text(note.text) if note_text: - all_discharge_notes.append(note_text) - all_discharge_notes_timestamps.append(note.timestamp) + all_discharge_notes_timestamped.append((note_text, note.timestamp)) except AttributeError: pass @@ -108,30 +119,17 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: try: note_text = self._clean_text(note.text) if note_text: - all_radiology_notes.append(note_text) - all_radiology_notes_timestamps.append(note.timestamp) + all_radiology_notes_timestamped.append((note_text, note.timestamp)) except AttributeError: pass - # Sort discharge_notes by timestamp - all_discharge_notes_timestamps.sort() - all_radiology_notes_timestamps.sort() - - # Compute time difference for discharge notes (hours) - discharge_note_time_diffs = [0.0] + [ - (curr - prev).total_seconds() / 3600 - for prev, curr in zip(all_discharge_notes_timestamps, all_discharge_notes_timestamps[1:]) - ] - - # Compute time difference for radiology notes (hours) - radiology_note_time_diffs = [0.0] + [ - (curr - prev).total_seconds() / 3600 - for prev, curr in zip(all_radiology_notes_timestamps, all_radiology_notes_timestamps[1:]) - ] + # Convert (note_text, timestamp) tuples to (note_text, time_diff_hours) tuples + discharge_note_time_diffs = self._compute_time_diffs(all_discharge_notes_timestamped) + radiology_note_time_diffs = self._compute_time_diffs(all_radiology_notes_timestamped) # ===== MODALITY REQUIREMENTS ===== # Check notes - need at least one discharge OR radiology note - has_notes = len(all_discharge_notes) > 0 or len(all_radiology_notes) > 0 + has_notes = len(discharge_note_time_diffs) > 0 or len(radiology_note_time_diffs) > 0 #Return empty list if any required modality is missing if not ( @@ -143,8 +141,8 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: return [ { "patient_id": patient.patient_id, - "discharge_note_times": (all_discharge_notes, discharge_note_time_diffs), - "radiology_note_times": (all_radiology_notes, radiology_note_time_diffs), + "discharge_note_times": discharge_note_time_diffs, + "radiology_note_times": radiology_note_time_diffs, "mortality": mortality_label, } ] \ No newline at end of file From 13e46f04801f85e555a849b349bb9b395ca7d7fe Mon Sep 17 00:00:00 2001 From: William Pang Date: Thu, 12 Feb 2026 07:41:33 -0800 Subject: [PATCH 06/11] Add handling of missing notes --- .../tasks/ehr_foundational_model_mimic4.py | 25 ++++++------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/pyhealth/tasks/ehr_foundational_model_mimic4.py b/pyhealth/tasks/ehr_foundational_model_mimic4.py index 036dd6a2a..78097965e 100644 --- a/pyhealth/tasks/ehr_foundational_model_mimic4.py +++ b/pyhealth/tasks/ehr_foundational_model_mimic4.py @@ -19,9 +19,9 @@ def _clean_text(self, text: Optional[str]) -> Optional[str]: """Return text if non-empty, otherwise None.""" return text if text else None - def _compute_time_diffs(self, notes_with_timestamps, anchor_time=None): - if not notes_with_timestamps: - return ([], []) + def _compute_time_diffs(self, notes_with_timestamps, anchor_time=None): + if not notes_with_timestamps: # TODO: Maybe I should move this somewhere else as it's not relevant to time diffs + return ([""], [0.0]) # TODO: How should we handle notes with missing timestamps? result = [] for i, (text, timestamp) in enumerate(notes_with_timestamps): if anchor_time is not None: @@ -124,25 +124,14 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: pass # Convert (note_text, timestamp) tuples to (note_text, time_diff_hours) tuples - discharge_note_time_diffs = self._compute_time_diffs(all_discharge_notes_timestamped) - radiology_note_time_diffs = self._compute_time_diffs(all_radiology_notes_timestamped) - - # ===== MODALITY REQUIREMENTS ===== - # Check notes - need at least one discharge OR radiology note - has_notes = len(discharge_note_time_diffs) > 0 or len(radiology_note_time_diffs) > 0 - - #Return empty list if any required modality is missing - if not ( - has_notes - ): - return [] - + discharge_note_times = self._compute_time_diffs(all_discharge_notes_timestamped) + radiology_note_times = self._compute_time_diffs(all_radiology_notes_timestamped) return [ { "patient_id": patient.patient_id, - "discharge_note_times": discharge_note_time_diffs, - "radiology_note_times": radiology_note_time_diffs, + "discharge_note_times": discharge_note_times, + "radiology_note_times": radiology_note_times, "mortality": mortality_label, } ] \ No newline at end of file From f456e5352948c31d1faf4c2d8749260fb0769faa Mon Sep 17 00:00:00 2001 From: William Pang Date: Thu, 12 Feb 2026 08:08:03 -0800 Subject: [PATCH 07/11] Update ehr_foundational_model_mimic4.py --- pyhealth/tasks/ehr_foundational_model_mimic4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyhealth/tasks/ehr_foundational_model_mimic4.py b/pyhealth/tasks/ehr_foundational_model_mimic4.py index 78097965e..382d38d9d 100644 --- a/pyhealth/tasks/ehr_foundational_model_mimic4.py +++ b/pyhealth/tasks/ehr_foundational_model_mimic4.py @@ -13,13 +13,14 @@ def __init__(self): "discharge_note_times": "tuple_time_text", "radiology_note_times": "tuple_time_text", } - self.output_schema: Dict[str, str] = {"mortality": "regression"} + self.output_schema: Dict[str, str] = {"mortality": "binary"} def _clean_text(self, text: Optional[str]) -> Optional[str]: """Return text if non-empty, otherwise None.""" return text if text else None def _compute_time_diffs(self, notes_with_timestamps, anchor_time=None): + # TODO: Add docstrings. anchor_time is in case we want it to normalize the time on admission time or something like that. if not notes_with_timestamps: # TODO: Maybe I should move this somewhere else as it's not relevant to time diffs return ([""], [0.0]) # TODO: How should we handle notes with missing timestamps? result = [] From 9eea8fe814fc0427433008124340cf42808e63f3 Mon Sep 17 00:00:00 2001 From: William Pang Date: Thu, 12 Feb 2026 08:09:27 -0800 Subject: [PATCH 08/11] update comments --- pyhealth/tasks/ehr_foundational_model_mimic4.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyhealth/tasks/ehr_foundational_model_mimic4.py b/pyhealth/tasks/ehr_foundational_model_mimic4.py index 382d38d9d..7c29e35d8 100644 --- a/pyhealth/tasks/ehr_foundational_model_mimic4.py +++ b/pyhealth/tasks/ehr_foundational_model_mimic4.py @@ -20,7 +20,9 @@ def _clean_text(self, text: Optional[str]) -> Optional[str]: return text if text else None def _compute_time_diffs(self, notes_with_timestamps, anchor_time=None): - # TODO: Add docstrings. anchor_time is in case we want it to normalize the time on admission time or something like that. + # TODO: Add docstrings. + # anchor_time is in case we want it to normalize the time on admission time or something like that. + if not notes_with_timestamps: # TODO: Maybe I should move this somewhere else as it's not relevant to time diffs return ([""], [0.0]) # TODO: How should we handle notes with missing timestamps? result = [] From eafd9297a1734c4724ee6c74ee77932ea82aef09 Mon Sep 17 00:00:00 2001 From: William Pang Date: Thu, 12 Feb 2026 08:09:43 -0800 Subject: [PATCH 09/11] update comments --- pyhealth/tasks/ehr_foundational_model_mimic4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/tasks/ehr_foundational_model_mimic4.py b/pyhealth/tasks/ehr_foundational_model_mimic4.py index 7c29e35d8..bc6be375f 100644 --- a/pyhealth/tasks/ehr_foundational_model_mimic4.py +++ b/pyhealth/tasks/ehr_foundational_model_mimic4.py @@ -21,7 +21,7 @@ def _clean_text(self, text: Optional[str]) -> Optional[str]: def _compute_time_diffs(self, notes_with_timestamps, anchor_time=None): # TODO: Add docstrings. - # anchor_time is in case we want it to normalize the time on admission time or something like that. + # anchor_time is in case we want it to normalize/center the time on admission time or something like that. if not notes_with_timestamps: # TODO: Maybe I should move this somewhere else as it's not relevant to time diffs return ([""], [0.0]) # TODO: How should we handle notes with missing timestamps? From fe53e89a1fca3ecdde3017af90abee2ea546ed70 Mon Sep 17 00:00:00 2001 From: William Pang Date: Thu, 12 Feb 2026 12:13:41 -0800 Subject: [PATCH 10/11] Update tuple_time_text_processor.py --- .../processors/tuple_time_text_processor.py | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/pyhealth/processors/tuple_time_text_processor.py b/pyhealth/processors/tuple_time_text_processor.py index ecc59c030..6542fa153 100644 --- a/pyhealth/processors/tuple_time_text_processor.py +++ b/pyhealth/processors/tuple_time_text_processor.py @@ -9,10 +9,11 @@ - List[str]: Clinical text entries (e.g., discharge notes, progress notes) - List[float]: Time differences between entries (in any time unit) - Output: Tuple[List[str], torch.Tensor, str] - - List[str]: Same text entries (unmodified) - - torch.Tensor: 1D float tensor of time differences - - str: Type tag for automatic modality routing (default: "note") + Output: str (JSON) + JSON string with keys: + - "texts": List[str] - Same text entries (unmodified) + - "time_diffs": List[float] - Time differences (unmodified) + - "type_tag": str - Type tag for automatic modality routing (default: "note") Use Case: This processor enables automatic modality bucketing in multimodal pipelines. @@ -51,7 +52,7 @@ """ from typing import Any, List, Tuple -import torch +import json from .base_processor import FeatureProcessor from . import register_processor @@ -73,30 +74,23 @@ def __init__(self, type_tag: str = "note"): super().__init__() self.type_tag = type_tag - def process(self, value: Tuple[List[str], List[float]]) -> Tuple[List[str], torch.Tensor, str]: + def process(self, value: Tuple[List[str], List[float]]) -> str: """Process a tuple of texts and time differences. - + + Serializes the data as a JSON string so litdata can store it natively. + Downstream code should parse with json.loads() to recover the dict + with keys "texts", "time_diffs", and "type_tag". + Args: value: Tuple containing: - List[str]: Text entries (clinical notes, observations, etc.) - List[float]: Time differences corresponding to each text entry - + Returns: - Tuple containing: - - List[str]: Original text entries (unmodified) - - torch.Tensor: 1D float tensor of time differences [shape: (N,)] - - str: Type tag for modality routing - - Example: - >>> processor = TupleTimeTextProcessor(type_tag="clinical_note") - >>> texts = ["Note 1", "Note 2"] - >>> times = [0.0, 24.0] # hours - >>> result = processor.process((texts, times)) - >>> print(result[1]) # tensor([0., 24.]) + str: JSON string with keys "texts", "time_diffs", "type_tag". """ texts, time_diffs = value - time_tensor = torch.tensor(time_diffs, dtype=torch.float32) - return texts, time_tensor, self.type_tag + return json.dumps({"texts": texts, "time_diffs": time_diffs, "type_tag": self.type_tag}) def size(self): """Return the size of the processor vocabulary (not applicable for this processor).""" From 0e1df7711b9b240f1ae0673a112b39e6a440e564 Mon Sep 17 00:00:00 2001 From: William Pang Date: Thu, 12 Feb 2026 12:23:42 -0800 Subject: [PATCH 11/11] Update multimodal_task.py --- examples/foundation_ehr/multimodal_task.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/foundation_ehr/multimodal_task.py b/examples/foundation_ehr/multimodal_task.py index 9a52be423..d4b87bb26 100644 --- a/examples/foundation_ehr/multimodal_task.py +++ b/examples/foundation_ehr/multimodal_task.py @@ -11,7 +11,7 @@ # There's probably better ways dealing with this on the cluster, but working locally for now # (see: https://github.com/sunlabuiuc/PyHealth/blob/master/examples/mortality_prediction/multimodal_mimic4_minimal.py) -PYHEALTH_REPO_ROOT = #'/Users/wpang/Desktop/PyHealth' +PYHEALTH_REPO_ROOT = '/Users/wpang/Desktop/PyHealth' EHR_ROOT = os.path.join(PYHEALTH_REPO_ROOT, "srv/local/data/physionet.org/files/mimiciv/2.2") NOTE_ROOT = os.path.join(PYHEALTH_REPO_ROOT, "srv/local/data/physionet.org/files/mimic-iv-note/2.2") @@ -26,8 +26,8 @@ ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"], note_tables=["discharge", "radiology"], cache_dir=CACHE_DIR, - num_workers=16, - dev=True + num_workers=8, + # dev=True ) # Apply multimodal task