mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
Improve logic for progressbarcallback handling
This commit is contained in:
parent
ba674fc796
commit
2d9d8dc976
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Type
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch as th
|
||||
from stable_baselines3.common.callbacks import ProgressBarCallback
|
||||
|
@ -74,10 +74,11 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||
'trained agent.')
|
||||
model = self.dd.model_dictionary[dk.pair]
|
||||
model.set_env(self.train_env)
|
||||
callbacks = [self.eval_callback, self.tensorboard_callback]
|
||||
use_progressbar = self.rl_config.get('progress_bar', False)
|
||||
if use_progressbar:
|
||||
callbacks.insert(0, ProgressBarCallback())
|
||||
callbacks: List[Any] = [self.eval_callback, self.tensorboard_callback]
|
||||
progressbar_callback: Optional[ProgressBarCallback] = None
|
||||
if self.rl_config.get('progress_bar', False):
|
||||
progressbar_callback = ProgressBarCallback()
|
||||
callbacks.insert(0, progressbar_callback)
|
||||
|
||||
try:
|
||||
model.learn(
|
||||
|
@ -85,8 +86,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||
callback=callbacks,
|
||||
)
|
||||
finally:
|
||||
if use_progressbar:
|
||||
callbacks[0].on_training_end()
|
||||
if progressbar_callback:
|
||||
progressbar_callback.on_training_end()
|
||||
|
||||
if Path(dk.data_path / "best_model.zip").is_file():
|
||||
logger.info('Callback found a best model.')
|
||||
|
|
Loading…
Reference in New Issue
Block a user