From bd870e233128d655ac89a091c1aa6a8b2196c0d7 Mon Sep 17 00:00:00 2001 From: robcaulk Date: Wed, 24 Aug 2022 16:32:14 +0200 Subject: [PATCH] fix monitor bug, set default values in case user doesnt set params --- freqtrade/freqai/RL/BaseReinforcementLearningModel.py | 4 ++-- freqtrade/freqai/prediction_models/ReinforcementLearner.py | 3 ++- .../prediction_models/ReinforcementLearner_multiproc.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 6660709bd..1bc3505e1 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -42,7 +42,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.eval_callback: EvalCallback = None self.model_type = self.freqai_info['rl_config']['model_type'] self.rl_config = self.freqai_info['rl_config'] - self.continual_retraining = self.rl_config['continual_retraining'] + self.continual_retraining = self.rl_config.get('continual_retraining', False) if self.model_type in SB3_MODELS: import_str = 'stable_baselines3' elif self.model_type in SB3_CONTRIB_MODELS: @@ -289,7 +289,7 @@ class MyRLEnv(Base5ActionRLEnv): return 0. pnl = self.get_unrealized_profit() - max_trade_duration = self.rl_config['max_trade_duration_candles'] + max_trade_duration = self.rl_config.get('max_trade_duration_candles', 100) trade_duration = self._current_tick - self._last_trade_tick factor = 1 diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 254fd32b0..f7f016ab4 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -32,6 +32,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel): logger.info('Continual training activated - starting training from previously ' 'trained agent.') model = self.dd.model_dictionary[dk.pair] + model.tensorboard_log = Path(dk.data_path / "tensorboard") model.set_env(self.train_env) model.learn( @@ -61,7 +62,7 @@ class MyRLEnv(Base5ActionRLEnv): return 0. pnl = self.get_unrealized_profit() - max_trade_duration = self.rl_config['max_trade_duration_candles'] + max_trade_duration = self.rl_config.get('max_trade_duration_candles', 100) trade_duration = self._current_tick - self._last_trade_tick factor = 1 diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index 17281e2d0..3a4c245aa 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -26,10 +26,10 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel): # model arch policy_kwargs = dict(activation_fn=th.nn.ReLU, - net_arch=[512, 512, 512]) + net_arch=[512, 512, 256]) model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs, - tensorboard_log=Path(dk.data_path / "tensorboard"), + tensorboard_log=Path(dk.full_path / "tensorboard"), **self.freqai_info['model_training_parameters'] )