2023-05-12 07:56:44 +00:00
|
|
|
import logging
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import Any
|
|
|
|
|
2023-05-17 07:21:48 +00:00
|
|
|
from xgboost.callback import TrainingCallback
|
2023-05-12 07:56:44 +00:00
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class BaseTensorboardLogger:
|
2023-05-14 14:08:00 +00:00
|
|
|
def __init__(self, logdir: Path, activate: bool = True):
|
2023-05-12 07:56:44 +00:00
|
|
|
logger.warning("Tensorboard is not installed, no logs will be written."
|
|
|
|
"Ensure torch is installed, or use the torch/RL docker images")
|
|
|
|
|
2023-05-14 16:05:49 +00:00
|
|
|
def log_scalar(self, tag: str, scalar_value: Any, step: int):
|
2023-05-12 07:56:44 +00:00
|
|
|
return
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
return
|
|
|
|
|
|
|
|
|
2023-05-17 07:21:48 +00:00
|
|
|
class BaseTensorBoardCallback(TrainingCallback):
|
2023-05-12 07:56:44 +00:00
|
|
|
|
2023-05-14 14:08:00 +00:00
|
|
|
def __init__(self, logdir: Path, activate: bool = True):
|
2023-05-12 07:56:44 +00:00
|
|
|
logger.warning("Tensorboard is not installed, no logs will be written."
|
|
|
|
"Ensure torch is installed, or use the torch/RL docker images")
|
|
|
|
|
|
|
|
def after_iteration(
|
2023-05-17 07:21:48 +00:00
|
|
|
self, model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
2023-05-12 07:56:44 +00:00
|
|
|
) -> bool:
|
|
|
|
return False
|
|
|
|
|
|
|
|
def after_training(self, model):
|
|
|
|
return model
|