From 4baa36bdcf449e224eaa4c69001bc2c503253988 Mon Sep 17 00:00:00 2001 From: sonnhfit Date: Fri, 19 Aug 2022 01:49:11 +0700 Subject: [PATCH] fix persist a single training environment for PPO --- config_examples/config_freqai-rl.example.json | 8 +--- freqtrade/freqai/RL/Base3ActionRLEnv.py | 23 ++++++++-- .../ReinforcementLearningPPO.py | 45 ++++++++++++------- 3 files changed, 51 insertions(+), 25 deletions(-) diff --git a/config_examples/config_freqai-rl.example.json b/config_examples/config_freqai-rl.example.json index ccc977705..1af872552 100644 --- a/config_examples/config_freqai-rl.example.json +++ b/config_examples/config_freqai-rl.example.json @@ -79,13 +79,9 @@ "random_state": 1, "shuffle": false }, - "model_training_parameters": { + "model_training_parameters": { "learning_rate": 0.00025, "gamma": 0.9, - "target_update_interval": 5000, - "buffer_size": 50000, - "exploration_initial_eps":1, - "exploration_final_eps": 0.1, "verbose": 1 }, "rl_config": { @@ -103,4 +99,4 @@ "internals": { "process_throttle_secs": 5 } -} \ No newline at end of file +} diff --git a/freqtrade/freqai/RL/Base3ActionRLEnv.py b/freqtrade/freqai/RL/Base3ActionRLEnv.py index 9d17b982d..df53c729b 100644 --- a/freqtrade/freqai/RL/Base3ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base3ActionRLEnv.py @@ -1,13 +1,16 @@ import logging from enum import Enum -# from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import gym import numpy as np +import pandas as pd from gym import spaces from gym.utils import seeding from pandas import DataFrame + +# from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + logger = logging.getLogger(__name__) @@ -43,6 +46,9 @@ class Base3ActionRLEnv(gym.Env): self.id = id self.seed(seed) + self.reset_env(df, prices, window_size, reward_kwargs, starting_point) + + def reset_env(self, df, prices, window_size, reward_kwargs, starting_point=True): self.df = df self.signal_features = self.df self.prices = prices @@ -54,7 +60,7 @@ class Base3ActionRLEnv(gym.Env): self.fee = 0.0015 # # spaces - self.shape = (window_size, self.signal_features.shape[1]) + self.shape = (window_size, self.signal_features.shape[1] + 2) self.action_space = spaces.Discrete(len(Actions)) self.observation_space = spaces.Box( low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32) @@ -165,7 +171,16 @@ class Base3ActionRLEnv(gym.Env): return observation, step_reward, self._done, info def _get_observation(self): - return self.signal_features[(self._current_tick - self.window_size):self._current_tick] + features_window = self.signal_features[( + self._current_tick - self.window_size):self._current_tick] + features_and_state = DataFrame(np.zeros((len(features_window), 2)), + columns=['current_profit_pct', 'position'], + index=features_window.index) + + features_and_state['current_profit_pct'] = self.get_unrealized_profit() + features_and_state['position'] = self._position.value + features_and_state = pd.concat([features_window, features_and_state], axis=1) + return features_and_state def get_unrealized_profit(self): @@ -307,7 +322,7 @@ class Base3ActionRLEnv(gym.Env): def prev_price(self) -> float: return self.prices.iloc[self._current_tick - 1].open - def sharpe_ratio(self): + def sharpe_ratio(self) -> float: if len(self.close_trade_profit) == 0: return 0. returns = np.array(self.close_trade_profit) diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearningPPO.py b/freqtrade/freqai/prediction_models/ReinforcementLearningPPO.py index 5dc7735d3..993ac263b 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearningPPO.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearningPPO.py @@ -1,16 +1,17 @@ +import gc import logging from typing import Any, Dict # , Tuple import numpy as np # import numpy.typing as npt import torch as th -from pandas import DataFrame from stable_baselines3 import PPO from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.monitor import Monitor -from freqtrade.freqai.RL.Base3ActionRLEnv import Base3ActionRLEnv, Actions, Positions -from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel + from freqtrade.freqai.data_kitchen import FreqaiDataKitchen +from freqtrade.freqai.RL.Base3ActionRLEnv import Actions, Base3ActionRLEnv, Positions +from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel logger = logging.getLogger(__name__) @@ -21,23 +22,15 @@ class ReinforcementLearningPPO(BaseReinforcementLearningModel): User created Reinforcement Learning Model prediction model. """ - def fit_rl(self, data_dictionary: Dict[str, Any], pair: str, dk: FreqaiDataKitchen, - prices_train: DataFrame, prices_test: DataFrame): + def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen): train_df = data_dictionary["train_features"] test_df = data_dictionary["test_features"] eval_freq = self.freqai_info["rl_config"]["eval_cycles"] * len(test_df) total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df) - # environments - train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, - reward_kwargs=self.reward_params) - eval = MyRLEnv(df=test_df, prices=prices_test, - window_size=self.CONV_WIDTH, reward_kwargs=self.reward_params) - eval_env = Monitor(eval, ".") - path = dk.data_path - eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/", + eval_callback = EvalCallback(self.eval_env, best_model_save_path=f"{path}/", log_path=f"{path}/ppo/logs/", eval_freq=int(eval_freq), deterministic=True, render=False) @@ -45,8 +38,8 @@ class ReinforcementLearningPPO(BaseReinforcementLearningModel): policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=[256, 256, 128]) - model = PPO('MlpPolicy', train_env, policy_kwargs=policy_kwargs, - tensorboard_log=f"{path}/ppo/tensorboard/", learning_rate=0.00025, + model = PPO('MlpPolicy', self.train_env, policy_kwargs=policy_kwargs, + tensorboard_log=f"{path}/ppo/tensorboard/", **self.freqai_info['model_training_parameters'] ) @@ -55,12 +48,34 @@ class ReinforcementLearningPPO(BaseReinforcementLearningModel): callback=eval_callback ) + del model best_model = PPO.load(dk.data_path / "best_model") print('Training finished!') + gc.collect() return best_model + def set_train_and_eval_environments(self, data_dictionary, prices_train, prices_test): + """ + User overrides this as shown here if they are using a custom MyRLEnv + """ + train_df = data_dictionary["train_features"] + test_df = data_dictionary["test_features"] + + # environments + if not self.train_env: + self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, + reward_kwargs=self.reward_params) + self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test, + window_size=self.CONV_WIDTH, + reward_kwargs=self.reward_params), ".") + else: + self.train_env.reset_env(train_df, prices_train, self.CONV_WIDTH, self.reward_params) + self.eval_env.reset_env(train_df, prices_train, self.CONV_WIDTH, self.reward_params) + self.train_env.reset() + self.eval_env.reset() + class MyRLEnv(Base3ActionRLEnv): """