freqtrade_origin/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py

86 lines
3.0 KiB
Python
Raw Normal View History

import logging
from typing import Any
2022-09-23 17:30:56 +00:00
from pandas import DataFrame
2023-06-11 19:20:15 +00:00
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
from sb3_contrib.common.maskable.utils import is_masking_supported
2023-04-17 18:27:18 +00:00
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback
logger = logging.getLogger(__name__)
class ReinforcementLearner_multiproc(ReinforcementLearner):
"""
Demonstration of how to build vectorized environments
"""
2024-05-12 15:12:20 +00:00
def set_train_and_eval_environments(
self,
data_dictionary: dict[str, Any],
2024-05-12 15:12:20 +00:00
prices_train: DataFrame,
prices_test: DataFrame,
dk: FreqaiDataKitchen,
):
"""
User can override this if they are using a custom MyRLEnv
2022-11-13 16:43:52 +00:00
:param data_dictionary: dict = common data dictionary containing train and test
features/labels/weights.
2022-11-13 16:43:52 +00:00
:param prices_train/test: DataFrame = dataframe comprised of the prices to be used in
the environment during training
or testing
2022-11-13 16:43:52 +00:00
:param dk: FreqaiDataKitchen = the datakitchen for the current pair
"""
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
if self.train_env:
self.train_env.close()
if self.eval_env:
self.eval_env.close()
env_info = self.pack_env_dict(dk.pair)
2022-12-14 19:03:05 +00:00
eval_freq = len(train_df) // self.max_threads
env_id = "train_env"
2024-05-12 15:12:20 +00:00
self.train_env = VecMonitor(
SubprocVecEnv(
[
make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, env_info=env_info)
for i in range(self.max_threads)
]
)
)
2024-05-12 15:12:20 +00:00
eval_env_id = "eval_env"
self.eval_env = VecMonitor(
SubprocVecEnv(
[
make_env(
self.MyRLEnv, eval_env_id, i, 1, test_df, prices_test, env_info=env_info
)
for i in range(self.max_threads)
]
)
)
2023-04-17 18:27:18 +00:00
2024-05-12 15:12:20 +00:00
self.eval_callback = MaskableEvalCallback(
self.eval_env,
deterministic=True,
render=False,
eval_freq=eval_freq,
best_model_save_path=str(dk.data_path),
use_masking=(self.model_type == "MaskablePPO" and is_masking_supported(self.eval_env)),
)
2022-12-03 11:30:04 +00:00
2023-04-17 18:27:18 +00:00
# TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS,
2024-04-18 20:51:25 +00:00
# IT WILL RETURN FALSE INFORMATION, NEVERTHELESS NOT THREAD SAFE WITH SB3!!!
actions = self.train_env.env_method("get_actions")[0]
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)