fix for merge

This commit is contained in:
Aleksey Savin 2024-04-23 22:42:04 +00:00
parent 4b9f0c2fc2
commit 49487afc86

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections import deque
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import datasieve.transforms as ds
import numpy as np
@ -108,7 +108,7 @@ class IFreqaiModel(ABC):
self.data_provider: Optional[DataProvider] = None
self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
self.can_short = True # overridden in start() with strategy.can_short
self.model: Any = None
self.model: Union[Any, None] = None
if self.ft_params.get('principal_component_analysis', False) and self.continual_learning:
self.ft_params.update({'principal_component_analysis': False})
logger.warning('User tried to use PCA with continual learning. Deactivating PCA.')
@ -257,6 +257,21 @@ class IFreqaiModel(ABC):
if self.freqai_info.get('write_metrics_to_disk', False):
self.dd.save_metric_tracker_to_disk()
def _train_model(self, dataframe_train, pair, dk, tr_backtest) -> Union[Any, None]:
try:
self.tb_logger = get_tb_logger(self.dd.model_type, dk.data_path,
self.activate_tensorboard)
model = self.train(dataframe_train, pair, dk)
self.tb_logger.close()
return model
except Exception as msg:
logger.warning(
f"Training {pair} raised exception {msg.__class__.__name__}. "
f"from {tr_backtest.start_fmt} to {tr_backtest.stop_fmt}."
f"Message: {msg}, skipping.", exc_info=True)
return None
def start_backtesting(
self, dataframe: DataFrame, metadata: dict, dk: FreqaiDataKitchen, strategy: IStrategy
) -> FreqaiDataKitchen:
@ -344,35 +359,26 @@ class IFreqaiModel(ABC):
if not self.model_exists(dk):
dk.find_features(dataframe_train)
dk.find_labels(dataframe_train)
self.model = self._train_model(dataframe_train, pair, dk, tr_backtest)
try:
self.tb_logger = get_tb_logger(self.dd.model_type, dk.data_path,
self.activate_tensorboard)
self.model = self.train(dataframe_train, pair, dk)
self.tb_logger.close()
except Exception as msg:
logger.warning(
f"Training {pair} raised exception {msg.__class__.__name__}. "
f"Message: {msg}, skipping.", exc_info=True)
self.model = None
self.dd.pair_dict[pair]["trained_timestamp"] = int(tr_train.stopts)
if self.plot_features and self.model is not None:
plot_feature_importance(self.model, pair, dk, self.plot_features)
if self.save_backtest_models and self.model is not None:
logger.info('Saving backtest model to disk.')
self.dd.save_data(self.model, pair, dk)
else:
logger.info('Saving metadata to disk.')
self.dd.save_metadata(dk)
if self.model:
self.dd.pair_dict[pair]["trained_timestamp"] = int(tr_train.stopts)
if self.plot_features and self.model is not None:
plot_feature_importance(self.model, pair, dk, self.plot_features)
if self.save_backtest_models and self.model is not None:
logger.info('Saving backtest model to disk.')
self.dd.save_data(self.model, pair, dk)
else:
logger.info('Saving metadata to disk.')
self.dd.save_metadata(dk)
else:
self.model = self.dd.load_data(pair, dk)
pred_df, do_preds = self.predict(dataframe_backtest, dk)
append_df = dk.get_predictions_to_append(pred_df, do_preds, dataframe_backtest)
dk.append_predictions(append_df)
dk.save_backtesting_prediction(append_df)
if self.model:
pred_df, do_preds = self.predict(dataframe_backtest, dk)
append_df = dk.get_predictions_to_append(pred_df, do_preds, dataframe_backtest)
dk.append_predictions(append_df)
dk.save_backtesting_prediction(append_df)
self.backtesting_fit_live_predictions(dk)
dk.fill_predictions(dataframe)