freqtrade_origin/freqtrade/freqai/tensorboard/TensorboardCallback.py

62 lines
2.1 KiB
Python
Raw Normal View History

from enum import Enum
from typing import Any, Union
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
from freqtrade.freqai.RL.BaseEnvironment import BaseActions
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard and
episodic summary reports.
"""
2024-05-12 15:12:20 +00:00
def __init__(self, verbose=1, actions: type[Enum] = BaseActions):
2023-03-19 16:57:56 +00:00
super().__init__(verbose)
self.model: Any = None
self.actions: type[Enum] = actions
def _on_training_start(self) -> None:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning_rate": self.model.learning_rate,
# "gamma": self.model.gamma,
# "gae_lambda": self.model.gae_lambda,
# "batch_size": self.model.batch_size,
# "n_steps": self.model.n_steps,
}
metric_dict: dict[str, Union[float, int]] = {
"eval/mean_reward": 0,
"rollout/ep_rew_mean": 0,
"rollout/ep_len_mean": 0,
"train/value_loss": 0,
"train/explained_variance": 0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
2022-12-07 11:37:55 +00:00
local_info = self.locals["infos"][0]
2023-10-15 09:20:11 +00:00
2024-05-12 15:12:20 +00:00
if hasattr(self.training_env, "envs"):
2023-10-15 09:52:18 +00:00
tensorboard_metrics = self.training_env.envs[0].unwrapped.tensorboard_metrics
else:
# For RL-multiproc - usage of [0] might need to be evaluated
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
2022-12-07 11:37:55 +00:00
2023-03-11 22:32:55 +00:00
for metric in local_info:
if metric not in ["episode", "terminal_observation"]:
self.logger.record(f"info/{metric}", local_info[metric])
for category in tensorboard_metrics:
for metric in tensorboard_metrics[category]:
self.logger.record(f"{category}/{metric}", tensorboard_metrics[category][metric])
2022-12-07 11:37:55 +00:00
return True