add maskable eval callback

This commit is contained in:
steam 2023-06-11 20:05:53 +03:00
parent afd54d39a5
commit c36547a563

View File

@ -13,7 +13,8 @@ import pandas as pd
import torch as th
import torch.multiprocessing
from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
from sb3_contrib.common.maskable.utils import is_masking_supported
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
@ -48,7 +49,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
self.eval_callback: Optional[EvalCallback] = None
self.eval_callback: Optional[MaskableEvalCallback] = None
self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config']
self.df_raw: DataFrame = DataFrame()
@ -151,9 +152,11 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, **env_info)
self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test, **env_info))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path))
self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path),
use_masking=(self.model_type == 'MaskablePPO' and
is_masking_supported(self.eval_env)))
actions = self.train_env.get_actions()
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)