diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py index 8a100a2d19e..223867e0cae 100644 --- a/qlib/data/storage/file_storage.py +++ b/qlib/data/storage/file_storage.py @@ -3,7 +3,7 @@ import struct from pathlib import Path -from typing import Iterable, Union, Dict, Mapping, Tuple, List +from typing import Iterable, Union, Dict, Mapping, Tuple, List, Optional import numpy as np import pandas as pd @@ -286,8 +286,41 @@ class FileFeatureStorage(FileStorageMixin, FeatureStorage): def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict = None, **kwargs): super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs) self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri) + self._resolved_file_name: Optional[str] = None self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin" + @property + def uri(self) -> Path: + if self.freq not in self.support_freq: + raise ValueError(f"{self.storage_name}: {self.provider_uri} does not contain data for {self.freq}") + + feature_dir = self.dpm.get_data_uri(self.freq).joinpath(f"{self.storage_name}s") + + if self._resolved_file_name is None: + default_uri = feature_dir.joinpath(self.file_name) + if default_uri.exists(): + self._resolved_file_name = self.file_name + else: + self._resolved_file_name = self._resolve_case_sensitive_file_name(feature_dir) or self.file_name + + return feature_dir.joinpath(self._resolved_file_name) + + def _resolve_case_sensitive_file_name(self, feature_dir: Path) -> Optional[str]: + if not feature_dir.exists(): + return None + + normalized_leaf = f"{self.field.lower()}.{self.freq.lower()}.bin" + expected_inst = self.instrument.lower() + + for instrument_dir in feature_dir.iterdir(): + if not instrument_dir.is_dir() or instrument_dir.name.lower() != expected_inst: + continue + for candidate_file in instrument_dir.iterdir(): + if candidate_file.is_file() and candidate_file.name.lower() == normalized_leaf: + return str(Path(instrument_dir.name, candidate_file.name)) + + return None + def clear(self): with self.uri.open("wb") as _: pass diff --git a/tests/storage_tests/test_storage.py b/tests/storage_tests/test_storage.py index 92fed34ecda..f646f72df80 100644 --- a/tests/storage_tests/test_storage.py +++ b/tests/storage_tests/test_storage.py @@ -4,6 +4,7 @@ from pathlib import Path from collections.abc import Iterable +import tempfile import numpy as np from qlib.tests import TestAutoData @@ -168,3 +169,23 @@ def test_feature_storage(self): print(feature[:].empty) with self.assertRaises(ValueError): print(feature.data.empty) + + def test_feature_storage_resolves_case_sensitive_instrument_dir(self): + with tempfile.TemporaryDirectory() as tmp_dir: + feature_root = Path(tmp_dir).joinpath("features", "ABB-U") + feature_root.mkdir(parents=True, exist_ok=True) + np.array([0.0, 1.0, 2.0, 3.0], dtype="