From 8ed6f96d0f39983b3d36cded7d92fd52b2cf2542 Mon Sep 17 00:00:00 2001 From: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com> Date: Fri, 20 Feb 2026 10:26:19 +0900 Subject: [PATCH 1/2] Add JP Prime Yahoo collector support --- scripts/data_collector/README.md | 4 +- scripts/data_collector/utils.py | 130 ++++++++++++++++- scripts/data_collector/yahoo/README.md | 14 +- scripts/data_collector/yahoo/collector.py | 68 ++++++++- tests/test_yahoo_collector_jp.py | 168 ++++++++++++++++++++++ 5 files changed, 370 insertions(+), 14 deletions(-) create mode 100644 tests/test_yahoo_collector_jp.py diff --git a/scripts/data_collector/README.md b/scripts/data_collector/README.md index d0058b33e2c..f485ae175e5 100644 --- a/scripts/data_collector/README.md +++ b/scripts/data_collector/README.md @@ -4,7 +4,7 @@ Scripts for data collection -- yahoo: get *US/CN* stock data from *Yahoo Finance* +- yahoo: get *US/CN/IN/BR/JP* stock data from *Yahoo Finance* - fund: get fund data from *http://fund.eastmoney.com* - cn_index: get *CN index* from *http://www.csindex.com.cn*, *CSI300*/*CSI100* - us_index: get *US index* from *https://en.wikipedia.org/wiki*, *SP500*/*NASDAQ100*/*DJIA*/*SP400* @@ -57,4 +57,4 @@ Scripts for data collection | Component | required data | |---------------------------------------------------|--------------------------------| | Data retrieval | Features, Calendar, Instrument | - | Backtest | **Features[Price/Volume]**, Calendar, Instruments | \ No newline at end of file + | Backtest | **Features[Price/Volume]**, Calendar, Instruments | diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index bf87d0de5ea..bf49e2d2c46 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -9,8 +9,9 @@ import pickle import requests import functools +from io import BytesIO from pathlib import Path -from typing import Iterable, Tuple, List +from typing import Iterable, Tuple, List, Optional import numpy as np import pandas as pd @@ -36,14 +37,18 @@ "US_ALL": "^GSPC", "IN_ALL": "^NSEI", "BR_ALL": "^BVSP", + "JP_ALL": "^N225", } +JPX_LISTED_COMPANIES_URL = "https://www.jpx.co.jp/markets/statistics-equities/misc/tvdivq0000001vg2-att/data_j.xls" + _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None _US_SYMBOLS = None _IN_SYMBOLS = None _BR_SYMBOLS = None +_JP_SYMBOLS = None _EN_FUND_SYMBOLS = None _CALENDAR_MAP = {} @@ -51,13 +56,20 @@ MINIMUM_SYMBOLS_NUM = 3900 +def _normalize_calendar_timestamp(value) -> pd.Timestamp: + ts = pd.Timestamp(value) + if ts.tzinfo is not None: + ts = ts.tz_localize(None) + return ts.normalize() + + def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: """get SH/SZ history calendar list Parameters ---------- bench_code: str - value from ["CSI300", "CSI500", "ALL", "US_ALL"] + value from ["CSI300", "CSI500", "ALL", "US_ALL", "IN_ALL", "BR_ALL", "JP_ALL"] Returns ------- @@ -72,11 +84,15 @@ def _get_calendar(url): calendar = _CALENDAR_MAP.get(bench_code, None) if calendar is None: - if bench_code.startswith("US_") or bench_code.startswith("IN_") or bench_code.startswith("BR_"): - print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code])) - print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")) - df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max") - calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist() + if ( + bench_code.startswith("US_") + or bench_code.startswith("IN_") + or bench_code.startswith("BR_") + or bench_code.startswith("JP_") + ): + _ticker = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]) + df = _ticker.history(interval="1d", period="max") + calendar = sorted({_normalize_calendar_timestamp(v) for v in df.index.get_level_values(level="date")}) else: if bench_code.upper() == "ALL": import akshare as ak # pylint: disable=C0415 @@ -448,6 +464,106 @@ def _format(s_): return _BR_SYMBOLS +def _normalize_jpx_column_name(col_name: str) -> str: + return str(col_name).replace(" ", "").replace("\u3000", "").replace("\n", "").strip().lower() + + +def _find_jpx_column(columns: list, exact_candidates: list, keyword_candidates: list) -> Optional[str]: + normalized_map = {col: _normalize_jpx_column_name(col) for col in columns} + exact_candidates = {_normalize_jpx_column_name(col) for col in exact_candidates} + keyword_candidates = [_normalize_jpx_column_name(col) for col in keyword_candidates] + + for _col, _normalized_col in normalized_map.items(): + if _normalized_col in exact_candidates: + return _col + + for _col, _normalized_col in normalized_map.items(): + if all(_keyword in _normalized_col for _keyword in keyword_candidates): + return _col + + return None + + +def _extract_jp_prime_symbols(df: pd.DataFrame) -> list: + if df is None or df.empty: + raise ValueError("JPX listed companies file is empty") + + code_col = _find_jpx_column( + columns=df.columns.tolist(), + exact_candidates=["コード", "銘柄コード", "code", "securitycode"], + keyword_candidates=["コード"], + ) + if code_col is None: + raise ValueError("Unable to find stock code column in JPX listed companies file") + + market_col = _find_jpx_column( + columns=df.columns.tolist(), + exact_candidates=["市場・商品区分", "市場商品区分", "市場区分", "marketsegment"], + keyword_candidates=["市場", "区分"], + ) + if market_col is None: + raise ValueError("Unable to find market classification column in JPX listed companies file") + + domestic_col = _find_jpx_column( + columns=df.columns.tolist(), + exact_candidates=["内外株式区分", "内外区分", "domesticforeign"], + keyword_candidates=["内外", "区分"], + ) + + market_series = df[market_col].astype(str) + prime_mask = market_series.str.contains("プライム", na=False) + + if market_series.str.contains("内国株式", na=False).any(): + domestic_mask = market_series.str.contains("内国株式", na=False) + elif domestic_col is not None: + domestic_mask = df[domestic_col].astype(str).str.contains("内国株式", na=False) + else: + domestic_mask = market_series.str.contains("内国株式", na=False) + + target_df = df.loc[prime_mask & domestic_mask, [code_col]].copy() + if target_df.empty: + raise ValueError("No JPX Prime domestic stocks found in listed companies file") + + symbols = ( + target_df[code_col] + .astype(str) + .str.extract(r"(\d{4})", expand=False) + .dropna() + .apply(lambda code: f"{code}.T") + .drop_duplicates() + .sort_values() + .tolist() + ) + if not symbols: + raise ValueError("No valid JP stock symbols extracted from JPX listed companies file") + return symbols + + +def get_jp_stock_symbols() -> list: + """get JP Prime (domestic stock) symbols""" + + global _JP_SYMBOLS # pylint: disable=W0603 + + @deco_retry + def _get_jpx_listed_companies_df(): + resp = requests.get(JPX_LISTED_COMPANIES_URL, timeout=None) + if resp.status_code != 200: + raise ValueError(f"request error, status_code={resp.status_code}") + try: + return pd.read_excel(BytesIO(resp.content), dtype=str) + except Exception as excel_error: + try: + return pd.read_html(BytesIO(resp.content))[0].astype(str) + except Exception as html_error: + raise ValueError( + f"failed to parse JPX listed companies file: excel_error={excel_error}, html_error={html_error}" + ) from html_error + + if _JP_SYMBOLS is None: + _JP_SYMBOLS = _extract_jp_prime_symbols(_get_jpx_listed_companies_df()) + return _JP_SYMBOLS + + def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: """get en fund symbols diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index c12a2383a40..8d501baac44 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -55,6 +55,7 @@ pip install -r requirements.txt ### Collector *YahooFinance* data to qlib > collector *YahooFinance* data and *dump* into `qlib` format. > If the above ready-made data can't meet users' requirements, users can follow this section to crawl the latest data and convert it to qlib-data. +> For `region=JP`, the symbol universe is **TSE Prime (domestic stocks)** from JPX listed companies file. 1. download data to csv: `python scripts/data_collector/yahoo/collector.py download_data` This will download the raw data such as high, low, open, close, adjclose price from yahoo to a local directory. One file per symbol. @@ -63,7 +64,8 @@ pip install -r requirements.txt - `source_dir`: save the directory - `interval`: `1d` or `1min`, by default `1d` > **due to the limitation of the *YahooFinance API*, only the last month's data is available in `1min`** - - `region`: `CN` or `US` or `IN` or `BR`, by default `CN` + - `region`: `CN` or `US` or `IN` or `BR` or `JP`, by default `CN` + > `JP` supports `1d` only - `delay`: `time.sleep(delay)`, by default *0.5* - `start`: start datetime, by default *"2000-01-01"*; *closed interval(including start)* - `end`: end datetime, by default `pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`; *open interval(excluding end)* @@ -92,6 +94,9 @@ pip install -r requirements.txt python collector.py download_data --source_dir ~/.qlib/stock_data/source/br_data --start 2003-01-03 --end 2022-03-01 --delay 1 --interval 1d --region BR # br 1min data python collector.py download_data --source_dir ~/.qlib/stock_data/source/br_data_1min --delay 1 --interval 1min --region BR + + # jp 1d data (TSE Prime domestic stocks) + python collector.py download_data --source_dir ~/.qlib/stock_data/source/jp_data --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region JP ``` 2. normalize data: `python scripts/data_collector/yahoo/collector.py normalize_data` @@ -105,7 +110,8 @@ pip install -r requirements.txt - `max_workers`: number of concurrent, by default *1* - `interval`: `1d` or `1min`, by default `1d` > if **`interval == 1min`**, `qlib_data_1d_dir` cannot be `None` - - `region`: `CN` or `US` or `IN`, by default `CN` + - `region`: `CN` or `US` or `IN` or `BR` or `JP`, by default `CN` + > `JP` supports `1d` only - `date_field_name`: column *name* identifying time in csv files, by default `date` - `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol` - `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None` @@ -133,6 +139,9 @@ pip install -r requirements.txt # normalize 1min br python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/br_data --source_dir ~/.qlib/stock_data/source/br_data_1min --normalize_dir ~/.qlib/stock_data/source/br_1min_nor --region BR --interval 1min + + # normalize 1d jp + python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/jp_data --normalize_dir ~/.qlib/stock_data/source/jp_1d_nor --region JP --interval 1d ``` 3. dump data: `python scripts/dump_bin.py dump_all` @@ -222,4 +231,3 @@ pip install -r requirements.txt # get all symbol data # df = D.features(D.instruments("all"), ["$close"], freq="1min") ``` - diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 82660f1112b..874236bc319 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -38,6 +38,7 @@ get_us_stock_symbols, get_in_stock_symbols, get_br_stock_symbols, + get_jp_stock_symbols, generate_minutes_calendar_from_daily, calc_adjusted_price, ) @@ -364,6 +365,33 @@ class YahooCollectorBR1min(YahooCollectorBR): retry = 2 +class YahooCollectorJP(YahooCollector, ABC): + def get_instrument_list(self): + logger.info("get JP Prime (domestic stock) symbols......") + symbols = get_jp_stock_symbols() + logger.info(f"get {len(symbols)} symbols.") + return symbols + + def download_index_data(self): + pass + + def normalize_symbol(self, symbol): + return code_to_fname(symbol).upper() + + @property + def _timezone(self): + return "Asia/Tokyo" + + +class YahooCollectorJP1d(YahooCollectorJP): + pass + + +class YahooCollectorJP1min(YahooCollectorJP): + def __init__(self, *args, **kwargs): + raise ValueError("JP region does not support 1min data collection") + + class YahooNormalize(BaseNormalize): COLUMNS = ["open", "close", "high", "low", "volume"] DAILY_FORMAT = "%Y-%m-%d" @@ -720,6 +748,27 @@ def symbol_to_yahoo(self, symbol): return fname_to_code(symbol) +class YahooNormalizeJP: + def _get_calendar_list(self) -> Iterable[pd.Timestamp]: + return get_calendar_list("JP_ALL") + + +class YahooNormalizeJP1d(YahooNormalizeJP, YahooNormalize1d): + pass + + +class YahooNormalizeJP1dExtend(YahooNormalizeJP, YahooNormalize1dExtend): + pass + + +class YahooNormalizeJP1min(YahooNormalizeJP, YahooNormalize1min): + def __init__(self, *args, **kwargs): + raise ValueError("JP region does not support 1min normalization") + + def symbol_to_yahoo(self, symbol): + return fname_to_code(symbol) + + class Run(BaseRun): def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN): """ @@ -735,11 +784,15 @@ def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval= interval: str freq, value from [1min, 1d], default 1d region: str - region, value from ["CN", "US", "BR"], default "CN" + region, value from ["CN", "US", "IN", "BR", "JP"], default "CN" """ super().__init__(source_dir, normalize_dir, max_workers, interval) self.region = region + def _validate_region_interval(self): + if self.region.upper() == "JP" and self.interval.lower() == "1min": + raise ValueError("JP region does not support 1min data") + @property def collector_class_name(self): return f"YahooCollector{self.region.upper()}{self.interval}" @@ -792,6 +845,7 @@ def download_data( # get 1m data $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ + self._validate_region_interval() if self.interval == "1d" and pd.Timestamp(end) > pd.Timestamp(datetime.datetime.now().strftime("%Y-%m-%d")): raise ValueError(f"end_date: {end} is greater than the current date.") @@ -828,6 +882,7 @@ def normalize_data( $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d $ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min """ + self._validate_region_interval() if self.interval.lower() == "1min": if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists(): raise ValueError( @@ -937,6 +992,7 @@ def update_data_to_bin( check_data_length: int = None, delay: float = 1, exists_skip: bool = False, + limit_nums: int = None, ): """update yahoo data to bin @@ -953,6 +1009,8 @@ def update_data_to_bin( time.sleep(delay), default 1 exists_skip: bool exists skip, by default False + limit_nums: int + using for debug, by default None Notes ----- If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day @@ -981,7 +1039,13 @@ def update_data_to_bin( # download data from yahoo # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1 - self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length) + self.download_data( + delay=delay, + start=trading_date, + end=end_date, + check_data_length=check_data_length, + limit_nums=limit_nums, + ) # NOTE: a larger max_workers setting here would be faster self.max_workers = ( max(multiprocessing.cpu_count() - 2, 1) diff --git a/tests/test_yahoo_collector_jp.py b/tests/test_yahoo_collector_jp.py new file mode 100644 index 00000000000..56166ba1263 --- /dev/null +++ b/tests/test_yahoo_collector_jp.py @@ -0,0 +1,168 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +import importlib as stdlib_importlib +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +import pandas as pd + +ROOT_DIR = Path(__file__).resolve().parent.parent +SCRIPTS_DIR = ROOT_DIR.joinpath("scripts") + +sys.path.insert(0, str(SCRIPTS_DIR)) + +from data_collector import utils as dc_utils # noqa: E402 +from data_collector.yahoo import collector as yahoo_collector # noqa: E402 + + +class TestYahooCollectorJP(unittest.TestCase): + def setUp(self): + dc_utils._JP_SYMBOLS = None # pylint: disable=W0212 + dc_utils._CALENDAR_MAP.pop("JP_ALL", None) # pylint: disable=W0212 + + @staticmethod + def _import_module_side_effect(module_name: str): + if module_name == "collector": + return yahoo_collector + return stdlib_importlib.import_module(module_name) + + def test_extract_jp_prime_symbols(self): + source_df = pd.DataFrame( + { + "コード": ["7203", "6758", "1301", "9432", "1301"], + "市場・商品区分": [ + "プライム(内国株式)", + "スタンダード(内国株式)", + "プライム(内国株式)", + "プライム(外国株式)", + "プライム(内国株式)", + ], + } + ) + + symbols = dc_utils._extract_jp_prime_symbols(source_df) # pylint: disable=W0212 + self.assertEqual(symbols, ["1301.T", "7203.T"]) + + def test_extract_jp_prime_symbols_missing_columns(self): + with self.assertRaisesRegex(ValueError, "stock code column"): + dc_utils._extract_jp_prime_symbols(pd.DataFrame({"市場・商品区分": ["プライム(内国株式)"]})) # pylint: disable=W0212 + + with self.assertRaisesRegex(ValueError, "market classification column"): + dc_utils._extract_jp_prime_symbols(pd.DataFrame({"コード": ["7203"]})) # pylint: disable=W0212 + + def test_get_jp_stock_symbols_from_jpx(self): + source_df = pd.DataFrame( + { + "コード": ["7203", "1301", "0001", "8306"], + "市場・商品区分": ["プライム(内国株式)", "プライム(内国株式)", "ETF・ETN", "プライム(内国株式)"], + } + ) + + class _Resp: + status_code = 200 + content = b"dummy" + + with patch("data_collector.utils.requests.get", return_value=_Resp()) as mock_get: + with patch("data_collector.utils.pd.read_excel", return_value=source_df) as mock_read_excel: + symbols = dc_utils.get_jp_stock_symbols() + symbols_cached = dc_utils.get_jp_stock_symbols() + + self.assertEqual(symbols, ["1301.T", "7203.T", "8306.T"]) + self.assertEqual(symbols_cached, ["1301.T", "7203.T", "8306.T"]) + self.assertEqual(mock_get.call_count, 1) + self.assertEqual(mock_read_excel.call_count, 1) + + def test_run_class_resolution_for_jp_1d(self): + with patch("data_collector.base.importlib.import_module", side_effect=self._import_module_side_effect): + with tempfile.TemporaryDirectory() as tmp_dir: + run = yahoo_collector.Run( + source_dir=tmp_dir, + normalize_dir=tmp_dir, + max_workers=1, + interval="1d", + region="JP", + ) + + self.assertEqual(run.collector_class_name, "YahooCollectorJP1d") + self.assertEqual(run.normalize_class_name, "YahooNormalizeJP1d") + self.assertIs(getattr(yahoo_collector, run.collector_class_name), yahoo_collector.YahooCollectorJP1d) + self.assertIs(getattr(yahoo_collector, run.normalize_class_name), yahoo_collector.YahooNormalizeJP1d) + + def test_run_jp_1min_is_not_supported(self): + with patch("data_collector.base.importlib.import_module", side_effect=self._import_module_side_effect): + with tempfile.TemporaryDirectory() as tmp_dir: + run = yahoo_collector.Run( + source_dir=tmp_dir, + normalize_dir=tmp_dir, + max_workers=1, + interval="1min", + region="JP", + ) + with self.assertRaisesRegex(ValueError, "JP region does not support 1min data"): + run.download_data(start="2024-01-01", end="2024-01-05") + with self.assertRaisesRegex(ValueError, "JP region does not support 1min data"): + run.normalize_data(qlib_data_1d_dir=tmp_dir) + + def test_get_calendar_list_jp_normalizes_timezone_values(self): + multi_index = pd.MultiIndex.from_arrays( + [ + ["^N225", "^N225", "^N225"], + [ + pd.Timestamp("2024-01-05"), + pd.Timestamp("2024-01-04 10:26:15+09:00"), + pd.Timestamp("2024-01-05 11:30:00+09:00"), + ], + ], + names=["symbol", "date"], + ) + history_df = pd.DataFrame({"close": [1, 2, 3]}, index=multi_index) + + class FakeTicker: + def __init__(self, *args, **kwargs): + pass + + def history(self, *args, **kwargs): + return history_df + + with patch("data_collector.utils.Ticker", FakeTicker): + calendar = dc_utils.get_calendar_list("JP_ALL") + + self.assertEqual(calendar, [pd.Timestamp("2024-01-04"), pd.Timestamp("2024-01-05")]) + + def test_update_data_to_bin_jp_skips_index_components(self): + with tempfile.TemporaryDirectory() as tmp_dir: + qlib_dir = Path(tmp_dir).joinpath("qlib_data") + qlib_dir.joinpath("calendars").mkdir(parents=True) + qlib_dir.joinpath("calendars/day.txt").write_text("2024-01-04\n2024-01-05\n", encoding="utf-8") + + with patch("data_collector.base.importlib.import_module", side_effect=self._import_module_side_effect): + run = yahoo_collector.Run( + source_dir=Path(tmp_dir).joinpath("source"), + normalize_dir=Path(tmp_dir).joinpath("normalize"), + max_workers=1, + interval="1d", + region="JP", + ) + + with patch("data_collector.yahoo.collector.exists_qlib_data", return_value=True): + with patch.object(yahoo_collector.Run, "download_data", return_value=None) as mock_download: + with patch.object(yahoo_collector.Run, "normalize_data_1d_extend", return_value=None) as mock_normalize_ext: + with patch("data_collector.yahoo.collector.DumpDataUpdate") as mock_dump_cls: + with patch("data_collector.yahoo.collector.importlib.import_module") as mock_import: + run.update_data_to_bin(qlib_data_1d_dir=str(qlib_dir), end_date="2024-01-06") + + mock_download.assert_called_once_with( + delay=1, start="2024-01-04", end="2024-01-06", check_data_length=None, limit_nums=None + ) + mock_normalize_ext.assert_called_once() + self.assertEqual(Path(mock_normalize_ext.call_args.args[0]).resolve(), qlib_dir.resolve()) + mock_dump_cls.return_value.dump.assert_called_once() + mock_import.assert_not_called() + + +if __name__ == "__main__": + unittest.main() From d0e1ae3d4ddc183beff6cc2617ad69239fe163df Mon Sep 17 00:00:00 2001 From: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com> Date: Fri, 27 Feb 2026 16:55:24 +0900 Subject: [PATCH 2/2] Add JP ETF/ETN symbols to Yahoo collector universe --- scripts/data_collector/utils.py | 25 +++++++++++++++-------- scripts/data_collector/yahoo/README.md | 4 ++-- scripts/data_collector/yahoo/collector.py | 2 +- tests/test_yahoo_collector_jp.py | 23 +++++++++++++-------- 4 files changed, 35 insertions(+), 19 deletions(-) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index bf49e2d2c46..25747246b8c 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -9,6 +9,7 @@ import pickle import requests import functools +import unicodedata from io import BytesIO from pathlib import Path from typing import Iterable, Tuple, List, Optional @@ -468,6 +469,10 @@ def _normalize_jpx_column_name(col_name: str) -> str: return str(col_name).replace(" ", "").replace("\u3000", "").replace("\n", "").strip().lower() +def _normalize_jpx_market_value(value: str) -> str: + return unicodedata.normalize("NFKC", str(value)).upper() + + def _find_jpx_column(columns: list, exact_candidates: list, keyword_candidates: list) -> Optional[str]: normalized_map = {col: _normalize_jpx_column_name(col) for col in columns} exact_candidates = {_normalize_jpx_column_name(col) for col in exact_candidates} @@ -511,18 +516,22 @@ def _extract_jp_prime_symbols(df: pd.DataFrame) -> list: ) market_series = df[market_col].astype(str) - prime_mask = market_series.str.contains("プライム", na=False) + normalized_market_series = market_series.map(_normalize_jpx_market_value) + prime_mask = normalized_market_series.str.contains("プライム", na=False) + etf_etn_mask = normalized_market_series.str.contains(r"ETF|ETN", na=False) - if market_series.str.contains("内国株式", na=False).any(): - domestic_mask = market_series.str.contains("内国株式", na=False) + if normalized_market_series.str.contains("内国株式", na=False).any(): + domestic_mask = normalized_market_series.str.contains("内国株式", na=False) elif domestic_col is not None: - domestic_mask = df[domestic_col].astype(str).str.contains("内国株式", na=False) + domestic_mask = ( + df[domestic_col].astype(str).map(_normalize_jpx_market_value).str.contains("内国株式", na=False) + ) else: - domestic_mask = market_series.str.contains("内国株式", na=False) + domestic_mask = normalized_market_series.str.contains("内国株式", na=False) - target_df = df.loc[prime_mask & domestic_mask, [code_col]].copy() + target_df = df.loc[(prime_mask & domestic_mask) | etf_etn_mask, [code_col]].copy() if target_df.empty: - raise ValueError("No JPX Prime domestic stocks found in listed companies file") + raise ValueError("No JPX Prime domestic stocks or ETF/ETN found in listed companies file") symbols = ( target_df[code_col] @@ -540,7 +549,7 @@ def _extract_jp_prime_symbols(df: pd.DataFrame) -> list: def get_jp_stock_symbols() -> list: - """get JP Prime (domestic stock) symbols""" + """get JP Prime (domestic stock) and ETF/ETN symbols""" global _JP_SYMBOLS # pylint: disable=W0603 diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index 8d501baac44..71859b79b34 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -55,7 +55,7 @@ pip install -r requirements.txt ### Collector *YahooFinance* data to qlib > collector *YahooFinance* data and *dump* into `qlib` format. > If the above ready-made data can't meet users' requirements, users can follow this section to crawl the latest data and convert it to qlib-data. -> For `region=JP`, the symbol universe is **TSE Prime (domestic stocks)** from JPX listed companies file. +> For `region=JP`, the symbol universe is **TSE Prime (domestic stocks) + ETF/ETN** from JPX listed companies file. 1. download data to csv: `python scripts/data_collector/yahoo/collector.py download_data` This will download the raw data such as high, low, open, close, adjclose price from yahoo to a local directory. One file per symbol. @@ -95,7 +95,7 @@ pip install -r requirements.txt # br 1min data python collector.py download_data --source_dir ~/.qlib/stock_data/source/br_data_1min --delay 1 --interval 1min --region BR - # jp 1d data (TSE Prime domestic stocks) + # jp 1d data (TSE Prime domestic stocks + ETF/ETN) python collector.py download_data --source_dir ~/.qlib/stock_data/source/jp_data --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region JP ``` 2. normalize data: `python scripts/data_collector/yahoo/collector.py normalize_data` diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 874236bc319..684b1ab489d 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -367,7 +367,7 @@ class YahooCollectorBR1min(YahooCollectorBR): class YahooCollectorJP(YahooCollector, ABC): def get_instrument_list(self): - logger.info("get JP Prime (domestic stock) symbols......") + logger.info("get JP Prime (domestic stock) + ETF/ETN symbols......") symbols = get_jp_stock_symbols() logger.info(f"get {len(symbols)} symbols.") return symbols diff --git a/tests/test_yahoo_collector_jp.py b/tests/test_yahoo_collector_jp.py index 56166ba1263..f10f704338b 100644 --- a/tests/test_yahoo_collector_jp.py +++ b/tests/test_yahoo_collector_jp.py @@ -30,22 +30,23 @@ def _import_module_side_effect(module_name: str): return yahoo_collector return stdlib_importlib.import_module(module_name) - def test_extract_jp_prime_symbols(self): + def test_extract_jp_prime_symbols_with_etf_etn(self): source_df = pd.DataFrame( { - "コード": ["7203", "6758", "1301", "9432", "1301"], + "コード": ["7203", "6758", "1301", "9432", "1489", "1489"], "市場・商品区分": [ "プライム(内国株式)", "スタンダード(内国株式)", "プライム(内国株式)", "プライム(外国株式)", - "プライム(内国株式)", + "ETF・ETN", + "ETF・ETN", ], } ) symbols = dc_utils._extract_jp_prime_symbols(source_df) # pylint: disable=W0212 - self.assertEqual(symbols, ["1301.T", "7203.T"]) + self.assertEqual(symbols, ["1301.T", "1489.T", "7203.T"]) def test_extract_jp_prime_symbols_missing_columns(self): with self.assertRaisesRegex(ValueError, "stock code column"): @@ -57,8 +58,14 @@ def test_extract_jp_prime_symbols_missing_columns(self): def test_get_jp_stock_symbols_from_jpx(self): source_df = pd.DataFrame( { - "コード": ["7203", "1301", "0001", "8306"], - "市場・商品区分": ["プライム(内国株式)", "プライム(内国株式)", "ETF・ETN", "プライム(内国株式)"], + "コード": ["7203", "1301", "0001", "8306", "1489"], + "市場・商品区分": [ + "プライム(内国株式)", + "プライム(内国株式)", + "ETF・ETN", + "プライム(内国株式)", + "ETF・ETN", + ], } ) @@ -71,8 +78,8 @@ class _Resp: symbols = dc_utils.get_jp_stock_symbols() symbols_cached = dc_utils.get_jp_stock_symbols() - self.assertEqual(symbols, ["1301.T", "7203.T", "8306.T"]) - self.assertEqual(symbols_cached, ["1301.T", "7203.T", "8306.T"]) + self.assertEqual(symbols, ["0001.T", "1301.T", "1489.T", "7203.T", "8306.T"]) + self.assertEqual(symbols_cached, ["0001.T", "1301.T", "1489.T", "7203.T", "8306.T"]) self.assertEqual(mock_get.call_count, 1) self.assertEqual(mock_read_excel.call_count, 1)