diff --git a/qlib/contrib/model/pytorch_adarnn.py b/qlib/contrib/model/pytorch_adarnn.py index c1585a6ac0..781462470c 100644 --- a/qlib/contrib/model/pytorch_adarnn.py +++ b/qlib/contrib/model/pytorch_adarnn.py @@ -268,6 +268,7 @@ def fit( best_epoch = 0 weight_mat, dist_mat = None, None + best_param = copy.deepcopy(self.model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_add.py b/qlib/contrib/model/pytorch_add.py index c94a03ecc3..17125c34ce 100644 --- a/qlib/contrib/model/pytorch_add.py +++ b/qlib/contrib/model/pytorch_add.py @@ -318,6 +318,7 @@ def bootstrap_fit(self, x_train, y_train, m_train, x_valid, y_valid, m_valid): y_train_values = np.squeeze(y_train.values) m_train_values = np.squeeze(m_train.values.astype(int)) + best_param = copy.deepcopy(self.ADD_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index d1c619ebf4..24818c5b58 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -235,6 +235,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.ALSTM_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 95b5cf95d8..a83082bbf2 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -255,6 +255,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.ALSTM_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 16439b3783..e766c68324 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -269,6 +269,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.GAT_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 09f0ac08b2..71b6266a3f 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -283,6 +283,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.GAT_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index 65da5ac4b4..eb7c887963 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -249,6 +249,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.GRU_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_hist.py b/qlib/contrib/model/pytorch_hist.py index 779cde9c85..eac36be651 100644 --- a/qlib/contrib/model/pytorch_hist.py +++ b/qlib/contrib/model/pytorch_hist.py @@ -300,6 +300,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.HIST_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_igmtf.py b/qlib/contrib/model/pytorch_igmtf.py index 0bddc5a0f5..9902690fb4 100644 --- a/qlib/contrib/model/pytorch_igmtf.py +++ b/qlib/contrib/model/pytorch_igmtf.py @@ -294,6 +294,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.igmtf_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_krnn.py b/qlib/contrib/model/pytorch_krnn.py index d97920b4dc..f5cacc7abc 100644 --- a/qlib/contrib/model/pytorch_krnn.py +++ b/qlib/contrib/model/pytorch_krnn.py @@ -458,6 +458,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.krnn_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_localformer.py b/qlib/contrib/model/pytorch_localformer.py index 42851dd6a2..6d95dd475d 100644 --- a/qlib/contrib/model/pytorch_localformer.py +++ b/qlib/contrib/model/pytorch_localformer.py @@ -184,6 +184,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_localformer_ts.py b/qlib/contrib/model/pytorch_localformer_ts.py index ae60a39968..f014883f3b 100644 --- a/qlib/contrib/model/pytorch_localformer_ts.py +++ b/qlib/contrib/model/pytorch_localformer_ts.py @@ -171,6 +171,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py index 3ba09097ac..4c62b8e3a9 100755 --- a/qlib/contrib/model/pytorch_lstm.py +++ b/qlib/contrib/model/pytorch_lstm.py @@ -230,6 +230,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.lstm_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index a0fc34d583..72f35cbba2 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -244,6 +244,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.LSTM_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_sandwich.py b/qlib/contrib/model/pytorch_sandwich.py index 344368143f..5fd3135afa 100644 --- a/qlib/contrib/model/pytorch_sandwich.py +++ b/qlib/contrib/model/pytorch_sandwich.py @@ -328,6 +328,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.sandwich_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index c971f1a58c..8854baa0af 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -385,6 +385,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.sfm_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index 3c698edade..02ece6ffa6 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -185,6 +185,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.tabnet_model.state_dict()) for epoch_idx in range(self.n_epochs): self.logger.info("epoch: %s" % (epoch_idx)) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_tcn.py b/qlib/contrib/model/pytorch_tcn.py index f6e7e953a0..a28d9601c0 100755 --- a/qlib/contrib/model/pytorch_tcn.py +++ b/qlib/contrib/model/pytorch_tcn.py @@ -240,6 +240,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.tcn_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_tcn_ts.py b/qlib/contrib/model/pytorch_tcn_ts.py index a6cc38885c..0cc644bbe5 100755 --- a/qlib/contrib/model/pytorch_tcn_ts.py +++ b/qlib/contrib/model/pytorch_tcn_ts.py @@ -234,6 +234,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.TCN_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_transformer.py b/qlib/contrib/model/pytorch_transformer.py index d05b9f4cad..39cae8299b 100644 --- a/qlib/contrib/model/pytorch_transformer.py +++ b/qlib/contrib/model/pytorch_transformer.py @@ -183,6 +183,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") diff --git a/qlib/contrib/model/pytorch_transformer_ts.py b/qlib/contrib/model/pytorch_transformer_ts.py index 70590e03e5..de5a3589ff 100644 --- a/qlib/contrib/model/pytorch_transformer_ts.py +++ b/qlib/contrib/model/pytorch_transformer_ts.py @@ -169,6 +169,7 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...")