mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-13 03:33:55 +00:00
add maskable eval callback
This commit is contained in:
parent
afd54d39a5
commit
c36547a563
|
@ -13,7 +13,8 @@ import pandas as pd
|
||||||
import torch as th
|
import torch as th
|
||||||
import torch.multiprocessing
|
import torch.multiprocessing
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from stable_baselines3.common.callbacks import EvalCallback
|
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
||||||
|
from sb3_contrib.common.maskable.utils import is_masking_supported
|
||||||
from stable_baselines3.common.monitor import Monitor
|
from stable_baselines3.common.monitor import Monitor
|
||||||
from stable_baselines3.common.utils import set_random_seed
|
from stable_baselines3.common.utils import set_random_seed
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
|
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
|
||||||
|
@ -48,7 +49,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||||
self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
|
self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
|
||||||
self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
|
self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
|
||||||
self.eval_callback: Optional[EvalCallback] = None
|
self.eval_callback: Optional[MaskableEvalCallback] = None
|
||||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||||
self.rl_config = self.freqai_info['rl_config']
|
self.rl_config = self.freqai_info['rl_config']
|
||||||
self.df_raw: DataFrame = DataFrame()
|
self.df_raw: DataFrame = DataFrame()
|
||||||
|
@ -151,9 +152,11 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||||
|
|
||||||
self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, **env_info)
|
self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, **env_info)
|
||||||
self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test, **env_info))
|
self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test, **env_info))
|
||||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True,
|
||||||
render=False, eval_freq=len(train_df),
|
render=False, eval_freq=len(train_df),
|
||||||
best_model_save_path=str(dk.data_path))
|
best_model_save_path=str(dk.data_path),
|
||||||
|
use_masking=(self.model_type == 'MaskablePPO' and
|
||||||
|
is_masking_supported(self.eval_env)))
|
||||||
|
|
||||||
actions = self.train_env.get_actions()
|
actions = self.train_env.get_actions()
|
||||||
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user