Merge pull request #8760 from initrv/rl-action-masks

Add MaskablePPO support
This commit is contained in:
Robert Caulk 2023-06-17 16:28:43 +02:00 committed by GitHub
commit 402a247c92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 10 deletions

View File

@ -2,7 +2,7 @@ import logging
import random
from abc import abstractmethod
from enum import Enum
from typing import Optional, Type, Union
from typing import List, Optional, Type, Union
import gymnasium as gym
import numpy as np
@ -141,6 +141,9 @@ class BaseEnvironment(gym.Env):
Unique to the environment action count. Must be inherited.
"""
def action_masks(self) -> List[bool]:
return [self._is_valid(action.value) for action in self.actions]
def seed(self, seed: int = 1):
self.np_random, seed = seeding.np_random(seed)
return [seed]

View File

@ -13,7 +13,8 @@ import pandas as pd
import torch as th
import torch.multiprocessing
from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
from sb3_contrib.common.maskable.utils import is_masking_supported
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
@ -48,7 +49,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
self.eval_callback: Optional[EvalCallback] = None
self.eval_callback: Optional[MaskableEvalCallback] = None
self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config']
self.df_raw: DataFrame = DataFrame()
@ -151,9 +152,11 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, **env_info)
self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test, **env_info))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path))
self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path),
use_masking=(self.model_type == 'MaskablePPO' and
is_masking_supported(self.eval_env)))
actions = self.train_env.get_actions()
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)

View File

@ -2,7 +2,8 @@ import logging
from typing import Any, Dict
from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
from sb3_contrib.common.maskable.utils import is_masking_supported
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
@ -55,9 +56,11 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
env_info=env_info) for i
in range(self.max_threads)]))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=eval_freq,
best_model_save_path=str(dk.data_path))
self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=eval_freq,
best_model_save_path=str(dk.data_path),
use_masking=(self.model_type == 'MaskablePPO' and
is_masking_supported(self.eval_env)))
# TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS,
# IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!