mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
Merge pull request #8623 from freqtrade/feat/tensorboard-logger
Add Tensorboard logger for PyTorch and XGBoost
This commit is contained in:
commit
c54f28ada8
|
@ -21,6 +21,7 @@ Mandatory parameters are marked as **Required** and have to be set in one of the
|
|||
| `continual_learning` | Use the final state of the most recently trained model as starting point for the new model, allowing for incremental learning (more information can be found [here](freqai-running.md#continual-learning)). Beware that this is currently a naive approach to incremental learning, and it has a high probability of overfitting/getting stuck in local minima while the market moves away from your model. We have the connections here primarily for experimental purposes and so that it is ready for more mature approaches to continual learning in chaotic systems like the crypto market. <br> **Datatype:** Boolean. <br> Default: `False`.
|
||||
| `write_metrics_to_disk` | Collect train timings, inference timings and cpu usage in json file. <br> **Datatype:** Boolean. <br> Default: `False`
|
||||
| `data_kitchen_thread_count` | <br> Designate the number of threads you want to use for data processing (outlier methods, normalization, etc.). This has no impact on the number of threads used for training. If user does not set it (default), FreqAI will use max number of threads - 2 (leaving 1 physical core available for Freqtrade bot and FreqUI) <br> **Datatype:** Positive integer.
|
||||
| `activate_tensorboard` | <br> Indicate whether or not to activate tensorboard for the tensorboard enabled modules (currently Reinforcment Learning, XGBoost, Catboost, and PyTorch). Tensorboard needs Torch installed, which means you will need the torch/RL docker image or you need to answer "yes" to the install question about whether or not you wish to install Torch. <br> **Datatype:** Boolean. <br> Default: `True`.
|
||||
|
||||
### Feature parameters
|
||||
|
||||
|
|
|
@ -161,7 +161,14 @@ This specific hyperopt would help you understand the appropriate `DI_values` for
|
|||
|
||||
## Using Tensorboard
|
||||
|
||||
CatBoost models benefit from tracking training metrics via Tensorboard. You can take advantage of the FreqAI integration to track training and evaluation performance across all coins and across all retrainings. Tensorboard is activated via the following command:
|
||||
!!! note "Availability"
|
||||
FreqAI includes tensorboard for a variety of models, including XGBoost, all PyTorch models, Reinforcement Learning, and Catboost. If you would like to see Tensorboard integrated into another model type, please open an issue on the [Freqtrade GitHub](https://github.com/freqtrade/freqtrade/issues)
|
||||
|
||||
!!! danger "Requirements"
|
||||
Tensorboard logging requires the FreqAI torch installation/docker image.
|
||||
|
||||
|
||||
The easiest way to use tensorboard is to ensure `freqai.activate_tensorboard` is set to `True` (default setting) in your configuration file, run FreqAI, then open a separate shell and run:
|
||||
|
||||
```bash
|
||||
cd freqtrade
|
||||
|
@ -171,3 +178,7 @@ tensorboard --logdir user_data/models/unique-id
|
|||
where `unique-id` is the `identifier` set in the `freqai` configuration file. This command must be run in a separate shell if you wish to view the output in your browser at 127.0.0.1:6060 (6060 is the default port used by Tensorboard).
|
||||
|
||||
![tensorboard](assets/tensorboard.jpg)
|
||||
|
||||
|
||||
!!! note "Deactivate for improved performance"
|
||||
Tensorboard logging can slow down training and should be deactivated for production use.
|
||||
|
|
|
@ -23,7 +23,7 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
|||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions
|
||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||
from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback
|
||||
from freqtrade.persistence import Trade
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from freqtrade.exceptions import OperationalException
|
|||
from freqtrade.exchange import timeframe_to_seconds
|
||||
from freqtrade.freqai.data_drawer import FreqaiDataDrawer
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.utils import plot_feature_importance, record_params
|
||||
from freqtrade.freqai.utils import get_tb_logger, plot_feature_importance, record_params
|
||||
from freqtrade.strategy.interface import IStrategy
|
||||
|
||||
|
||||
|
@ -110,6 +110,7 @@ class IFreqaiModel(ABC):
|
|||
if self.ft_params.get('principal_component_analysis', False) and self.continual_learning:
|
||||
self.ft_params.update({'principal_component_analysis': False})
|
||||
logger.warning('User tried to use PCA with continual learning. Deactivating PCA.')
|
||||
self.activate_tensorboard: bool = self.freqai_info.get('activate_tensorboard', True)
|
||||
|
||||
record_params(config, self.full_path)
|
||||
|
||||
|
@ -344,7 +345,10 @@ class IFreqaiModel(ABC):
|
|||
dk.find_labels(dataframe_train)
|
||||
|
||||
try:
|
||||
self.tb_logger = get_tb_logger(self.dd.model_type, dk.data_path,
|
||||
self.activate_tensorboard)
|
||||
self.model = self.train(dataframe_train, pair, dk)
|
||||
self.tb_logger.close()
|
||||
except Exception as msg:
|
||||
logger.warning(
|
||||
f"Training {pair} raised exception {msg.__class__.__name__}. "
|
||||
|
@ -632,7 +636,10 @@ class IFreqaiModel(ABC):
|
|||
dk.find_features(unfiltered_dataframe)
|
||||
dk.find_labels(unfiltered_dataframe)
|
||||
|
||||
self.tb_logger = get_tb_logger(self.dd.model_type, dk.data_path,
|
||||
self.activate_tensorboard)
|
||||
model = self.train(unfiltered_dataframe, pair, dk)
|
||||
self.tb_logger.close()
|
||||
|
||||
self.dd.pair_dict[pair]["trained_timestamp"] = trained_timestamp
|
||||
dk.set_new_model_names(pair, trained_timestamp)
|
||||
|
|
|
@ -84,6 +84,7 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
|
|||
model_meta_data={"class_names": class_names},
|
||||
device=self.device,
|
||||
data_convertor=self.data_convertor,
|
||||
tb_logger=self.tb_logger,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
trainer.fit(data_dictionary, self.splits)
|
||||
|
|
|
@ -78,6 +78,7 @@ class PyTorchMLPRegressor(BasePyTorchRegressor):
|
|||
criterion=criterion,
|
||||
device=self.device,
|
||||
data_convertor=self.data_convertor,
|
||||
tb_logger=self.tb_logger,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
trainer.fit(data_dictionary, self.splits)
|
||||
|
|
|
@ -32,8 +32,7 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
|
|||
"trainer_kwargs": {
|
||||
"max_iters": 5000,
|
||||
"batch_size": 64,
|
||||
"max_n_eval_batches": null,
|
||||
"window_size": 10
|
||||
"max_n_eval_batches": null
|
||||
},
|
||||
"model_kwargs": {
|
||||
"hidden_dim": 512,
|
||||
|
@ -85,6 +84,7 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
|
|||
device=self.device,
|
||||
data_convertor=self.data_convertor,
|
||||
window_size=self.window_size,
|
||||
tb_logger=self.tb_logger,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
trainer.fit(data_dictionary, self.splits)
|
||||
|
|
|
@ -58,10 +58,14 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||
net_arch=self.net_arch)
|
||||
|
||||
if self.activate_tensorboard:
|
||||
tb_path = Path(dk.full_path / "tensorboard" / dk.pair.split('/')[0])
|
||||
else:
|
||||
tb_path = None
|
||||
|
||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=Path(
|
||||
dk.full_path / "tensorboard" / dk.pair.split('/')[0]),
|
||||
tensorboard_log=tb_path,
|
||||
**self.freqai_info.get('model_training_parameters', {})
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -8,7 +8,7 @@ from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
|
|||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
|
||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||
from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -5,6 +5,7 @@ from xgboost import XGBRegressor
|
|||
|
||||
from freqtrade.freqai.base_models.BaseRegressionModel import BaseRegressionModel
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.tensorboard import TBCallback
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -44,7 +45,10 @@ class XGBoostRegressor(BaseRegressionModel):
|
|||
|
||||
model = XGBRegressor(**self.model_training_parameters)
|
||||
|
||||
model.set_params(callbacks=[TBCallback(dk.data_path)], activate=self.activate_tensorboard)
|
||||
model.fit(X=X, y=y, sample_weight=sample_weight, eval_set=eval_set,
|
||||
sample_weight_eval_set=eval_weights, xgb_model=xgb_model)
|
||||
# set the callbacks to empty so that we can serialize to disk later
|
||||
model.set_params(callbacks=[])
|
||||
|
||||
return model
|
||||
|
|
15
freqtrade/freqai/tensorboard/__init__.py
Normal file
15
freqtrade/freqai/tensorboard/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# ensure users can still use a non-torch freqai version
|
||||
try:
|
||||
from freqtrade.freqai.tensorboard.tensorboard import TensorBoardCallback, TensorboardLogger
|
||||
TBLogger = TensorboardLogger
|
||||
TBCallback = TensorBoardCallback
|
||||
except ModuleNotFoundError:
|
||||
from freqtrade.freqai.tensorboard.base_tensorboard import (BaseTensorBoardCallback,
|
||||
BaseTensorboardLogger)
|
||||
TBLogger = BaseTensorboardLogger # type: ignore
|
||||
TBCallback = BaseTensorBoardCallback # type: ignore
|
||||
|
||||
__all__ = (
|
||||
"TBLogger",
|
||||
"TBCallback"
|
||||
)
|
35
freqtrade/freqai/tensorboard/base_tensorboard.py
Normal file
35
freqtrade/freqai/tensorboard/base_tensorboard.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from xgboost.callback import TrainingCallback
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseTensorboardLogger:
|
||||
def __init__(self, logdir: Path, activate: bool = True):
|
||||
logger.warning("Tensorboard is not installed, no logs will be written."
|
||||
"Ensure torch is installed, or use the torch/RL docker images")
|
||||
|
||||
def log_scalar(self, tag: str, scalar_value: Any, step: int):
|
||||
return
|
||||
|
||||
def close(self):
|
||||
return
|
||||
|
||||
|
||||
class BaseTensorBoardCallback(TrainingCallback):
|
||||
|
||||
def __init__(self, logdir: Path, activate: bool = True):
|
||||
logger.warning("Tensorboard is not installed, no logs will be written."
|
||||
"Ensure torch is installed, or use the torch/RL docker images")
|
||||
|
||||
def after_iteration(
|
||||
self, model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def after_training(self, model):
|
||||
return model
|
62
freqtrade/freqai/tensorboard/tensorboard.py
Normal file
62
freqtrade/freqai/tensorboard/tensorboard.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from xgboost import callback
|
||||
|
||||
from freqtrade.freqai.tensorboard.base_tensorboard import (BaseTensorBoardCallback,
|
||||
BaseTensorboardLogger)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TensorboardLogger(BaseTensorboardLogger):
|
||||
def __init__(self, logdir: Path, activate: bool = True):
|
||||
self.activate = activate
|
||||
if self.activate:
|
||||
self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard")
|
||||
|
||||
def log_scalar(self, tag: str, scalar_value: Any, step: int):
|
||||
if self.activate:
|
||||
self.writer.add_scalar(tag, scalar_value, step)
|
||||
|
||||
def close(self):
|
||||
if self.activate:
|
||||
self.writer.flush()
|
||||
self.writer.close()
|
||||
|
||||
|
||||
class TensorBoardCallback(BaseTensorBoardCallback):
|
||||
|
||||
def __init__(self, logdir: Path, activate: bool = True):
|
||||
self.activate = activate
|
||||
if self.activate:
|
||||
self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard")
|
||||
|
||||
def after_iteration(
|
||||
self, model, epoch: int, evals_log: callback.TrainingCallback.EvalsLog
|
||||
) -> bool:
|
||||
if not self.activate:
|
||||
return False
|
||||
if not evals_log:
|
||||
return False
|
||||
|
||||
for data, metric in evals_log.items():
|
||||
for metric_name, log in metric.items():
|
||||
score = log[-1][0] if isinstance(log[-1], tuple) else log[-1]
|
||||
if data == "train":
|
||||
self.writer.add_scalar("train_loss", score, epoch)
|
||||
else:
|
||||
self.writer.add_scalar("valid_loss", score, epoch)
|
||||
|
||||
return False
|
||||
|
||||
def after_training(self, model):
|
||||
if not self.activate:
|
||||
return model
|
||||
self.writer.flush()
|
||||
self.writer.close()
|
||||
|
||||
return model
|
|
@ -28,6 +28,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
data_convertor: PyTorchDataConvertor,
|
||||
model_meta_data: Dict[str, Any] = {},
|
||||
window_size: int = 1,
|
||||
tb_logger: Any = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
|
@ -55,6 +56,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
|
||||
self.data_convertor = data_convertor
|
||||
self.window_size: int = window_size
|
||||
self.tb_logger = tb_logger
|
||||
|
||||
def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]):
|
||||
"""
|
||||
|
@ -78,8 +80,6 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
)
|
||||
self.model.train()
|
||||
for epoch in range(1, epochs + 1):
|
||||
# training
|
||||
losses = []
|
||||
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
|
||||
|
||||
xb, yb = batch_data
|
||||
|
@ -91,20 +91,15 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
self.optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
losses.append(loss.item())
|
||||
train_loss = sum(losses) / len(losses)
|
||||
log_message = f"epoch {epoch}/{epochs}: train loss {train_loss:.4f}"
|
||||
self.tb_logger.log_scalar("train_loss", loss.item(), i)
|
||||
|
||||
# evaluation
|
||||
if "test" in splits:
|
||||
test_loss = self.estimate_loss(
|
||||
self.estimate_loss(
|
||||
data_loaders_dictionary,
|
||||
self.max_n_eval_batches,
|
||||
"test"
|
||||
)
|
||||
log_message += f" ; test loss {test_loss:.4f}"
|
||||
|
||||
logger.info(log_message)
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_loss(
|
||||
|
@ -112,10 +107,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
data_loader_dictionary: Dict[str, DataLoader],
|
||||
max_n_eval_batches: Optional[int],
|
||||
split: str,
|
||||
) -> float:
|
||||
) -> None:
|
||||
self.model.eval()
|
||||
n_batches = 0
|
||||
losses = []
|
||||
for i, batch_data in enumerate(data_loader_dictionary[split]):
|
||||
if max_n_eval_batches and i > max_n_eval_batches:
|
||||
n_batches += 1
|
||||
|
@ -126,10 +120,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
|
||||
yb_pred = self.model(xb)
|
||||
loss = self.criterion(yb_pred, yb)
|
||||
losses.append(loss.item())
|
||||
self.tb_logger.log_scalar(f"{split}_loss", loss.item(), i)
|
||||
|
||||
self.model.train()
|
||||
return sum(losses) / len(losses)
|
||||
|
||||
def create_data_loaders_dictionary(
|
||||
self,
|
||||
|
|
|
@ -92,55 +92,6 @@ def get_required_data_timerange(config: Config) -> TimeRange:
|
|||
return data_load_timerange
|
||||
|
||||
|
||||
# Keep below for when we wish to download heterogeneously lengthed data for FreqAI.
|
||||
# def download_all_data_for_training(dp: DataProvider, config: Config) -> None:
|
||||
# """
|
||||
# Called only once upon start of bot to download the necessary data for
|
||||
# populating indicators and training a FreqAI model.
|
||||
# :param timerange: TimeRange = The full data timerange for populating the indicators
|
||||
# and training the model.
|
||||
# :param dp: DataProvider instance attached to the strategy
|
||||
# """
|
||||
|
||||
# if dp._exchange is not None:
|
||||
# markets = [p for p, m in dp._exchange.markets.items() if market_is_active(m)
|
||||
# or config.get('include_inactive')]
|
||||
# else:
|
||||
# # This should not occur:
|
||||
# raise OperationalException('No exchange object found.')
|
||||
|
||||
# all_pairs = dynamic_expand_pairlist(config, markets)
|
||||
|
||||
# if not dp._exchange:
|
||||
# # Not realistic - this is only called in live mode.
|
||||
# raise OperationalException("Dataprovider did not have an exchange attached.")
|
||||
|
||||
# time = datetime.now(tz=timezone.utc).timestamp()
|
||||
|
||||
# for tf in config["freqai"]["feature_parameters"].get("include_timeframes"):
|
||||
# timerange = TimeRange()
|
||||
# timerange.startts = int(time)
|
||||
# timerange.stopts = int(time)
|
||||
# startup_candles = dp.get_required_startup(str(tf))
|
||||
# tf_seconds = timeframe_to_seconds(str(tf))
|
||||
# timerange.subtract_start(tf_seconds * startup_candles)
|
||||
# new_pairs_days = int((timerange.stopts - timerange.startts) / 86400)
|
||||
# # FIXME: now that we are looping on `refresh_backtest_ohlcv_data`, the function
|
||||
# # redownloads the funding rate for each pair.
|
||||
# refresh_backtest_ohlcv_data(
|
||||
# dp._exchange,
|
||||
# pairs=all_pairs,
|
||||
# timeframes=[tf],
|
||||
# datadir=config["datadir"],
|
||||
# timerange=timerange,
|
||||
# new_pairs_days=new_pairs_days,
|
||||
# erase=False,
|
||||
# data_format=config.get("dataformat_ohlcv", "json"),
|
||||
# trading_mode=config.get("trading_mode", "spot"),
|
||||
# prepend=config.get("prepend_data", False),
|
||||
# )
|
||||
|
||||
|
||||
def plot_feature_importance(model: Any, pair: str, dk: FreqaiDataKitchen,
|
||||
count_max: int = 25) -> None:
|
||||
"""
|
||||
|
@ -233,3 +184,13 @@ def get_timerange_backtest_live_models(config: Config) -> str:
|
|||
dd = FreqaiDataDrawer(models_path, config)
|
||||
timerange = dd.get_timerange_from_live_historic_predictions()
|
||||
return timerange.timerange_str
|
||||
|
||||
|
||||
def get_tb_logger(model_type: str, path: Path, activate: bool) -> Any:
|
||||
|
||||
if model_type == "pytorch" and activate:
|
||||
from freqtrade.freqai.tensorboard import TBLogger
|
||||
return TBLogger(path, activate)
|
||||
else:
|
||||
from freqtrade.freqai.tensorboard.base_tensorboard import BaseTensorboardLogger
|
||||
return BaseTensorboardLogger(path, activate)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import platform
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
@ -14,6 +15,11 @@ from freqtrade.resolvers.freqaimodel_resolver import FreqaiModelResolver
|
|||
from tests.conftest import get_patched_exchange
|
||||
|
||||
|
||||
def is_mac() -> bool:
|
||||
machine = platform.system()
|
||||
return "Darwin" in machine
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def freqai_conf(default_conf, tmpdir):
|
||||
freqaiconf = deepcopy(default_conf)
|
||||
|
@ -36,6 +42,7 @@ def freqai_conf(default_conf, tmpdir):
|
|||
"identifier": "uniqe-id100",
|
||||
"live_trained_timestamp": 0,
|
||||
"data_kitchen_thread_count": 2,
|
||||
"activate_tensorboard": False,
|
||||
"feature_parameters": {
|
||||
"include_timeframes": ["5m"],
|
||||
"include_corr_pairlist": ["ADA/BTC"],
|
||||
|
|
|
@ -12,6 +12,7 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
|||
from tests.conftest import get_patched_exchange, log_has_re
|
||||
from tests.freqai.conftest import (get_patched_data_kitchen, get_patched_freqai_strategy,
|
||||
make_data_dictionary, make_unfiltered_dataframe)
|
||||
from tests.freqai.test_freqai_interface import is_mac
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -173,6 +174,9 @@ def test_get_full_model_path(mocker, freqai_conf, model):
|
|||
freqai_conf.update({"timerange": "20180110-20180130"})
|
||||
freqai_conf.update({"strategy": "freqai_test_strat"})
|
||||
|
||||
if is_mac():
|
||||
pytest.skip("Mac is confused during this test for unknown reasons")
|
||||
|
||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||
exchange = get_patched_exchange(mocker, freqai_conf)
|
||||
strategy.dp = DataProvider(freqai_conf, exchange)
|
||||
|
@ -188,7 +192,7 @@ def test_get_full_model_path(mocker, freqai_conf, model):
|
|||
|
||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
||||
|
||||
freqai.dk.set_paths('ADA/BTC', None)
|
||||
freqai.extract_data_and_train_model(
|
||||
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from freqtrade.optimize.backtesting import Backtesting
|
|||
from freqtrade.persistence import Trade
|
||||
from freqtrade.plugins.pairlistmanager import PairListManager
|
||||
from tests.conftest import EXMS, create_mock_trades, get_patched_exchange, log_has_re
|
||||
from tests.freqai.conftest import (get_patched_freqai_strategy, make_rl_config,
|
||||
from tests.freqai.conftest import (get_patched_freqai_strategy, is_mac, make_rl_config,
|
||||
mock_pytorch_mlp_model_training_parameters)
|
||||
|
||||
|
||||
|
@ -28,11 +28,6 @@ def is_arm() -> bool:
|
|||
return "arm" in machine or "aarch64" in machine
|
||||
|
||||
|
||||
def is_mac() -> bool:
|
||||
machine = platform.system()
|
||||
return "Darwin" in machine
|
||||
|
||||
|
||||
def can_run_model(model: str) -> None:
|
||||
if is_arm() and "Catboost" in model:
|
||||
pytest.skip("CatBoost is not supported on ARM.")
|
||||
|
@ -59,6 +54,11 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
|||
dbscan, float32, can_short, shuffle, buffer):
|
||||
|
||||
can_run_model(model)
|
||||
|
||||
test_tb = True
|
||||
if is_mac():
|
||||
test_tb = False
|
||||
|
||||
model_save_ext = 'joblib'
|
||||
freqai_conf.update({"freqaimodel": model})
|
||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
||||
|
@ -94,6 +94,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
|||
strategy.freqai_info = freqai_conf.get("freqai", {})
|
||||
freqai = strategy.freqai
|
||||
freqai.live = True
|
||||
freqai.activate_tensorboard = test_tb
|
||||
freqai.can_short = can_short
|
||||
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
||||
freqai.dk.live = True
|
||||
|
@ -239,6 +240,9 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model):
|
|||
)
|
||||
def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog):
|
||||
can_run_model(model)
|
||||
test_tb = True
|
||||
if is_mac():
|
||||
test_tb = False
|
||||
|
||||
freqai_conf.get("freqai", {}).update({"save_backtest_models": True})
|
||||
freqai_conf['runmode'] = RunMode.BACKTEST
|
||||
|
@ -271,6 +275,7 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog)
|
|||
strategy.freqai_info = freqai_conf.get("freqai", {})
|
||||
freqai = strategy.freqai
|
||||
freqai.live = False
|
||||
freqai.activate_tensorboard = test_tb
|
||||
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
||||
timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
|
||||
|
@ -282,6 +287,7 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog)
|
|||
df[f'%-constant_{i}'] = i
|
||||
|
||||
metadata = {"pair": "LTC/BTC"}
|
||||
freqai.dk.set_paths('LTC/BTC', None)
|
||||
freqai.start_backtesting(df, metadata, freqai.dk, strategy)
|
||||
model_folders = [x for x in freqai.dd.full_path.iterdir() if x.is_dir()]
|
||||
|
||||
|
@ -439,6 +445,7 @@ def test_principal_component_analysis(mocker, freqai_conf):
|
|||
|
||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
||||
freqai.dk.set_paths('ADA/BTC', None)
|
||||
|
||||
freqai.extract_data_and_train_model(
|
||||
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
||||
|
@ -472,6 +479,7 @@ def test_plot_feature_importance(mocker, freqai_conf):
|
|||
|
||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
||||
freqai.dk.set_paths('ADA/BTC', None)
|
||||
|
||||
freqai.extract_data_and_train_model(
|
||||
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
||||
|
|
Loading…
Reference in New Issue
Block a user