mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 18:23:55 +00:00
Merge pull request #8760 from initrv/rl-action-masks
Add MaskablePPO support
This commit is contained in:
commit
402a247c92
|
@ -2,7 +2,7 @@ import logging
|
|||
import random
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, Union
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
@ -141,6 +141,9 @@ class BaseEnvironment(gym.Env):
|
|||
Unique to the environment action count. Must be inherited.
|
||||
"""
|
||||
|
||||
def action_masks(self) -> List[bool]:
|
||||
return [self._is_valid(action.value) for action in self.actions]
|
||||
|
||||
def seed(self, seed: int = 1):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
|
|
@ -13,7 +13,8 @@ import pandas as pd
|
|||
import torch as th
|
||||
import torch.multiprocessing
|
||||
from pandas import DataFrame
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
||||
from sb3_contrib.common.maskable.utils import is_masking_supported
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
|
||||
|
@ -48,7 +49,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||
self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
|
||||
self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
|
||||
self.eval_callback: Optional[EvalCallback] = None
|
||||
self.eval_callback: Optional[MaskableEvalCallback] = None
|
||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||
self.rl_config = self.freqai_info['rl_config']
|
||||
self.df_raw: DataFrame = DataFrame()
|
||||
|
@ -151,9 +152,11 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||
|
||||
self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, **env_info)
|
||||
self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test, **env_info))
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
||||
self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path),
|
||||
use_masking=(self.model_type == 'MaskablePPO' and
|
||||
is_masking_supported(self.eval_env)))
|
||||
|
||||
actions = self.train_env.get_actions()
|
||||
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||
|
|
|
@ -2,7 +2,8 @@ import logging
|
|||
from typing import Any, Dict
|
||||
|
||||
from pandas import DataFrame
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
||||
from sb3_contrib.common.maskable.utils import is_masking_supported
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
|
||||
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
|
@ -55,9 +56,11 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
|
|||
env_info=env_info) for i
|
||||
in range(self.max_threads)]))
|
||||
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=eval_freq,
|
||||
best_model_save_path=str(dk.data_path))
|
||||
self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=eval_freq,
|
||||
best_model_save_path=str(dk.data_path),
|
||||
use_masking=(self.model_type == 'MaskablePPO' and
|
||||
is_masking_supported(self.eval_env)))
|
||||
|
||||
# TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS,
|
||||
# IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!
|
||||
|
|
Loading…
Reference in New Issue
Block a user