Improve logic for progressbarcallback handling

This commit is contained in:
Matthias 2023-10-15 11:20:25 +02:00
parent ba674fc796
commit 2d9d8dc976

View File

@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Any, Dict, Type
from typing import Any, Dict, List, Optional, Type
import torch as th
from stable_baselines3.common.callbacks import ProgressBarCallback
@ -74,10 +74,11 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
'trained agent.')
model = self.dd.model_dictionary[dk.pair]
model.set_env(self.train_env)
callbacks = [self.eval_callback, self.tensorboard_callback]
use_progressbar = self.rl_config.get('progress_bar', False)
if use_progressbar:
callbacks.insert(0, ProgressBarCallback())
callbacks: List[Any] = [self.eval_callback, self.tensorboard_callback]
progressbar_callback: Optional[ProgressBarCallback] = None
if self.rl_config.get('progress_bar', False):
progressbar_callback = ProgressBarCallback()
callbacks.insert(0, progressbar_callback)
try:
model.learn(
@ -85,8 +86,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
callback=callbacks,
)
finally:
if use_progressbar:
callbacks[0].on_training_end()
if progressbar_callback:
progressbar_callback.on_training_end()
if Path(dk.data_path / "best_model.zip").is_file():
logger.info('Callback found a best model.')