-
Notifications
You must be signed in to change notification settings - Fork 564
[Contribution] Create ehr_foundation_model task #840
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
5bdaa76
5a49feb
9777311
6ecf289
60641c0
e04f54e
13e46f0
f456e53
9eea8fe
eafd929
fe53e89
0e1df77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| from datetime import datetime | ||
| from typing import Any, Dict, List, Optional | ||
| import os | ||
|
|
||
| # PyHealth Packages | ||
| from pyhealth.datasets import MIMIC4Dataset | ||
| from pyhealth.tasks.ehr_foundational_model_mimic4 import EHRFoundationalModelMIMIC4 | ||
| from pyhealth.tasks.base_task import BaseTask | ||
|
|
||
| # Load MIMIC4 Files | ||
| # There's probably better ways dealing with this on the cluster, but working locally for now | ||
| # (see: https://github.com/sunlabuiuc/PyHealth/blob/master/examples/mortality_prediction/multimodal_mimic4_minimal.py) | ||
|
|
||
| PYHEALTH_REPO_ROOT = '/Users/wpang/Desktop/PyHealth' | ||
|
|
||
| EHR_ROOT = os.path.join(PYHEALTH_REPO_ROOT, "srv/local/data/physionet.org/files/mimiciv/2.2") | ||
| NOTE_ROOT = os.path.join(PYHEALTH_REPO_ROOT, "srv/local/data/physionet.org/files/mimic-iv-note/2.2") | ||
| CXR_ROOT = os.path.join(PYHEALTH_REPO_ROOT,"srv/local/data/physionet.org/files/mimic-cxr-jpg/2.0.0") | ||
| CACHE_DIR = os.path.join(PYHEALTH_REPO_ROOT,"srv/local/data/wp/pyhealth_cache") | ||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| dataset = MIMIC4Dataset( | ||
| ehr_root=EHR_ROOT, | ||
| note_root=NOTE_ROOT, | ||
| ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"], | ||
| note_tables=["discharge", "radiology"], | ||
| cache_dir=CACHE_DIR, | ||
| num_workers=8, | ||
| # dev=True | ||
| ) | ||
|
|
||
| # Apply multimodal task | ||
| task = EHRFoundationalModelMIMIC4() | ||
| samples = dataset.set_task(task, cache_dir=f"{CACHE_DIR}/task", num_workers=8) | ||
|
|
||
| # Get and print sample | ||
| sample = samples[0] | ||
| print(sample) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| from datetime import datetime | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| from pyhealth.tasks.base_task import BaseTask | ||
|
|
||
| class EHRFoundationalModelMIMIC4(BaseTask): | ||
|
|
||
| task_name: str = "EHRFoundationalModelMIMIC4" | ||
|
|
||
| def __init__(self): | ||
| """Initialize the EHR Foundational Model task.""" | ||
| self.input_schema: Dict[str, str] = { | ||
| "discharge_note_times": "tuple_time_text", | ||
| "radiology_note_times": "tuple_time_text", | ||
| } | ||
| self.output_schema: Dict[str, str] = {"mortality": "binary"} | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Found another bug, mortality should be a "binary" variable for our purposes here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes, haha good catch! I made it into I wonder if it was because in the |
||
| def _clean_text(self, text: Optional[str]) -> Optional[str]: | ||
| """Return text if non-empty, otherwise None.""" | ||
| return text if text else None | ||
|
|
||
| def _compute_time_diffs(self, notes_with_timestamps, anchor_time=None): | ||
| # TODO: Add docstrings. | ||
| # anchor_time is in case we want it to normalize/center the time on admission time or something like that. | ||
|
|
||
| if not notes_with_timestamps: # TODO: Maybe I should move this somewhere else as it's not relevant to time diffs | ||
| return (["<missing>"], [0.0]) # TODO: How should we handle notes with missing timestamps? | ||
| result = [] | ||
| for i, (text, timestamp) in enumerate(notes_with_timestamps): | ||
| if anchor_time is not None: | ||
| diff = (timestamp - anchor_time).total_seconds() / 3600 | ||
| elif i == 0: | ||
| diff = 0.0 | ||
| else: | ||
| diff = (timestamp - notes_with_timestamps[i - 1][1]).total_seconds() / 3600 | ||
| result.append((text, diff)) | ||
| texts, time_diffs = zip(*result) | ||
| return (list(texts), list(time_diffs)) | ||
|
|
||
| def __call__(self, patient: Any) -> List[Dict[str, Any]]: | ||
| # Get demographic info to filter by age | ||
| demographics = patient.get_events(event_type="patients") | ||
| if not demographics: | ||
| return [] | ||
|
|
||
| demographics = demographics[0] | ||
|
|
||
| # Get visits | ||
| admissions = patient.get_events(event_type="admissions") | ||
| if len(admissions) == 0: | ||
| return [] | ||
|
|
||
| # Determine which admissions to process iteratively | ||
| # Check each admission's NEXT admission for mortality flag | ||
| admissions_to_process = [] | ||
| mortality_label = 0 | ||
|
|
||
| for i, admission in enumerate(admissions): | ||
| # Check if THIS admission has the death flag | ||
| if admission.hospital_expire_flag in [1, "1"]: | ||
| # Patient died in this admission - set mortality label | ||
| # but don't include this admission's data | ||
| mortality_label = 1 | ||
| break | ||
|
|
||
| # Check if there's a next admission with death flag | ||
| if i + 1 < len(admissions): | ||
| next_admission = admissions[i + 1] | ||
| if next_admission.hospital_expire_flag in [1, "1"]: | ||
| # Next admission has death - include current, set mortality | ||
| admissions_to_process.append(admission) | ||
| mortality_label = 1 | ||
| break | ||
|
|
||
| # No death in current or next - include this admission | ||
| admissions_to_process.append(admission) | ||
|
|
||
| if len(admissions_to_process) == 0: | ||
| return [] | ||
|
|
||
| # Get first admission time as reference for lab time calculations | ||
| first_admission_time = admissions_to_process[0].timestamp | ||
|
|
||
| # Aggregated data across all admissions | ||
| all_discharge_notes_timestamped = [] # List of (note_text, timestamp) tuples | ||
| all_radiology_notes_timestamped = [] # List of (note_text, timestamp) tuples | ||
|
|
||
| # Process each admission and aggregate data | ||
| for admission in admissions_to_process: | ||
| # Parse admission discharge time for lab events filtering | ||
| try: | ||
| admission_dischtime = datetime.strptime( | ||
| admission.dischtime, "%Y-%m-%d %H:%M:%S" | ||
| ) | ||
| except (ValueError, AttributeError): | ||
| # If we can't parse discharge time, skip this admission | ||
| continue | ||
|
|
||
| # Skip if discharge is before admission (data quality issue) | ||
| if admission_dischtime < admission.timestamp: | ||
| continue | ||
|
|
||
| # Get notes using hadm_id filtering | ||
| discharge_notes = patient.get_events( | ||
| event_type="discharge", filters=[("hadm_id", "==", admission.hadm_id)] | ||
| ) | ||
| radiology_notes = patient.get_events( | ||
| event_type="radiology", filters=[("hadm_id", "==", admission.hadm_id)] | ||
| ) | ||
|
|
||
| # Extract and aggregate notes as individual items in lists | ||
| # Note: attribute is "text" (from mimic4_note.yaml), not "discharge"/"radiology" | ||
| for note in discharge_notes: | ||
| try: | ||
| note_text = self._clean_text(note.text) | ||
| if note_text: | ||
| all_discharge_notes_timestamped.append((note_text, note.timestamp)) | ||
| except AttributeError: | ||
| pass | ||
|
|
||
| for note in radiology_notes: | ||
| try: | ||
| note_text = self._clean_text(note.text) | ||
| if note_text: | ||
| all_radiology_notes_timestamped.append((note_text, note.timestamp)) | ||
| except AttributeError: | ||
| pass | ||
|
|
||
| # Convert (note_text, timestamp) tuples to (note_text, time_diff_hours) tuples | ||
| discharge_note_times = self._compute_time_diffs(all_discharge_notes_timestamped) | ||
| radiology_note_times = self._compute_time_diffs(all_radiology_notes_timestamped) | ||
|
|
||
| return [ | ||
| { | ||
| "patient_id": patient.patient_id, | ||
| "discharge_note_times": discharge_note_times, | ||
| "radiology_note_times": radiology_note_times, | ||
| "mortality": mortality_label, | ||
| } | ||
| ] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just do to get pass the serialization problems:
texts, time_diffs = value time_tensor = torch.tensor(time_diffs, dtype=torch.float32) return pickle.dumps(texts), time_tensorHowever, I think we may need to think of a better approach here, one that includes a tokenizer in our processor here.