Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 97 additions & 27 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,19 @@ def __init__(
tables (List[str]): List of table names to load.
dataset_name (Optional[str]): Name of the dataset. Defaults to class name.
config_path (Optional[str]): Path to the configuration YAML file.
cache_dir (Optional[str | Path]): Directory for caching processed data.
Behavior depends on the type passed:

- **None** (default): Auto-generates a cache path under the default
pyhealth cache directory. Cache files include a UUID in their
filenames (e.g., ``global_event_df_{uuid}.parquet``) derived from
the dataset configuration, so different table sets don't collide.
- **str**: Used as the cache directory path. Cache files include a
UUID in their filenames to prevent collisions between different
table configurations sharing the same directory.
- **Path**: Used as-is with NO modification. Cache files still include
UUID in their filenames for isolation.
num_workers (int): Number of worker processes for parallel operations.
dev (bool): Whether to run in dev mode (limits to 1000 patients).
"""
if len(set(tables)) != len(tables):
Expand All @@ -326,45 +339,82 @@ def __init__(
)

# Cached attributes
self._cache_dir = cache_dir
self._cache_dir = self._init_cache_dir(cache_dir)
self._global_event_df = None
self._unique_patient_ids = None

@property
def cache_dir(self) -> Path:
def _init_cache_dir(self, cache_dir: str | Path | None) -> Path:
"""Returns the cache directory path.
The cache structure is as follows::

tmp/ # Temporary files during processing
global_event_df.parquet/ # Cached global event dataframe
tasks/ # Cached task-specific data, please see set_task method
The cache directory is determined by the type of ``cache_dir`` passed
to ``__init__``:

- **None**: Auto-generated under default pyhealth cache directory.
- **str**: Used as-is as the cache directory path.
- **Path**: Used exactly as-is (no modification).

Cache files within the directory include UUID suffixes in their
filenames (e.g., ``global_event_df_{uuid}.parquet``) to prevent
collisions between different table configurations.

The cache structure within the directory is::

tmp/ # Temporary files during processing
{uuid}/ # Cache files for this dataset configuration
global_event_df.parquet/ # Cached global event dataframe
tasks/ # Cached task-specific data
{task_name}_{uuid}/ # Cached data for specific task based on task name and its args
task_df_{uuid}.ld/ # Intermediate task dataframe based on schema
samples_{uuid}.ld/ # Final processed samples after applying processors

Returns:
Path: The cache directory path.
Path: The resolved cache directory path.
"""
if self._cache_dir is None:
id_str = json.dumps(
{
"root": self.root,
"tables": sorted(self.tables),
"dataset_name": self.dataset_name,
"dev": self.dev,
},
sort_keys=True,
)
id_str = json.dumps(
{
"root": self.root,
"tables": sorted(self.tables),
"dataset_name": self.dataset_name,
"dev": self.dev,
},
sort_keys=True,
)

if cache_dir is None:
cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str(
uuid.uuid5(uuid.NAMESPACE_DNS, id_str)
)
cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}")
self._cache_dir = cache_dir
else:
# Ensure the explicitly provided cache_dir exists
cache_dir = Path(self._cache_dir)
# Ensure separate cache directories for different table configurations by appending a UUID suffix
cache_dir = Path(self._cache_dir) / str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str))
cache_dir.mkdir(parents=True, exist_ok=True)
self._cache_dir = cache_dir
return Path(self._cache_dir)


def _get_cache_uuid(self) -> str:
"""Get the cache UUID for this dataset configuration.

Returns a deterministic UUID computed from tables, root, dataset_name,
and dev mode. This is used to create unique filenames within the cache
directory so that different table configurations don't collide.
"""
if not hasattr(self, '_cache_uuid') or self._cache_uuid is None:
id_str = json.dumps(
{
"root": self.root,
"tables": sorted(self.tables),
"dataset_name": self.dataset_name,
"dev": self.dev,
},
sort_keys=True,
)
self._cache_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str))
return self._cache_uuid

def create_tmpdir(self) -> Path:
"""Creates and returns a new temporary directory within the cache.

Expand Down Expand Up @@ -499,9 +549,12 @@ def global_event_df(self) -> pl.LazyFrame:
self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore

if self._global_event_df is None:
ret_path = self.cache_dir / "global_event_df.parquet"
ret_path = self.cache_dir / f"global_event_df_{self._get_cache_uuid()}.parquet"
if not ret_path.exists():
logger.info(f"No cached event dataframe found. Creating: {ret_path}")
self._event_transform(ret_path)
else:
logger.info(f"Found cached event dataframe: {ret_path}")
self._global_event_df = ret_path

return pl.scan_parquet(
Expand Down Expand Up @@ -822,13 +875,16 @@ def set_task(
"""Processes the base dataset to generate the task-specific sample dataset.
The cache structure is as follows::

task_df.ld/ # Intermediate task dataframe after task transformation
samples_{uuid}.ld/ # Final processed samples after applying processors
schema.pkl # Saved SampleBuilder schema
*.bin # Processed sample files
samples_{uuid}.ld/
task_df_{schema_uuid}.ld/ # Intermediate task dataframe (schema-aware)
samples_{proc_uuid}.ld/ # Final processed samples after applying processors
schema.pkl # Saved SampleBuilder schema
*.bin # Processed sample files
samples_{proc_uuid}.ld/
...

The task_df path includes a hash of the task's input/output schemas,
so changing schemas automatically invalidates the cached task dataframe.

Args:
task (Optional[BaseTask]): The task to set. Uses default task if None.
num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`.
Expand Down Expand Up @@ -903,9 +959,23 @@ def set_task(
default=str
)

task_df_path = Path(cache_dir) / "task_df.ld"
# Hash based ONLY on task schemas (not the task instance) to avoid
# recursion issues. This ensures task_df is invalidated when schemas change.
task_schema_params = json.dumps(
{
"input_schema": task.input_schema,
"output_schema": task.output_schema,
},
sort_keys=True,
default=str
)
task_schema_hash = uuid.uuid5(uuid.NAMESPACE_DNS, task_schema_params)

task_df_path = Path(cache_dir) / f"task_df_{task_schema_hash}.ld"
samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld"

logger.info(f"Task cache paths: task_df={task_df_path}, samples={samples_path}")

task_df_path.mkdir(parents=True, exist_ok=True)
samples_path.mkdir(parents=True, exist_ok=True)

Expand Down
Loading