diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index 83bd3550b..224b9dbe1 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import deque from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import datasieve.transforms as ds import numpy as np @@ -108,7 +108,7 @@ class IFreqaiModel(ABC): self.data_provider: Optional[DataProvider] = None self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1) self.can_short = True # overridden in start() with strategy.can_short - self.model: Any = None + self.model: Union[Any, None] = None if self.ft_params.get('principal_component_analysis', False) and self.continual_learning: self.ft_params.update({'principal_component_analysis': False}) logger.warning('User tried to use PCA with continual learning. Deactivating PCA.') @@ -257,6 +257,21 @@ class IFreqaiModel(ABC): if self.freqai_info.get('write_metrics_to_disk', False): self.dd.save_metric_tracker_to_disk() + + def _train_model(self, dataframe_train, pair, dk, tr_backtest) -> Union[Any, None]: + try: + self.tb_logger = get_tb_logger(self.dd.model_type, dk.data_path, + self.activate_tensorboard) + model = self.train(dataframe_train, pair, dk) + self.tb_logger.close() + return model + except Exception as msg: + logger.warning( + f"Training {pair} raised exception {msg.__class__.__name__}. " + f"from {tr_backtest.start_fmt} to {tr_backtest.stop_fmt}." + f"Message: {msg}, skipping.", exc_info=True) + return None + def start_backtesting( self, dataframe: DataFrame, metadata: dict, dk: FreqaiDataKitchen, strategy: IStrategy ) -> FreqaiDataKitchen: @@ -344,35 +359,26 @@ class IFreqaiModel(ABC): if not self.model_exists(dk): dk.find_features(dataframe_train) dk.find_labels(dataframe_train) + self.model = self._train_model(dataframe_train, pair, dk, tr_backtest) - try: - self.tb_logger = get_tb_logger(self.dd.model_type, dk.data_path, - self.activate_tensorboard) - self.model = self.train(dataframe_train, pair, dk) - self.tb_logger.close() - except Exception as msg: - logger.warning( - f"Training {pair} raised exception {msg.__class__.__name__}. " - f"Message: {msg}, skipping.", exc_info=True) - self.model = None - - self.dd.pair_dict[pair]["trained_timestamp"] = int(tr_train.stopts) - if self.plot_features and self.model is not None: - plot_feature_importance(self.model, pair, dk, self.plot_features) - if self.save_backtest_models and self.model is not None: - logger.info('Saving backtest model to disk.') - self.dd.save_data(self.model, pair, dk) - else: - logger.info('Saving metadata to disk.') - self.dd.save_metadata(dk) - + if self.model: + self.dd.pair_dict[pair]["trained_timestamp"] = int(tr_train.stopts) + if self.plot_features and self.model is not None: + plot_feature_importance(self.model, pair, dk, self.plot_features) + if self.save_backtest_models and self.model is not None: + logger.info('Saving backtest model to disk.') + self.dd.save_data(self.model, pair, dk) + else: + logger.info('Saving metadata to disk.') + self.dd.save_metadata(dk) else: self.model = self.dd.load_data(pair, dk) - pred_df, do_preds = self.predict(dataframe_backtest, dk) - append_df = dk.get_predictions_to_append(pred_df, do_preds, dataframe_backtest) - dk.append_predictions(append_df) - dk.save_backtesting_prediction(append_df) + if self.model: + pred_df, do_preds = self.predict(dataframe_backtest, dk) + append_df = dk.get_predictions_to_append(pred_df, do_preds, dataframe_backtest) + dk.append_predictions(append_df) + dk.save_backtesting_prediction(append_df) self.backtesting_fit_live_predictions(dk) dk.fill_predictions(dataframe)