2023-05-12 07:56:44 +00:00
|
|
|
import logging
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from xgboost import callback
|
|
|
|
|
2024-05-12 13:18:32 +00:00
|
|
|
from freqtrade.freqai.tensorboard.base_tensorboard import (
|
|
|
|
BaseTensorBoardCallback,
|
|
|
|
BaseTensorboardLogger,
|
|
|
|
)
|
2023-05-12 07:56:44 +00:00
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class TensorboardLogger(BaseTensorboardLogger):
|
2023-05-14 14:08:00 +00:00
|
|
|
def __init__(self, logdir: Path, activate: bool = True):
|
|
|
|
self.activate = activate
|
|
|
|
if self.activate:
|
|
|
|
self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard")
|
2023-05-12 07:56:44 +00:00
|
|
|
|
|
|
|
def log_scalar(self, tag: str, scalar_value: Any, step: int):
|
2023-05-14 14:08:00 +00:00
|
|
|
if self.activate:
|
|
|
|
self.writer.add_scalar(tag, scalar_value, step)
|
2023-05-12 07:56:44 +00:00
|
|
|
|
|
|
|
def close(self):
|
2023-05-14 14:08:00 +00:00
|
|
|
if self.activate:
|
|
|
|
self.writer.flush()
|
|
|
|
self.writer.close()
|
2023-05-12 07:56:44 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TensorBoardCallback(BaseTensorBoardCallback):
|
2023-05-14 14:08:00 +00:00
|
|
|
def __init__(self, logdir: Path, activate: bool = True):
|
|
|
|
self.activate = activate
|
|
|
|
if self.activate:
|
|
|
|
self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard")
|
2023-05-12 07:56:44 +00:00
|
|
|
|
|
|
|
def after_iteration(
|
|
|
|
self, model, epoch: int, evals_log: callback.TrainingCallback.EvalsLog
|
|
|
|
) -> bool:
|
2023-05-14 14:08:00 +00:00
|
|
|
if not self.activate:
|
|
|
|
return False
|
2023-05-12 07:56:44 +00:00
|
|
|
if not evals_log:
|
|
|
|
return False
|
|
|
|
|
2024-03-03 14:47:19 +00:00
|
|
|
evals = ["validation", "train"]
|
2024-07-05 06:49:27 +00:00
|
|
|
for metric, eval_ in zip(evals_log.items(), evals):
|
2024-03-03 14:47:19 +00:00
|
|
|
for metric_name, log in metric[1].items():
|
2023-05-12 07:56:44 +00:00
|
|
|
score = log[-1][0] if isinstance(log[-1], tuple) else log[-1]
|
2024-07-05 06:49:27 +00:00
|
|
|
self.writer.add_scalar(f"{eval_}-{metric_name}", score, epoch)
|
2023-05-12 07:56:44 +00:00
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
def after_training(self, model):
|
2023-05-14 14:08:00 +00:00
|
|
|
if not self.activate:
|
|
|
|
return model
|
2023-05-12 07:56:44 +00:00
|
|
|
self.writer.flush()
|
|
|
|
self.writer.close()
|
|
|
|
|
|
|
|
return model
|