Merge pull request #8336 from richardjozsa/develop

Added the latest Gymnasium version 0.28(will be released shortly),
This commit is contained in:
Robert Caulk 2023-05-01 07:32:37 +02:00 committed by GitHub
commit c26099280f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 67 additions and 41 deletions

View File

@ -1,7 +1,7 @@
import logging import logging
from enum import Enum from enum import Enum
from gym import spaces from gymnasium import spaces
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
@ -94,9 +94,12 @@ class Base3ActionRLEnv(BaseEnvironment):
observation = self._get_observation() observation = self._get_observation()
# user can play with time if they want
truncated = False
self._update_history(info) self._update_history(info)
return observation, step_reward, self._done, info return observation, step_reward, self._done, truncated, info
def is_tradesignal(self, action: int) -> bool: def is_tradesignal(self, action: int) -> bool:
""" """

View File

@ -1,7 +1,7 @@
import logging import logging
from enum import Enum from enum import Enum
from gym import spaces from gymnasium import spaces
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
@ -96,9 +96,12 @@ class Base4ActionRLEnv(BaseEnvironment):
observation = self._get_observation() observation = self._get_observation()
# user can play with time if they want
truncated = False
self._update_history(info) self._update_history(info)
return observation, step_reward, self._done, info return observation, step_reward, self._done, truncated, info
def is_tradesignal(self, action: int) -> bool: def is_tradesignal(self, action: int) -> bool:
""" """

View File

@ -1,7 +1,7 @@
import logging import logging
from enum import Enum from enum import Enum
from gym import spaces from gymnasium import spaces
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
@ -101,10 +101,12 @@ class Base5ActionRLEnv(BaseEnvironment):
) )
observation = self._get_observation() observation = self._get_observation()
# user can play with time if they want
truncated = False
self._update_history(info) self._update_history(info)
return observation, step_reward, self._done, info return observation, step_reward, self._done, truncated, info
def is_tradesignal(self, action: int) -> bool: def is_tradesignal(self, action: int) -> bool:
""" """

View File

@ -4,11 +4,11 @@ from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Optional, Type, Union from typing import Optional, Type, Union
import gym import gymnasium as gym
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from gym import spaces from gymnasium import spaces
from gym.utils import seeding from gymnasium.utils import seeding
from pandas import DataFrame from pandas import DataFrame
@ -127,6 +127,14 @@ class BaseEnvironment(gym.Env):
self.history: dict = {} self.history: dict = {}
self.trade_history: list = [] self.trade_history: list = []
def get_attr(self, attr: str):
"""
Returns the attribute of the environment
:param attr: attribute to return
:return: attribute
"""
return getattr(self, attr)
@abstractmethod @abstractmethod
def set_action_space(self): def set_action_space(self):
""" """
@ -203,7 +211,7 @@ class BaseEnvironment(gym.Env):
self.close_trade_profit = [] self.close_trade_profit = []
self._total_unrealized_profit = 1 self._total_unrealized_profit = 1
return self._get_observation() return self._get_observation(), self.history
@abstractmethod @abstractmethod
def step(self, action: int): def step(self, action: int):

View File

@ -6,7 +6,7 @@ from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
import gym import gymnasium as gym
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import pandas as pd import pandas as pd
@ -16,13 +16,13 @@ from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
from freqtrade.persistence import Trade from freqtrade.persistence import Trade
@ -46,8 +46,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
'cpu_count', 1), max(int(self.max_system_threads / 2), 1)) 'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
th.set_num_threads(self.max_threads) th.set_num_threads(self.max_threads)
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
self.eval_callback: Optional[EvalCallback] = None self.eval_callback: Optional[EvalCallback] = None
self.model_type = self.freqai_info['rl_config']['model_type'] self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config'] self.rl_config = self.freqai_info['rl_config']
@ -431,9 +431,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return 0. return 0.
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, def make_env(MyRLEnv: Type[BaseEnvironment], env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame, seed: int, train_df: DataFrame, price: DataFrame,
monitor: bool = False,
env_info: Dict[str, Any] = {}) -> Callable: env_info: Dict[str, Any] = {}) -> Callable:
""" """
Utility function for multiprocessed env. Utility function for multiprocessed env.
@ -450,8 +449,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank, env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank,
**env_info) **env_info)
if monitor:
env = Monitor(env)
return env return env
set_random_seed(seed) set_random_seed(seed)
return _init return _init

View File

@ -3,8 +3,9 @@ from typing import Any, Dict, Type, Union
from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam from stable_baselines3.common.logger import HParam
from stable_baselines3.common.vec_env import VecEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment from freqtrade.freqai.RL.BaseEnvironment import BaseActions
class TensorboardCallback(BaseCallback): class TensorboardCallback(BaseCallback):
@ -12,11 +13,13 @@ class TensorboardCallback(BaseCallback):
Custom callback for plotting additional values in tensorboard and Custom callback for plotting additional values in tensorboard and
episodic summary reports. episodic summary reports.
""" """
# Override training_env type to fix type errors
training_env: Union[VecEnv, None] = None
def __init__(self, verbose=1, actions: Type[Enum] = BaseActions): def __init__(self, verbose=1, actions: Type[Enum] = BaseActions):
super().__init__(verbose) super().__init__(verbose)
self.model: Any = None self.model: Any = None
self.logger = None # type: Any self.logger: Any = None
self.training_env: BaseEnvironment = None # type: ignore
self.actions: Type[Enum] = actions self.actions: Type[Enum] = actions
def _on_training_start(self) -> None: def _on_training_start(self) -> None:
@ -44,6 +47,8 @@ class TensorboardCallback(BaseCallback):
def _on_step(self) -> bool: def _on_step(self) -> bool:
local_info = self.locals["infos"][0] local_info = self.locals["infos"][0]
if self.training_env is None:
return True
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
for metric in local_info: for metric in local_info:

View File

@ -242,8 +242,8 @@ class IFreqaiModel(ABC):
new_trained_timerange, pair, strategy, dk, data_load_timerange new_trained_timerange, pair, strategy, dk, data_load_timerange
) )
except Exception as msg: except Exception as msg:
logger.warning(f"Training {pair} raised exception {msg.__class__.__name__}. " logger.exception(f"Training {pair} raised exception {msg.__class__.__name__}. "
f"Message: {msg}, skipping.") f"Message: {msg}, skipping.")
self.train_timer('stop', pair) self.train_timer('stop', pair)

View File

@ -1,11 +1,12 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict, Type
import torch as th import torch as th
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
@ -84,7 +85,9 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
return model return model
class MyRLEnv(Base5ActionRLEnv): MyRLEnv: Type[BaseEnvironment]
class MyRLEnv(Base5ActionRLEnv): # type: ignore[no-redef]
""" """
User can override any function in BaseRLEnv and gym.Env. Here the user User can override any function in BaseRLEnv and gym.Env. Here the user
sets a custom reward based on profit and trade duration. sets a custom reward based on profit and trade duration.

View File

@ -3,7 +3,7 @@ from typing import Any, Dict
from pandas import DataFrame from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
@ -41,22 +41,25 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
env_info = self.pack_env_dict(dk.pair) env_info = self.pack_env_dict(dk.pair)
eval_freq = len(train_df) // self.max_threads
env_id = "train_env" env_id = "train_env"
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, self.train_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
train_df, prices_train, train_df, prices_train,
monitor=True, env_info=env_info) for i
env_info=env_info) for i in range(self.max_threads)]))
in range(self.max_threads)])
eval_env_id = 'eval_env' eval_env_id = 'eval_env'
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, self.eval_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
test_df, prices_test, test_df, prices_test,
monitor=True, env_info=env_info) for i
env_info=env_info) for i in range(self.max_threads)]))
in range(self.max_threads)])
self.eval_callback = EvalCallback(self.eval_env, deterministic=True, self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df), render=False, eval_freq=eval_freq,
best_model_save_path=str(dk.data_path)) best_model_save_path=str(dk.data_path))
# TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS,
# IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!
actions = self.train_env.env_method("get_actions")[0] actions = self.train_env.env_method("get_actions")[0]
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)

View File

@ -3,10 +3,11 @@
# Required for freqai-rl # Required for freqai-rl
torch==1.13.1; python_version < '3.11' torch==1.13.1; python_version < '3.11'
stable-baselines3==1.7.0; python_version < '3.11' #until these branches will be released we can use this
sb3-contrib==1.7.0; python_version < '3.11' gymnasium==0.28.1
stable_baselines3==2.0.0a5
sb3_contrib>=2.0.0a4
# Gym is forced to this version by stable-baselines3. # Gym is forced to this version by stable-baselines3.
setuptools==65.5.1 # Should be removed when gym is fixed. setuptools==65.5.1 # Should be removed when gym is fixed.
gym==0.21; python_version < '3.11'
# Progress bar for stable-baselines3 and sb3-contrib # Progress bar for stable-baselines3 and sb3-contrib
tqdm==4.65.0; python_version < '3.11' tqdm==4.65.0; python_version < '3.11'