From 2de6a82ca9d0d34052ebabace94455709625df38 Mon Sep 17 00:00:00 2001 From: Ayush Ojha Date: Sat, 31 Jan 2026 19:26:14 -0800 Subject: [PATCH] fix: initialize best_param before training loop to prevent UnboundLocalError When validation score never improves (e.g. due to NaN scores from missing data), best_param was never assigned inside the if branch, causing UnboundLocalError on model.load_state_dict(best_param) after the loop. Initialize best_param with the model's initial state before the training loop, following the pattern already used in pytorch_gru.py. This ensures the model falls back to initial weights if no epoch improves the score. Fixes #1794 --- qlib/contrib/model/pytorch_adarnn.py | 1 + qlib/contrib/model/pytorch_add.py | 1 + qlib/contrib/model/pytorch_alstm.py | 1 + qlib/contrib/model/pytorch_alstm_ts.py | 1 + qlib/contrib/model/pytorch_gats.py | 1 + qlib/contrib/model/pytorch_gats_ts.py | 1 + qlib/contrib/model/pytorch_gru_ts.py | 1 + qlib/contrib/model/pytorch_hist.py | 1 + qlib/contrib/model/pytorch_igmtf.py | 1 + qlib/contrib/model/pytorch_krnn.py | 1 + qlib/contrib/model/pytorch_localformer.py | 1 + qlib/contrib/model/pytorch_localformer_ts.py | 1 + qlib/contrib/model/pytorch_lstm.py | 1 + qlib/contrib/model/pytorch_lstm_ts.py | 1 + qlib/contrib/model/pytorch_sandwich.py | 1 + qlib/contrib/model/pytorch_sfm.py | 1 + qlib/contrib/model/pytorch_tabnet.py | 1 + qlib/contrib/model/pytorch_tcn.py | 1 + qlib/contrib/model/pytorch_tcn_ts.py | 1 + qlib/contrib/model/pytorch_transformer.py | 1 + qlib/contrib/model/pytorch_transformer_ts.py | 1 + 21 files changed, 21 insertions(+) diff --git a/qlib/contrib/model/pytorch_adarnn.py b/qlib/contrib/model/pytorch_adarnn.py index c1585a6ac0a..781462470c8 100644 --- a/qlib/contrib/model/pytorch_adarnn.py +++ b/qlib/contrib/model/pytorch_adarnn.py @@ -268,6 +268,7 @@ def fit( best_epoch = 0 weight_mat, dist_mat = None, None + best_param = copy.deepcopy(self.model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_add.py b/qlib/contrib/model/pytorch_add.py index c94a03ecc31..17125c34cec 100644 --- a/qlib/contrib/model/pytorch_add.py +++ b/qlib/contrib/model/pytorch_add.py @@ -318,6 +318,7 @@ def bootstrap_fit(self, x_train, y_train, m_train, x_valid, y_valid, m_valid): y_train_values = np.squeeze(y_train.values) m_train_values = np.squeeze(m_train.values.astype(int)) + best_param = copy.deepcopy(self.ADD_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index d1c619ebf41..24818c5b581 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -235,6 +235,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.ALSTM_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 95b5cf95d8b..a83082bbf2b 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -255,6 +255,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.ALSTM_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 16439b3783a..e766c68324f 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -269,6 +269,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.GAT_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 09f0ac08b25..71b6266a3f0 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -283,6 +283,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.GAT_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index 65da5ac4b40..eb7c8879630 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -249,6 +249,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.GRU_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_hist.py b/qlib/contrib/model/pytorch_hist.py index 779cde9c859..eac36be6513 100644 --- a/qlib/contrib/model/pytorch_hist.py +++ b/qlib/contrib/model/pytorch_hist.py @@ -300,6 +300,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.HIST_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_igmtf.py b/qlib/contrib/model/pytorch_igmtf.py index 0bddc5a0f5f..9902690fb45 100644 --- a/qlib/contrib/model/pytorch_igmtf.py +++ b/qlib/contrib/model/pytorch_igmtf.py @@ -294,6 +294,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.igmtf_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_krnn.py b/qlib/contrib/model/pytorch_krnn.py index d97920b4dc5..f5cacc7abca 100644 --- a/qlib/contrib/model/pytorch_krnn.py +++ b/qlib/contrib/model/pytorch_krnn.py @@ -458,6 +458,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.krnn_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_localformer.py b/qlib/contrib/model/pytorch_localformer.py index 42851dd6a28..6d95dd475d4 100644 --- a/qlib/contrib/model/pytorch_localformer.py +++ b/qlib/contrib/model/pytorch_localformer.py @@ -184,6 +184,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_localformer_ts.py b/qlib/contrib/model/pytorch_localformer_ts.py index ae60a399682..f014883f3bc 100644 --- a/qlib/contrib/model/pytorch_localformer_ts.py +++ b/qlib/contrib/model/pytorch_localformer_ts.py @@ -171,6 +171,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py index 3ba09097acd..4c62b8e3a95 100755 --- a/qlib/contrib/model/pytorch_lstm.py +++ b/qlib/contrib/model/pytorch_lstm.py @@ -230,6 +230,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.lstm_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index a0fc34d5832..72f35cbba28 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -244,6 +244,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.LSTM_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_sandwich.py b/qlib/contrib/model/pytorch_sandwich.py index 344368143ff..5fd3135afab 100644 --- a/qlib/contrib/model/pytorch_sandwich.py +++ b/qlib/contrib/model/pytorch_sandwich.py @@ -328,6 +328,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.sandwich_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index c971f1a58c5..8854baa0afb 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -385,6 +385,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.sfm_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index 3c698edade3..02ece6ffa6f 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -185,6 +185,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.tabnet_model.state_dict()) for epoch_idx in range(self.n_epochs): self.logger.info("epoch: %s" % (epoch_idx)) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_tcn.py b/qlib/contrib/model/pytorch_tcn.py index f6e7e953a00..a28d9601c09 100755 --- a/qlib/contrib/model/pytorch_tcn.py +++ b/qlib/contrib/model/pytorch_tcn.py @@ -240,6 +240,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.tcn_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_tcn_ts.py b/qlib/contrib/model/pytorch_tcn_ts.py index a6cc38885c3..0cc644bbe5c 100755 --- a/qlib/contrib/model/pytorch_tcn_ts.py +++ b/qlib/contrib/model/pytorch_tcn_ts.py @@ -234,6 +234,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.TCN_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_transformer.py b/qlib/contrib/model/pytorch_transformer.py index d05b9f4cad1..39cae8299b8 100644 --- a/qlib/contrib/model/pytorch_transformer.py +++ b/qlib/contrib/model/pytorch_transformer.py @@ -183,6 +183,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_transformer_ts.py b/qlib/contrib/model/pytorch_transformer_ts.py index 70590e03e5f..de5a3589ff5 100644 --- a/qlib/contrib/model/pytorch_transformer_ts.py +++ b/qlib/contrib/model/pytorch_transformer_ts.py @@ -169,6 +169,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...")