diff --git a/freqtrade/freqai/utils.py b/freqtrade/freqai/utils.py index e40b143d7..8afc87870 100644 --- a/freqtrade/freqai/utils.py +++ b/freqtrade/freqai/utils.py @@ -1,7 +1,7 @@ import logging from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict import numpy as np import pandas as pd @@ -16,8 +16,6 @@ from freqtrade.exchange import timeframe_to_seconds from freqtrade.exchange.exchange import market_is_active from freqtrade.freqai.data_drawer import FreqaiDataDrawer from freqtrade.freqai.data_kitchen import FreqaiDataKitchen -from freqtrade.freqai.tensorboard import TBLogger -from freqtrade.freqai.tensorboard.base_tensorboard import BaseTensorboardLogger from freqtrade.plugins.pairlist.pairlist_helpers import dynamic_expand_pairlist @@ -188,12 +186,13 @@ def get_timerange_backtest_live_models(config: Config) -> str: return timerange.timerange_str -def get_tb_logger(model_type: str, path: Path, activate: bool) -> Union[TBLogger, - BaseTensorboardLogger]: - tb_logger: Union[TBLogger, BaseTensorboardLogger] +def get_tb_logger(model_type: str, path: Path, activate: bool) -> Any: + if model_type == "pytorch" and activate: + from freqtrade.freqai.tensorboard import TBLogger tb_logger = TBLogger(path, activate) else: - tb_logger = BaseTensorboardLogger(path, activate) + from freqtrade.freqai.tensorboard.base_tensorboard import BaseTensorboardLogger + tb_logger = BaseTensorboardLogger(path, activate) # type: ignore return tb_logger