2022-12-04 12:54:30 +00:00
|
|
|
from enum import Enum
|
|
|
|
from typing import Any, Dict, Type, Union
|
|
|
|
|
|
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
|
|
from stable_baselines3.common.logger import HParam
|
|
|
|
|
2023-04-26 12:11:26 +00:00
|
|
|
from freqtrade.freqai.RL.BaseEnvironment import BaseActions
|
2022-12-04 12:54:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TensorboardCallback(BaseCallback):
|
|
|
|
"""
|
|
|
|
Custom callback for plotting additional values in tensorboard and
|
|
|
|
episodic summary reports.
|
|
|
|
"""
|
2024-05-12 15:12:20 +00:00
|
|
|
|
2022-12-04 12:54:30 +00:00
|
|
|
def __init__(self, verbose=1, actions: Type[Enum] = BaseActions):
|
2023-03-19 16:57:56 +00:00
|
|
|
super().__init__(verbose)
|
2022-12-04 12:54:30 +00:00
|
|
|
self.model: Any = None
|
|
|
|
self.actions: Type[Enum] = actions
|
|
|
|
|
|
|
|
def _on_training_start(self) -> None:
|
|
|
|
hparam_dict = {
|
|
|
|
"algorithm": self.model.__class__.__name__,
|
|
|
|
"learning_rate": self.model.learning_rate,
|
|
|
|
# "gamma": self.model.gamma,
|
|
|
|
# "gae_lambda": self.model.gae_lambda,
|
|
|
|
# "batch_size": self.model.batch_size,
|
|
|
|
# "n_steps": self.model.n_steps,
|
|
|
|
}
|
|
|
|
metric_dict: Dict[str, Union[float, int]] = {
|
|
|
|
"eval/mean_reward": 0,
|
|
|
|
"rollout/ep_rew_mean": 0,
|
|
|
|
"rollout/ep_len_mean": 0,
|
|
|
|
"train/value_loss": 0,
|
|
|
|
"train/explained_variance": 0,
|
|
|
|
}
|
|
|
|
self.logger.record(
|
|
|
|
"hparams",
|
|
|
|
HParam(hparam_dict, metric_dict),
|
|
|
|
exclude=("stdout", "log", "json", "csv"),
|
|
|
|
)
|
|
|
|
|
|
|
|
def _on_step(self) -> bool:
|
2022-12-07 11:37:55 +00:00
|
|
|
local_info = self.locals["infos"][0]
|
2023-10-15 09:20:11 +00:00
|
|
|
|
2024-05-12 15:12:20 +00:00
|
|
|
if hasattr(self.training_env, "envs"):
|
2023-10-15 09:52:18 +00:00
|
|
|
tensorboard_metrics = self.training_env.envs[0].unwrapped.tensorboard_metrics
|
|
|
|
|
|
|
|
else:
|
|
|
|
# For RL-multiproc - usage of [0] might need to be evaluated
|
|
|
|
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
2022-12-07 11:37:55 +00:00
|
|
|
|
2023-03-11 22:32:55 +00:00
|
|
|
for metric in local_info:
|
|
|
|
if metric not in ["episode", "terminal_observation"]:
|
|
|
|
self.logger.record(f"info/{metric}", local_info[metric])
|
|
|
|
|
|
|
|
for category in tensorboard_metrics:
|
|
|
|
for metric in tensorboard_metrics[category]:
|
|
|
|
self.logger.record(f"{category}/{metric}", tensorboard_metrics[category][metric])
|
2022-12-07 11:37:55 +00:00
|
|
|
|
2022-12-04 12:54:30 +00:00
|
|
|
return True
|