mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-12 19:23:55 +00:00
Merge pull request #8336 from richardjozsa/develop
Added the latest Gymnasium version 0.28(will be released shortly),
This commit is contained in:
commit
c26099280f
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||||
|
|
||||||
|
@ -94,9 +94,12 @@ class Base3ActionRLEnv(BaseEnvironment):
|
||||||
|
|
||||||
observation = self._get_observation()
|
observation = self._get_observation()
|
||||||
|
|
||||||
|
# user can play with time if they want
|
||||||
|
truncated = False
|
||||||
|
|
||||||
self._update_history(info)
|
self._update_history(info)
|
||||||
|
|
||||||
return observation, step_reward, self._done, info
|
return observation, step_reward, self._done, truncated, info
|
||||||
|
|
||||||
def is_tradesignal(self, action: int) -> bool:
|
def is_tradesignal(self, action: int) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||||
|
|
||||||
|
@ -96,9 +96,12 @@ class Base4ActionRLEnv(BaseEnvironment):
|
||||||
|
|
||||||
observation = self._get_observation()
|
observation = self._get_observation()
|
||||||
|
|
||||||
|
# user can play with time if they want
|
||||||
|
truncated = False
|
||||||
|
|
||||||
self._update_history(info)
|
self._update_history(info)
|
||||||
|
|
||||||
return observation, step_reward, self._done, info
|
return observation, step_reward, self._done, truncated, info
|
||||||
|
|
||||||
def is_tradesignal(self, action: int) -> bool:
|
def is_tradesignal(self, action: int) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||||
|
|
||||||
|
@ -101,10 +101,12 @@ class Base5ActionRLEnv(BaseEnvironment):
|
||||||
)
|
)
|
||||||
|
|
||||||
observation = self._get_observation()
|
observation = self._get_observation()
|
||||||
|
# user can play with time if they want
|
||||||
|
truncated = False
|
||||||
|
|
||||||
self._update_history(info)
|
self._update_history(info)
|
||||||
|
|
||||||
return observation, step_reward, self._done, info
|
return observation, step_reward, self._done, truncated, info
|
||||||
|
|
||||||
def is_tradesignal(self, action: int) -> bool:
|
def is_tradesignal(self, action: int) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -4,11 +4,11 @@ from abc import abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Type, Union
|
from typing import Optional, Type, Union
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from gym.utils import seeding
|
from gymnasium.utils import seeding
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,6 +127,14 @@ class BaseEnvironment(gym.Env):
|
||||||
self.history: dict = {}
|
self.history: dict = {}
|
||||||
self.trade_history: list = []
|
self.trade_history: list = []
|
||||||
|
|
||||||
|
def get_attr(self, attr: str):
|
||||||
|
"""
|
||||||
|
Returns the attribute of the environment
|
||||||
|
:param attr: attribute to return
|
||||||
|
:return: attribute
|
||||||
|
"""
|
||||||
|
return getattr(self, attr)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_action_space(self):
|
def set_action_space(self):
|
||||||
"""
|
"""
|
||||||
|
@ -203,7 +211,7 @@ class BaseEnvironment(gym.Env):
|
||||||
self.close_trade_profit = []
|
self.close_trade_profit = []
|
||||||
self._total_unrealized_profit = 1
|
self._total_unrealized_profit = 1
|
||||||
|
|
||||||
return self._get_observation()
|
return self._get_observation(), self.history
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def step(self, action: int):
|
def step(self, action: int):
|
||||||
|
|
|
@ -6,7 +6,7 @@ from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -16,13 +16,13 @@ from pandas import DataFrame
|
||||||
from stable_baselines3.common.callbacks import EvalCallback
|
from stable_baselines3.common.callbacks import EvalCallback
|
||||||
from stable_baselines3.common.monitor import Monitor
|
from stable_baselines3.common.monitor import Monitor
|
||||||
from stable_baselines3.common.utils import set_random_seed
|
from stable_baselines3.common.utils import set_random_seed
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
|
||||||
|
|
||||||
from freqtrade.exceptions import OperationalException
|
from freqtrade.exceptions import OperationalException
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||||
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
|
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions
|
||||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||||
from freqtrade.persistence import Trade
|
from freqtrade.persistence import Trade
|
||||||
|
|
||||||
|
@ -46,8 +46,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||||
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
|
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
|
||||||
th.set_num_threads(self.max_threads)
|
th.set_num_threads(self.max_threads)
|
||||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||||
self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
|
self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
|
||||||
self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
|
self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
|
||||||
self.eval_callback: Optional[EvalCallback] = None
|
self.eval_callback: Optional[EvalCallback] = None
|
||||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||||
self.rl_config = self.freqai_info['rl_config']
|
self.rl_config = self.freqai_info['rl_config']
|
||||||
|
@ -431,9 +431,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||||
return 0.
|
return 0.
|
||||||
|
|
||||||
|
|
||||||
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
def make_env(MyRLEnv: Type[BaseEnvironment], env_id: str, rank: int,
|
||||||
seed: int, train_df: DataFrame, price: DataFrame,
|
seed: int, train_df: DataFrame, price: DataFrame,
|
||||||
monitor: bool = False,
|
|
||||||
env_info: Dict[str, Any] = {}) -> Callable:
|
env_info: Dict[str, Any] = {}) -> Callable:
|
||||||
"""
|
"""
|
||||||
Utility function for multiprocessed env.
|
Utility function for multiprocessed env.
|
||||||
|
@ -450,8 +449,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||||
|
|
||||||
env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank,
|
env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank,
|
||||||
**env_info)
|
**env_info)
|
||||||
if monitor:
|
|
||||||
env = Monitor(env)
|
|
||||||
return env
|
return env
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
return _init
|
return _init
|
||||||
|
|
|
@ -3,8 +3,9 @@ from typing import Any, Dict, Type, Union
|
||||||
|
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.logger import HParam
|
from stable_baselines3.common.logger import HParam
|
||||||
|
from stable_baselines3.common.vec_env import VecEnv
|
||||||
|
|
||||||
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment
|
from freqtrade.freqai.RL.BaseEnvironment import BaseActions
|
||||||
|
|
||||||
|
|
||||||
class TensorboardCallback(BaseCallback):
|
class TensorboardCallback(BaseCallback):
|
||||||
|
@ -12,11 +13,13 @@ class TensorboardCallback(BaseCallback):
|
||||||
Custom callback for plotting additional values in tensorboard and
|
Custom callback for plotting additional values in tensorboard and
|
||||||
episodic summary reports.
|
episodic summary reports.
|
||||||
"""
|
"""
|
||||||
|
# Override training_env type to fix type errors
|
||||||
|
training_env: Union[VecEnv, None] = None
|
||||||
|
|
||||||
def __init__(self, verbose=1, actions: Type[Enum] = BaseActions):
|
def __init__(self, verbose=1, actions: Type[Enum] = BaseActions):
|
||||||
super().__init__(verbose)
|
super().__init__(verbose)
|
||||||
self.model: Any = None
|
self.model: Any = None
|
||||||
self.logger = None # type: Any
|
self.logger: Any = None
|
||||||
self.training_env: BaseEnvironment = None # type: ignore
|
|
||||||
self.actions: Type[Enum] = actions
|
self.actions: Type[Enum] = actions
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
def _on_training_start(self) -> None:
|
||||||
|
@ -44,6 +47,8 @@ class TensorboardCallback(BaseCallback):
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
|
|
||||||
local_info = self.locals["infos"][0]
|
local_info = self.locals["infos"][0]
|
||||||
|
if self.training_env is None:
|
||||||
|
return True
|
||||||
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
||||||
|
|
||||||
for metric in local_info:
|
for metric in local_info:
|
||||||
|
|
|
@ -242,8 +242,8 @@ class IFreqaiModel(ABC):
|
||||||
new_trained_timerange, pair, strategy, dk, data_load_timerange
|
new_trained_timerange, pair, strategy, dk, data_load_timerange
|
||||||
)
|
)
|
||||||
except Exception as msg:
|
except Exception as msg:
|
||||||
logger.warning(f"Training {pair} raised exception {msg.__class__.__name__}. "
|
logger.exception(f"Training {pair} raised exception {msg.__class__.__name__}. "
|
||||||
f"Message: {msg}, skipping.")
|
f"Message: {msg}, skipping.")
|
||||||
|
|
||||||
self.train_timer('stop', pair)
|
self.train_timer('stop', pair)
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Type
|
||||||
|
|
||||||
import torch as th
|
import torch as th
|
||||||
|
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
|
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
|
||||||
|
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment
|
||||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
|
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,7 +85,9 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
class MyRLEnv(Base5ActionRLEnv):
|
MyRLEnv: Type[BaseEnvironment]
|
||||||
|
|
||||||
|
class MyRLEnv(Base5ActionRLEnv): # type: ignore[no-redef]
|
||||||
"""
|
"""
|
||||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||||
sets a custom reward based on profit and trade duration.
|
sets a custom reward based on profit and trade duration.
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Any, Dict
|
||||||
|
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from stable_baselines3.common.callbacks import EvalCallback
|
from stable_baselines3.common.callbacks import EvalCallback
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
|
||||||
|
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||||
|
@ -41,22 +41,25 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
|
||||||
|
|
||||||
env_info = self.pack_env_dict(dk.pair)
|
env_info = self.pack_env_dict(dk.pair)
|
||||||
|
|
||||||
|
eval_freq = len(train_df) // self.max_threads
|
||||||
|
|
||||||
env_id = "train_env"
|
env_id = "train_env"
|
||||||
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
|
self.train_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
|
||||||
train_df, prices_train,
|
train_df, prices_train,
|
||||||
monitor=True,
|
env_info=env_info) for i
|
||||||
env_info=env_info) for i
|
in range(self.max_threads)]))
|
||||||
in range(self.max_threads)])
|
|
||||||
|
|
||||||
eval_env_id = 'eval_env'
|
eval_env_id = 'eval_env'
|
||||||
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
self.eval_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
||||||
test_df, prices_test,
|
test_df, prices_test,
|
||||||
monitor=True,
|
env_info=env_info) for i
|
||||||
env_info=env_info) for i
|
in range(self.max_threads)]))
|
||||||
in range(self.max_threads)])
|
|
||||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||||
render=False, eval_freq=len(train_df),
|
render=False, eval_freq=eval_freq,
|
||||||
best_model_save_path=str(dk.data_path))
|
best_model_save_path=str(dk.data_path))
|
||||||
|
|
||||||
|
# TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS,
|
||||||
|
# IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!
|
||||||
actions = self.train_env.env_method("get_actions")[0]
|
actions = self.train_env.env_method("get_actions")[0]
|
||||||
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||||
|
|
|
@ -3,10 +3,11 @@
|
||||||
|
|
||||||
# Required for freqai-rl
|
# Required for freqai-rl
|
||||||
torch==1.13.1; python_version < '3.11'
|
torch==1.13.1; python_version < '3.11'
|
||||||
stable-baselines3==1.7.0; python_version < '3.11'
|
#until these branches will be released we can use this
|
||||||
sb3-contrib==1.7.0; python_version < '3.11'
|
gymnasium==0.28.1
|
||||||
|
stable_baselines3==2.0.0a5
|
||||||
|
sb3_contrib>=2.0.0a4
|
||||||
# Gym is forced to this version by stable-baselines3.
|
# Gym is forced to this version by stable-baselines3.
|
||||||
setuptools==65.5.1 # Should be removed when gym is fixed.
|
setuptools==65.5.1 # Should be removed when gym is fixed.
|
||||||
gym==0.21; python_version < '3.11'
|
|
||||||
# Progress bar for stable-baselines3 and sb3-contrib
|
# Progress bar for stable-baselines3 and sb3-contrib
|
||||||
tqdm==4.65.0; python_version < '3.11'
|
tqdm==4.65.0; python_version < '3.11'
|
||||||
|
|
Loading…
Reference in New Issue
Block a user