From d4f8beb8e8a87ff75d7bdd007b7dfac5d6d96d82 Mon Sep 17 00:00:00 2001 From: John Wu Date: Mon, 9 Feb 2026 16:56:27 -0600 Subject: [PATCH 1/5] init commit, tldr the cache_dir or files should probably indicate table names --- pyhealth/datasets/base_dataset.py | 78 ++++++++++++++++++++++++------- 1 file changed, 60 insertions(+), 18 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 7faffef60..44748e05f 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -309,6 +309,21 @@ def __init__( tables (List[str]): List of table names to load. dataset_name (Optional[str]): Name of the dataset. Defaults to class name. config_path (Optional[str]): Path to the configuration YAML file. + cache_dir (Optional[str | Path]): Directory for caching processed data. + Behavior depends on the type passed: + + - **None** (default): Auto-generates a cache path under the default + pyhealth cache directory with a UUID subdirectory derived from the + dataset configuration (tables, root, dev mode). Different table sets + automatically get separate caches. + - **str**: Treated as a base name. A UUID suffix will be appended to + the directory name to prevent cache collisions between different table + configurations. For example, ``"/my/cache"`` becomes + ``"/my/cache_"``. A warning is logged showing the transformation. + - **Path**: Used as-is with NO modification. Use this when you want + full control over the exact cache directory path. You are responsible + for ensuring different table configurations don't share the same path. + num_workers (int): Number of worker processes for parallel operations. dev (bool): Whether to run in dev mode (limits to 1000 patients). """ if len(set(tables)) != len(tables): @@ -333,37 +348,64 @@ def __init__( @property def cache_dir(self) -> Path: """Returns the cache directory path. - The cache structure is as follows:: + + The cache directory is determined by the type of ``cache_dir`` passed + to ``__init__``: + + - **None**: Auto-generated under default pyhealth cache with UUID subdir. + - **str**: UUID suffix appended to directory name (e.g., ``cache_``). + - **Path**: Used exactly as-is (no UUID appended). Pass ``Path(...)`` to + opt out of automatic UUID suffixing and use an exact directory. + + The cache structure within the directory is:: tmp/ # Temporary files during processing global_event_df.parquet/ # Cached global event dataframe - tasks/ # Cached task-specific data, please see set_task method + tasks/ # Cached task-specific data Returns: - Path: The cache directory path. + Path: The resolved cache directory path. """ + # If already computed (Path object), return it directly. + # This also handles the case where the user passed Path() explicitly + # at init time -- it's used as-is with no modification. + if isinstance(self._cache_dir, Path): + return self._cache_dir + + # Generate UUID based on dataset configuration (tables, root, etc.) + # to ensure different table sets get isolated cache directories. + id_str = json.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, + sort_keys=True, + ) + cache_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + if self._cache_dir is None: - id_str = json.dumps( - { - "root": self.root, - "tables": sorted(self.tables), - "dataset_name": self.dataset_name, - "dev": self.dev, - }, - sort_keys=True, - ) - cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( - uuid.uuid5(uuid.NAMESPACE_DNS, id_str) - ) + # No cache_dir provided: use default pyhealth cache with UUID subdir + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / cache_uuid cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}") self._cache_dir = cache_dir else: - # Ensure the explicitly provided cache_dir exists - cache_dir = Path(self._cache_dir) + # String provided: append UUID to directory name for table isolation + base_path = Path(self._cache_dir) + cache_dir = base_path.parent / f"{base_path.name}_{cache_uuid}" cache_dir.mkdir(parents=True, exist_ok=True) + logger.warning( + f"cache_dir was provided as a string: '{self._cache_dir}'. " + f"A UUID suffix has been appended for table-specific isolation: " + f"'{cache_dir}'. Different table configurations will use separate " + f"cache directories. To use an exact path with no modification, " + f"pass cache_dir=Path('{self._cache_dir}') instead." + ) self._cache_dir = cache_dir - return Path(self._cache_dir) + + return self._cache_dir def create_tmpdir(self) -> Path: """Creates and returns a new temporary directory within the cache. From 840efbc8c237621a61875be63d0e9b45a171a010 Mon Sep 17 00:00:00 2001 From: John Wu Date: Tue, 10 Feb 2026 14:08:39 -0600 Subject: [PATCH 2/5] better approach is to hash the filename rather than the directory name --- pyhealth/datasets/base_dataset.py | 120 ++++++++++++++++++------------ 1 file changed, 72 insertions(+), 48 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 44748e05f..6c663a77e 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -313,16 +313,14 @@ def __init__( Behavior depends on the type passed: - **None** (default): Auto-generates a cache path under the default - pyhealth cache directory with a UUID subdirectory derived from the - dataset configuration (tables, root, dev mode). Different table sets - automatically get separate caches. - - **str**: Treated as a base name. A UUID suffix will be appended to - the directory name to prevent cache collisions between different table - configurations. For example, ``"/my/cache"`` becomes - ``"/my/cache_"``. A warning is logged showing the transformation. - - **Path**: Used as-is with NO modification. Use this when you want - full control over the exact cache directory path. You are responsible - for ensuring different table configurations don't share the same path. + pyhealth cache directory. Cache files include a UUID in their + filenames (e.g., ``global_event_df_{uuid}.parquet``) derived from + the dataset configuration, so different table sets don't collide. + - **str**: Used as the cache directory path. Cache files include a + UUID in their filenames to prevent collisions between different + table configurations sharing the same directory. + - **Path**: Used as-is with NO modification. Cache files still include + UUID in their filenames for isolation. num_workers (int): Number of worker processes for parallel operations. dev (bool): Whether to run in dev mode (limits to 1000 patients). """ @@ -352,16 +350,19 @@ def cache_dir(self) -> Path: The cache directory is determined by the type of ``cache_dir`` passed to ``__init__``: - - **None**: Auto-generated under default pyhealth cache with UUID subdir. - - **str**: UUID suffix appended to directory name (e.g., ``cache_``). - - **Path**: Used exactly as-is (no UUID appended). Pass ``Path(...)`` to - opt out of automatic UUID suffixing and use an exact directory. + - **None**: Auto-generated under default pyhealth cache directory. + - **str**: Used as-is as the cache directory path. + - **Path**: Used exactly as-is (no modification). + + Cache files within the directory include UUID suffixes in their + filenames (e.g., ``global_event_df_{uuid}.parquet``) to prevent + collisions between different table configurations. The cache structure within the directory is:: - tmp/ # Temporary files during processing - global_event_df.parquet/ # Cached global event dataframe - tasks/ # Cached task-specific data + tmp/ # Temporary files during processing + global_event_df_{uuid}.parquet/ # Cached global event dataframe + tasks/ # Cached task-specific data Returns: Path: The resolved cache directory path. @@ -372,41 +373,44 @@ def cache_dir(self) -> Path: if isinstance(self._cache_dir, Path): return self._cache_dir - # Generate UUID based on dataset configuration (tables, root, etc.) - # to ensure different table sets get isolated cache directories. - id_str = json.dumps( - { - "root": self.root, - "tables": sorted(self.tables), - "dataset_name": self.dataset_name, - "dev": self.dev, - }, - sort_keys=True, - ) - cache_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) - if self._cache_dir is None: - # No cache_dir provided: use default pyhealth cache with UUID subdir - cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / cache_uuid + # No cache_dir provided: use default pyhealth cache directory + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / "datasets" cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}") self._cache_dir = cache_dir else: - # String provided: append UUID to directory name for table isolation - base_path = Path(self._cache_dir) - cache_dir = base_path.parent / f"{base_path.name}_{cache_uuid}" + # String provided: use as-is (file-based isolation via UUID in filenames) + cache_dir = Path(self._cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) - logger.warning( - f"cache_dir was provided as a string: '{self._cache_dir}'. " - f"A UUID suffix has been appended for table-specific isolation: " - f"'{cache_dir}'. Different table configurations will use separate " - f"cache directories. To use an exact path with no modification, " - f"pass cache_dir=Path('{self._cache_dir}') instead." + logger.info( + f"Using cache dir: {cache_dir} " + f"(cache files will include UUID suffix for table isolation)" ) self._cache_dir = cache_dir return self._cache_dir + def _get_cache_uuid(self) -> str: + """Get the cache UUID for this dataset configuration. + + Returns a deterministic UUID computed from tables, root, dataset_name, + and dev mode. This is used to create unique filenames within the cache + directory so that different table configurations don't collide. + """ + if not hasattr(self, '_cache_uuid') or self._cache_uuid is None: + id_str = json.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, + sort_keys=True, + ) + self._cache_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + return self._cache_uuid + def create_tmpdir(self) -> Path: """Creates and returns a new temporary directory within the cache. @@ -541,9 +545,12 @@ def global_event_df(self) -> pl.LazyFrame: self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore if self._global_event_df is None: - ret_path = self.cache_dir / "global_event_df.parquet" + ret_path = self.cache_dir / f"global_event_df_{self._get_cache_uuid()}.parquet" if not ret_path.exists(): + logger.info(f"No cached event dataframe found. Creating: {ret_path}") self._event_transform(ret_path) + else: + logger.info(f"Found cached event dataframe: {ret_path}") self._global_event_df = ret_path return pl.scan_parquet( @@ -864,13 +871,16 @@ def set_task( """Processes the base dataset to generate the task-specific sample dataset. The cache structure is as follows:: - task_df.ld/ # Intermediate task dataframe after task transformation - samples_{uuid}.ld/ # Final processed samples after applying processors - schema.pkl # Saved SampleBuilder schema - *.bin # Processed sample files - samples_{uuid}.ld/ + task_df_{schema_uuid}.ld/ # Intermediate task dataframe (schema-aware) + samples_{proc_uuid}.ld/ # Final processed samples after applying processors + schema.pkl # Saved SampleBuilder schema + *.bin # Processed sample files + samples_{proc_uuid}.ld/ ... + The task_df path includes a hash of the task's input/output schemas, + so changing schemas automatically invalidates the cached task dataframe. + Args: task (Optional[BaseTask]): The task to set. Uses default task if None. num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`. @@ -945,9 +955,23 @@ def set_task( default=str ) - task_df_path = Path(cache_dir) / "task_df.ld" + # Hash based ONLY on task schemas (not the task instance) to avoid + # recursion issues. This ensures task_df is invalidated when schemas change. + task_schema_params = json.dumps( + { + "input_schema": task.input_schema, + "output_schema": task.output_schema, + }, + sort_keys=True, + default=str + ) + task_schema_hash = uuid.uuid5(uuid.NAMESPACE_DNS, task_schema_params) + + task_df_path = Path(cache_dir) / f"task_df_{task_schema_hash}.ld" samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld" + logger.info(f"Task cache paths: task_df={task_df_path}, samples={samples_path}") + task_df_path.mkdir(parents=True, exist_ok=True) samples_path.mkdir(parents=True, exist_ok=True) From 5e108ba359980c3fbb06c0ddf858799bd338aa59 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 11 Feb 2026 18:49:33 -0600 Subject: [PATCH 3/5] Fix cache dir --- pyhealth/datasets/base_dataset.py | 48 +++++++++++++++++-------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 6c663a77e..3930795f6 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -339,12 +339,11 @@ def __init__( ) # Cached attributes - self._cache_dir = cache_dir + self._cache_dir = self._init_cache_dir(cache_dir) self._global_event_df = None self._unique_patient_ids = None - @property - def cache_dir(self) -> Path: + def _init_cache_dir(self, cache_dir: str | Path | None) -> Path: """Returns the cache directory path. The cache directory is determined by the type of ``cache_dir`` passed @@ -360,36 +359,41 @@ def cache_dir(self) -> Path: The cache structure within the directory is:: - tmp/ # Temporary files during processing - global_event_df_{uuid}.parquet/ # Cached global event dataframe - tasks/ # Cached task-specific data + tmp/ # Temporary files during processing + {uuid}/ # Cache files for this dataset configuration + global_event_df.parquet/ # Cached global event dataframe + tasks/ # Cached task-specific data + {task_name}_{uuid}/ # Cached data for specific task based on task name and its args + task_df_{uuid}.ld/ # Intermediate task dataframe based on schema + samples_{uuid}.ld/ # Final processed samples after applying processors Returns: Path: The resolved cache directory path. """ - # If already computed (Path object), return it directly. - # This also handles the case where the user passed Path() explicitly - # at init time -- it's used as-is with no modification. - if isinstance(self._cache_dir, Path): - return self._cache_dir - - if self._cache_dir is None: - # No cache_dir provided: use default pyhealth cache directory - cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / "datasets" + id_str = json.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, + sort_keys=True, + ) + + if cache_dir is None: + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( + uuid.uuid5(uuid.NAMESPACE_DNS, id_str) + ) cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}") self._cache_dir = cache_dir else: - # String provided: use as-is (file-based isolation via UUID in filenames) - cache_dir = Path(self._cache_dir) + # Ensure separate cache directories for different table configurations by appending a UUID suffix + cache_dir = Path(self._cache_dir) / str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) cache_dir.mkdir(parents=True, exist_ok=True) - logger.info( - f"Using cache dir: {cache_dir} " - f"(cache files will include UUID suffix for table isolation)" - ) self._cache_dir = cache_dir + return Path(self._cache_dir) - return self._cache_dir def _get_cache_uuid(self) -> str: """Get the cache UUID for this dataset configuration. From fa63c430a26b1b3ad0af27a994568edb21796c6f Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Sun, 15 Feb 2026 19:41:38 +0000 Subject: [PATCH 4/5] Simplify new caching behavior code and add unit tests --- pyhealth/datasets/base_dataset.py | 119 +++++++++------------------- pyhealth/datasets/sample_dataset.py | 1 + tests/core/test_caching.py | 106 +++++++++++++++++++++++-- 3 files changed, 137 insertions(+), 89 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 3930795f6..e7066e264 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -125,17 +125,17 @@ def _litdata_merge(cache_dir: Path) -> None: """ from litdata.streaming.writer import _INDEX_FILENAME files = os.listdir(cache_dir) - + # Return if the index already exists if _INDEX_FILENAME in files: return index_files = [f for f in files if f.endswith(_INDEX_FILENAME)] - + # Return if there are no index files to merge if len(index_files) == 0: raise ValueError("There are zero samples in the dataset, please check the task and processors.") - + BinaryWriter(cache_dir=str(cache_dir), chunk_bytes="64MB").merge(num_workers=len(index_files)) @@ -311,16 +311,11 @@ def __init__( config_path (Optional[str]): Path to the configuration YAML file. cache_dir (Optional[str | Path]): Directory for caching processed data. Behavior depends on the type passed: - - - **None** (default): Auto-generates a cache path under the default - pyhealth cache directory. Cache files include a UUID in their - filenames (e.g., ``global_event_df_{uuid}.parquet``) derived from - the dataset configuration, so different table sets don't collide. - - **str**: Used as the cache directory path. Cache files include a - UUID in their filenames to prevent collisions between different - table configurations sharing the same directory. - - **Path**: Used as-is with NO modification. Cache files still include - UUID in their filenames for isolation. + + - **None** (default): Auto-generates a cache path under the default + pyhealth cache directory. + - **str** or **Path**: Used as the root cache directory path. A UUID + is appended to the provided path to capture dataset configuration. num_workers (int): Number of worker processes for parallel operations. dev (bool): Whether to run in dev mode (limits to 1000 patients). """ @@ -339,7 +334,7 @@ def __init__( ) # Cached attributes - self._cache_dir = self._init_cache_dir(cache_dir) + self.cache_dir = self._init_cache_dir(cache_dir) self._global_event_df = None self._unique_patient_ids = None @@ -350,70 +345,44 @@ def _init_cache_dir(self, cache_dir: str | Path | None) -> Path: to ``__init__``: - **None**: Auto-generated under default pyhealth cache directory. - - **str**: Used as-is as the cache directory path. - - **Path**: Used exactly as-is (no modification). - - Cache files within the directory include UUID suffixes in their - filenames (e.g., ``global_event_df_{uuid}.parquet``) to prevent - collisions between different table configurations. + - **str** or **Path: Used as the root cache directory path. A UUID + is appended to the provided path to capture dataset configuration. The cache structure within the directory is:: - tmp/ # Temporary files during processing - {uuid}/ # Cache files for this dataset configuration + {dataset_uuid}/ # Cache files for this dataset configuration + tmp/ # Temporary files during processing global_event_df.parquet/ # Cached global event dataframe tasks/ # Cached task-specific data - {task_name}_{uuid}/ # Cached data for specific task based on task name and its args - task_df_{uuid}.ld/ # Intermediate task dataframe based on schema - samples_{uuid}.ld/ # Final processed samples after applying processors + {task_name}_{task_uuid}/ # Cached data for specific task based on task name, schema, and args + task_df.ld/ # Intermediate task dataframe based on schema + samples.ld/ # Final processed samples after applying processors Returns: Path: The resolved cache directory path. """ id_str = json.dumps( { - "root": self.root, + "root": str(self.root), "tables": sorted(self.tables), "dataset_name": self.dataset_name, "dev": self.dev, }, sort_keys=True, ) - + + id = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + if cache_dir is None: - cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( - uuid.uuid5(uuid.NAMESPACE_DNS, id_str) - ) + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / id cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}") - self._cache_dir = cache_dir else: # Ensure separate cache directories for different table configurations by appending a UUID suffix - cache_dir = Path(self._cache_dir) / str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + cache_dir = Path(cache_dir) / id cache_dir.mkdir(parents=True, exist_ok=True) - self._cache_dir = cache_dir - return Path(self._cache_dir) - - - def _get_cache_uuid(self) -> str: - """Get the cache UUID for this dataset configuration. - - Returns a deterministic UUID computed from tables, root, dataset_name, - and dev mode. This is used to create unique filenames within the cache - directory so that different table configurations don't collide. - """ - if not hasattr(self, '_cache_uuid') or self._cache_uuid is None: - id_str = json.dumps( - { - "root": self.root, - "tables": sorted(self.tables), - "dataset_name": self.dataset_name, - "dev": self.dev, - }, - sort_keys=True, - ) - self._cache_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) - return self._cache_uuid + logger.info(f"Using provided cache_dir: {cache_dir}") + return Path(cache_dir) def create_tmpdir(self) -> Path: """Creates and returns a new temporary directory within the cache. @@ -549,7 +518,7 @@ def global_event_df(self) -> pl.LazyFrame: self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore if self._global_event_df is None: - ret_path = self.cache_dir / f"global_event_df_{self._get_cache_uuid()}.parquet" + ret_path = self.cache_dir / f"global_event_df.parquet" if not ret_path.exists(): logger.info(f"No cached event dataframe found. Creating: {ret_path}") self._event_transform(ret_path) @@ -875,21 +844,17 @@ def set_task( """Processes the base dataset to generate the task-specific sample dataset. The cache structure is as follows:: - task_df_{schema_uuid}.ld/ # Intermediate task dataframe (schema-aware) - samples_{proc_uuid}.ld/ # Final processed samples after applying processors - schema.pkl # Saved SampleBuilder schema - *.bin # Processed sample files - samples_{proc_uuid}.ld/ - ... - - The task_df path includes a hash of the task's input/output schemas, - so changing schemas automatically invalidates the cached task dataframe. + {task_name}_{task_uuid}/ # Cached data for specific task based on task name, schema, and args + task_df.ld/ # Intermediate task dataframe based on schema + samples.ld/ # Final processed samples after applying processors + schema.pkl # Saved SampleBuilder schema + *.bin # Processed sample files Args: task (Optional[BaseTask]): The task to set. Uses default task if None. num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`. cache_dir (Optional[str]): Directory to cache samples after task transformation, - but without applying processors. Default is {self.cache_dir}/tasks/{task_name}_{uuid5(vars(task))}. + but without applying processors. Default is {self.cache_dir}/tasks. cache_format (str): Deprecated. Only "parquet" is supported now. input_processors (Optional[Dict[str, FeatureProcessor]]): Pre-fitted input processors. If provided, these will be used @@ -921,7 +886,11 @@ def set_task( ) task_params = json.dumps( - vars(task), + { + **vars(task), + "input_schema": task.input_schema, + "output_schema": task.output_schema, + }, sort_keys=True, default=str ) @@ -959,26 +928,14 @@ def set_task( default=str ) - # Hash based ONLY on task schemas (not the task instance) to avoid - # recursion issues. This ensures task_df is invalidated when schemas change. - task_schema_params = json.dumps( - { - "input_schema": task.input_schema, - "output_schema": task.output_schema, - }, - sort_keys=True, - default=str - ) - task_schema_hash = uuid.uuid5(uuid.NAMESPACE_DNS, task_schema_params) - - task_df_path = Path(cache_dir) / f"task_df_{task_schema_hash}.ld" + task_df_path = Path(cache_dir) / "task_df.ld" samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld" logger.info(f"Task cache paths: task_df={task_df_path}, samples={samples_path}") task_df_path.mkdir(parents=True, exist_ok=True) samples_path.mkdir(parents=True, exist_ok=True) - + if not (samples_path / "index.json").exists(): # Check if index.json exists to verify cache integrity, this # is the standard file for litdata.StreamingDataset diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index c7b719f62..134bede35 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -296,6 +296,7 @@ def __init__( """ super().__init__(path, **kwargs) + self.path = path self.dataset_name = "" if dataset_name is None else dataset_name self.task_name = "" if task_name is None else task_name diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index fa832c31d..0a6ac7dfe 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -43,16 +43,45 @@ def __call__(self, patient): return samples +class MockTask2(BaseTask): + """Second mock task with a different output schema than the first""" + task_name = "test_task" + input_schema = {"test_attribute": "raw"} + output_schema = {"test_label": "multiclass"} + + def __init__(self, param=None): + self.call_count = 0 + if param: + self.param = param + + def __call__(self, patient): + """Return mock samples based on patient data.""" + # Extract patient's test data from the patient's data source + patient_data = patient.data_source + self.call_count += 1 + + samples = [] + for row in patient_data.iter_rows(named=True): + sample = { + "test_attribute": row["test/test_attribute"], + "test_label": row["test/test_label"], + "patient_id": row["patient_id"], + } + samples.append(sample) + + return samples + + class MockDataset(BaseDataset): """Mock dataset for testing purposes.""" - def __init__(self, cache_dir: str | Path | None = None): + def __init__(self, root: str = "", tables = [], dataset_name = "TestDataset", cache_dir: str | Path | None = None, dev = False): super().__init__( - root="", - tables=[], - dataset_name="TestDataset", + root=root, + tables=tables, + dataset_name=dataset_name, cache_dir=cache_dir, - dev=False, + dev=dev, ) def load_data(self) -> dd.DataFrame: @@ -162,7 +191,7 @@ def test_set_task_writes_cache_and_metadata(self): def test_default_cache_dir_is_used(self): """When cache_dir is omitted, default cache dir should be used.""" task_params = json.dumps( - {"call_count": 0}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "call_count": 0}, sort_keys=True, default=str ) @@ -199,14 +228,16 @@ def test_tasks_with_diff_param_values_get_diff_caches(self): sample_dataset1 = self.dataset.set_task(MockTask(param=1)) sample_dataset2 = self.dataset.set_task(MockTask(param=2)) + self.assertNotEqual(sample_dataset1.path, sample_dataset2.path) + task_params1 = json.dumps( - {"call_count": 0, "param": 2}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "call_count": 0, "param": 1}, sort_keys=True, default=str ) task_params2 = json.dumps( - {"call_count": 0, "param": 2}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "call_count": 0, "param": 2}, sort_keys=True, default=str ) @@ -225,6 +256,65 @@ def test_tasks_with_diff_param_values_get_diff_caches(self): sample_dataset1.close() sample_dataset2.close() + def test_tasks_with_diff_output_schemas_get_diff_caches(self): + sample_dataset1 = self.dataset.set_task(MockTask()) + sample_dataset2 = self.dataset.set_task(MockTask2()) + + self.assertNotEqual(sample_dataset1.path, sample_dataset2.path) + + task_params1 = json.dumps( + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "call_count": 0}, + sort_keys=True, + default=str + ) + + task_params2 = json.dumps( + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "multiclass"}, "call_count": 0}, + sort_keys=True, + default=str + ) + + task_cache1 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params1)}" + task_cache2 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params2)}" + + self.assertTrue(task_cache1.exists()) + self.assertTrue(task_cache2.exists()) + self.assertTrue((task_cache1 / "task_df.ld" / "index.json").exists()) + self.assertTrue((task_cache2 / "task_df.ld" / "index.json").exists()) + self.assertTrue((self.dataset.cache_dir / "global_event_df.parquet").exists()) + self.assertEqual(len(sample_dataset1), 4) + self.assertEqual(len(sample_dataset2), 4) + + sample_dataset1.close() + sample_dataset2.close() + + def test_datasets_with_diff_roots_get_diff_caches(self): + dataset1 = MockDataset(root=tempfile.TemporaryDirectory().name, cache_dir=self.temp_dir.name) + dataset2 = MockDataset(root=tempfile.TemporaryDirectory().name, cache_dir=self.temp_dir.name) + + self.assertNotEqual(dataset1.cache_dir, dataset2.cache_dir) + + def test_datasets_with_diff_tables_get_diff_caches(self): + dataset1 = MockDataset(tables=["one", "two", ], cache_dir=self.temp_dir.name) + dataset2 = MockDataset(tables=["one", "two", "three"], cache_dir=self.temp_dir.name) + dataset3 = MockDataset(tables=["one", "three"], cache_dir=self.temp_dir.name) + dataset4 = MockDataset(tables=[], cache_dir=self.temp_dir.name) + + caches = [dataset1.cache_dir, dataset2.cache_dir, dataset3.cache_dir, dataset4.cache_dir] + + self.assertEqual(len(caches), len(set(caches))) + + def test_datasets_with_diff_names_get_diff_caches(self): + dataset1 = MockDataset(dataset_name="one", cache_dir=self.temp_dir.name) + dataset2 = MockDataset(dataset_name="two", cache_dir=self.temp_dir.name) + + self.assertNotEqual(dataset1.cache_dir, dataset2.cache_dir) + + def test_datasets_with_diff_dev_values_get_diff_caches(self): + dataset1 = MockDataset(dev=True, cache_dir=self.temp_dir.name) + dataset2 = MockDataset(dev=False, cache_dir=self.temp_dir.name) + + self.assertNotEqual(dataset1.cache_dir, dataset2.cache_dir) if __name__ == "__main__": unittest.main() From a70e74b958f3eb286ce24d88cab1147502b8e336 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Sun, 15 Feb 2026 19:55:48 +0000 Subject: [PATCH 5/5] Remove unnecessary f-string --- pyhealth/datasets/base_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index e7066e264..aa260f303 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -518,7 +518,7 @@ def global_event_df(self) -> pl.LazyFrame: self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore if self._global_event_df is None: - ret_path = self.cache_dir / f"global_event_df.parquet" + ret_path = self.cache_dir / "global_event_df.parquet" if not ret_path.exists(): logger.info(f"No cached event dataframe found. Creating: {ret_path}") self._event_transform(ret_path)