mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
add state/action info to callbacks
This commit is contained in:
parent
075c8c23c8
commit
469aa0d43f
|
@ -13,9 +13,11 @@ import torch as th
|
|||
import torch.multiprocessing
|
||||
from pandas import DataFrame
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.logger import HParam
|
||||
|
||||
from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
|
@ -155,6 +157,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
||||
|
||||
self.tensorboard_callback = TensorboardCallback()
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||
|
@ -398,3 +402,48 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
|||
return env
|
||||
set_random_seed(seed)
|
||||
return _init
|
||||
|
||||
class TensorboardCallback(BaseCallback):
|
||||
"""
|
||||
Custom callback for plotting additional values in tensorboard.
|
||||
"""
|
||||
def __init__(self, verbose=1):
|
||||
super(TensorboardCallback, self).__init__(verbose)
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
hparam_dict = {
|
||||
"algorithm": self.model.__class__.__name__,
|
||||
"learning_rate": self.model.learning_rate,
|
||||
"gamma": self.model.gamma,
|
||||
"gae_lambda": self.model.gae_lambda,
|
||||
"batch_size": self.model.batch_size,
|
||||
"n_steps": self.model.n_steps,
|
||||
}
|
||||
metric_dict = {
|
||||
"eval/mean_reward": 0,
|
||||
"rollout/ep_rew_mean": 0,
|
||||
"rollout/ep_len_mean":0 ,
|
||||
"train/value_loss": 0,
|
||||
"train/explained_variance": 0,
|
||||
}
|
||||
self.logger.record(
|
||||
"hparams",
|
||||
HParam(hparam_dict, metric_dict),
|
||||
exclude=("stdout", "log", "json", "csv"),
|
||||
)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
custom_info = self.training_env.get_attr("custom_info")[0]
|
||||
self.logger.record(f"_state/position", self.locals["infos"][0]["position"])
|
||||
self.logger.record(f"_state/trade_duration", self.locals["infos"][0]["trade_duration"])
|
||||
self.logger.record(f"_state/current_profit_pct", self.locals["infos"][0]["current_profit_pct"])
|
||||
self.logger.record(f"_reward/total_profit", self.locals["infos"][0]["total_profit"])
|
||||
self.logger.record(f"_reward/total_reward", self.locals["infos"][0]["total_reward"])
|
||||
self.logger.record_mean(f"_reward/mean_trade_duration", self.locals["infos"][0]["trade_duration"])
|
||||
self.logger.record(f"_actions/action", self.locals["infos"][0]["action"])
|
||||
self.logger.record(f"_actions/_Invalid", custom_info["Invalid"])
|
||||
self.logger.record(f"_actions/_Unknown", custom_info["Unknown"])
|
||||
self.logger.record(f"_actions/Hold", custom_info["Hold"])
|
||||
for action in Actions:
|
||||
self.logger.record(f"_actions/{action.name}", custom_info[action.name])
|
||||
return True
|
Loading…
Reference in New Issue
Block a user