diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 7faffef60..3930795f6 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -309,6 +309,19 @@ def __init__( tables (List[str]): List of table names to load. dataset_name (Optional[str]): Name of the dataset. Defaults to class name. config_path (Optional[str]): Path to the configuration YAML file. + cache_dir (Optional[str | Path]): Directory for caching processed data. + Behavior depends on the type passed: + + - **None** (default): Auto-generates a cache path under the default + pyhealth cache directory. Cache files include a UUID in their + filenames (e.g., ``global_event_df_{uuid}.parquet``) derived from + the dataset configuration, so different table sets don't collide. + - **str**: Used as the cache directory path. Cache files include a + UUID in their filenames to prevent collisions between different + table configurations sharing the same directory. + - **Path**: Used as-is with NO modification. Cache files still include + UUID in their filenames for isolation. + num_workers (int): Number of worker processes for parallel operations. dev (bool): Whether to run in dev mode (limits to 1000 patients). """ if len(set(tables)) != len(tables): @@ -326,32 +339,48 @@ def __init__( ) # Cached attributes - self._cache_dir = cache_dir + self._cache_dir = self._init_cache_dir(cache_dir) self._global_event_df = None self._unique_patient_ids = None - @property - def cache_dir(self) -> Path: + def _init_cache_dir(self, cache_dir: str | Path | None) -> Path: """Returns the cache directory path. - The cache structure is as follows:: - tmp/ # Temporary files during processing - global_event_df.parquet/ # Cached global event dataframe - tasks/ # Cached task-specific data, please see set_task method + The cache directory is determined by the type of ``cache_dir`` passed + to ``__init__``: + + - **None**: Auto-generated under default pyhealth cache directory. + - **str**: Used as-is as the cache directory path. + - **Path**: Used exactly as-is (no modification). + + Cache files within the directory include UUID suffixes in their + filenames (e.g., ``global_event_df_{uuid}.parquet``) to prevent + collisions between different table configurations. + + The cache structure within the directory is:: + + tmp/ # Temporary files during processing + {uuid}/ # Cache files for this dataset configuration + global_event_df.parquet/ # Cached global event dataframe + tasks/ # Cached task-specific data + {task_name}_{uuid}/ # Cached data for specific task based on task name and its args + task_df_{uuid}.ld/ # Intermediate task dataframe based on schema + samples_{uuid}.ld/ # Final processed samples after applying processors Returns: - Path: The cache directory path. + Path: The resolved cache directory path. """ - if self._cache_dir is None: - id_str = json.dumps( - { - "root": self.root, - "tables": sorted(self.tables), - "dataset_name": self.dataset_name, - "dev": self.dev, - }, - sort_keys=True, - ) + id_str = json.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, + sort_keys=True, + ) + + if cache_dir is None: cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( uuid.uuid5(uuid.NAMESPACE_DNS, id_str) ) @@ -359,12 +388,33 @@ def cache_dir(self) -> Path: logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}") self._cache_dir = cache_dir else: - # Ensure the explicitly provided cache_dir exists - cache_dir = Path(self._cache_dir) + # Ensure separate cache directories for different table configurations by appending a UUID suffix + cache_dir = Path(self._cache_dir) / str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) cache_dir.mkdir(parents=True, exist_ok=True) self._cache_dir = cache_dir return Path(self._cache_dir) + + def _get_cache_uuid(self) -> str: + """Get the cache UUID for this dataset configuration. + + Returns a deterministic UUID computed from tables, root, dataset_name, + and dev mode. This is used to create unique filenames within the cache + directory so that different table configurations don't collide. + """ + if not hasattr(self, '_cache_uuid') or self._cache_uuid is None: + id_str = json.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, + sort_keys=True, + ) + self._cache_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + return self._cache_uuid + def create_tmpdir(self) -> Path: """Creates and returns a new temporary directory within the cache. @@ -499,9 +549,12 @@ def global_event_df(self) -> pl.LazyFrame: self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore if self._global_event_df is None: - ret_path = self.cache_dir / "global_event_df.parquet" + ret_path = self.cache_dir / f"global_event_df_{self._get_cache_uuid()}.parquet" if not ret_path.exists(): + logger.info(f"No cached event dataframe found. Creating: {ret_path}") self._event_transform(ret_path) + else: + logger.info(f"Found cached event dataframe: {ret_path}") self._global_event_df = ret_path return pl.scan_parquet( @@ -822,13 +875,16 @@ def set_task( """Processes the base dataset to generate the task-specific sample dataset. The cache structure is as follows:: - task_df.ld/ # Intermediate task dataframe after task transformation - samples_{uuid}.ld/ # Final processed samples after applying processors - schema.pkl # Saved SampleBuilder schema - *.bin # Processed sample files - samples_{uuid}.ld/ + task_df_{schema_uuid}.ld/ # Intermediate task dataframe (schema-aware) + samples_{proc_uuid}.ld/ # Final processed samples after applying processors + schema.pkl # Saved SampleBuilder schema + *.bin # Processed sample files + samples_{proc_uuid}.ld/ ... + The task_df path includes a hash of the task's input/output schemas, + so changing schemas automatically invalidates the cached task dataframe. + Args: task (Optional[BaseTask]): The task to set. Uses default task if None. num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`. @@ -903,9 +959,23 @@ def set_task( default=str ) - task_df_path = Path(cache_dir) / "task_df.ld" + # Hash based ONLY on task schemas (not the task instance) to avoid + # recursion issues. This ensures task_df is invalidated when schemas change. + task_schema_params = json.dumps( + { + "input_schema": task.input_schema, + "output_schema": task.output_schema, + }, + sort_keys=True, + default=str + ) + task_schema_hash = uuid.uuid5(uuid.NAMESPACE_DNS, task_schema_params) + + task_df_path = Path(cache_dir) / f"task_df_{task_schema_hash}.ld" samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld" + logger.info(f"Task cache paths: task_df={task_df_path}, samples={samples_path}") + task_df_path.mkdir(parents=True, exist_ok=True) samples_path.mkdir(parents=True, exist_ok=True)