Skip to content

[Contribution] Create ehr_foundation_model task#840

Draft
will-pang wants to merge 12 commits intosunlabuiuc:masterfrom
will-pang:FoundationalEHR/wp-create-multimodal-task-notes
Draft

[Contribution] Create ehr_foundation_model task#840
will-pang wants to merge 12 commits intosunlabuiuc:masterfrom
will-pang:FoundationalEHR/wp-create-multimodal-task-notes

Conversation

@will-pang
Copy link
Contributor

@will-pang will-pang commented Feb 11, 2026

Contributor Information

  • Name: William Pang
  • Contribution Type: Task
  • Notes: Google Doc

Description

v0.1

Per John's feedback, I've incorporated a few changes:

  • Incorporate Rian's tuple_time_text_processor, which now feeds in radiology and discharge notes as (note_text, time_diff_hours) tuples
    • I think we still need to discuss how to compute time_diff, in the sense that whether the timestamps from note.timestamp give a proper chronology of time or not. If we need to arrange in chronological order, I probably need to add a .sort(lambda x: x['time_stamp']) function or something equivalent.
  • If the note is missing, we now return:
    return (["<missing>"], [0.0]) # TODO: How should we handle notes with missing timestamps? 
    
  • The code seems to bug out on the preprocessor. I'm not as familiar with litdata for serialization, but per Claude:

When litdata sees a tuple like (List[str], List[float], str), it flattens it and infers a single serializer for the whole field based on the first element it encounters. So if the first element it hits is a str, it uses the string serializer for everything — then fails when it encounters a float.


v0

More of a draft PR as I'm still fairly new to the inner workings of the package, but here'a few things that I still think needs to be done:

  • Seems to work when I run:
  dataset = MIMIC4Dataset(
          ehr_root=EHR_ROOT,
          note_root=NOTE_ROOT,
          ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"],
          note_tables=["discharge", "radiology"],
          cache_dir=CACHE_DIR,
          num_workers=16,
          dev = True
      )
  
      task = EHRFoundationalModelMIMIC4()    
      samples = dataset.set_task(task, cache_dir=f"{CACHE_DIR}/task", num_workers=8)

but when I run it outside of dev mode, I run into this error:

AttributeError: 'str' object has no attribute 'dtype'                                                             
/Users/wpang/.local/share/uv/python/cpython-3.12.12-macos-aarch64-none/lib/python3.12/multiprocessing/resource_tr 
acker.py:279: UserWarning: resource_tracker: There appear to be 10 leaked semaphore objects to clean up at        
shutdown                                                                                                          
warnings.warn('resource_tracker: There appear to be %d '      

Claude says that this relates to the notes varying by length from patient to patient (e.g., patient A might have 4 radiology notes and 2 discharge notes, whereas patient B might have 2 radiology notes and 5 discharge notes), but I'm a little stuck as I am still getting comfortable with the architecture of the package.


Testing Notes

  • To test, you can run this script: examples/foundation_ehr/multimodal_task.py

@will-pang will-pang changed the title Create ehr_foundation_model task [Contribution] Create ehr_foundation_model task Feb 11, 2026
Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, some quick comments, because now I think I know what's happening that's happened to me before.

What you can probably do is just make sure every patient has a List[str]. For patients without a type of note, you can just append "<missing>" to denote a missing note or something of that sort or "". We'll probably have to standardize.

Another thing we can do is make sure to align our definitions with Rian's tuple time processors:

https://pyhealth.readthedocs.io/en/latest/api/processors/pyhealth.processors.TupleTimeTextProcessor.html

So you don't have to define a times feature.

i.e instead of

input_schema = "discharge_times" : 'tensor', "discharge" : 'raw'
what you can do is do:

input_schema = "discharge_note_times" : "tuple_time_text"

where each discharge_note_times is a (notes, times)

Let me know if this helps!

# Aggregated data across all admissions
all_discharge_notes = [] # List of individual discharge notes
all_radiology_notes = [] # List of individual radiology notes
all_discharge_notes_timestamps = [] # List of individual discharge notes timestamps
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also possible that our timestamps may not be being correctly computed? I'd have think about it.

"radiology_note_time_diffs": "tensor",
}
self.output_schema: Dict[str, str] = {"mortality": "regression"}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found another bug, mortality should be a "binary" variable for our purposes here.

Copy link
Contributor Author

@will-pang will-pang Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, haha good catch! I made it into regression for testing because I was getting this error:

Traceback (most recent call last):
  File "/Users/wpang/Desktop/PyHealth/examples/foundation_ehr/multimodal_task.py", line 35, in <module>
    samples = dataset.set_task(task, cache_dir=f"{CACHE_DIR}/task", num_workers=8)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wpang/Desktop/PyHealth/pyhealth/datasets/base_dataset.py", line 936, in set_task
    builder.fit(dataset)
  File "/Users/wpang/Desktop/PyHealth/pyhealth/datasets/sample_dataset.py", line 167, in fit
    processor.fit(samples, key)
  File "/Users/wpang/Desktop/PyHealth/pyhealth/processors/label_processor.py", line 25, in fit
    raise ValueError(f"Expected 2 unique labels, got {len(all_labels)}")
ValueError: Expected 2 unique labels, got 1

I wonder if it was because in the dev sample I was working with haha all the patients had a mortality label of 0, so for testing purposes I just made it a regression task. I've made the update!

Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some other things we definitely need to revisit and follow now that I understand how the processors work, but fortunately PyHealth is really flexible so it should be doable:

  1. In our Task, we'll need explicitly define the processor classes themselves with arguments https://pyhealth.readthedocs.io/en/latest/api/processors.html -> this documentation should explain how to define processor arguments in a task
  2. The TupleTimeTextProcessor will need to leverage a HuggingFace Tokenizer so all texts will be tokenized into a [T x L] tensor of tokens, with a time tensor of Tdimension.
  3. The TimeImageProcessor I think fortunately works as it should with the litdata expectations, just two tensors.
  4. The TextEmbedding model will need to assume inputs are already tokenized

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just do to get pass the serialization problems:
texts, time_diffs = value time_tensor = torch.tensor(time_diffs, dtype=torch.float32) return pickle.dumps(texts), time_tensor

However, I think we may need to think of a better approach here, one that includes a tokenizer in our processor here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants