Add hypeorpt cloudpickle magic

closes #7078
This commit is contained in:
Matthias 2022-07-16 11:05:58 +02:00
parent e53e530874
commit 40e2da10f3

View File

@ -6,6 +6,7 @@ This module contains the hyperopt logic
import logging import logging
import random import random
import sys
import warnings import warnings
from datetime import datetime, timezone from datetime import datetime, timezone
from math import ceil from math import ceil
@ -17,6 +18,7 @@ import rapidjson
from colorama import Fore, Style from colorama import Fore, Style
from colorama import init as colorama_init from colorama import init as colorama_init
from joblib import Parallel, cpu_count, delayed, dump, load, wrap_non_picklable_objects from joblib import Parallel, cpu_count, delayed, dump, load, wrap_non_picklable_objects
from joblib.externals import cloudpickle
from pandas import DataFrame from pandas import DataFrame
from freqtrade.constants import DATETIME_PRINT_FORMAT, FTHYPT_FILEVERSION, LAST_BT_RESULT_FN from freqtrade.constants import DATETIME_PRINT_FORMAT, FTHYPT_FILEVERSION, LAST_BT_RESULT_FN
@ -87,6 +89,7 @@ class Hyperopt:
self.backtesting._set_strategy(self.backtesting.strategylist[0]) self.backtesting._set_strategy(self.backtesting.strategylist[0])
self.custom_hyperopt.strategy = self.backtesting.strategy self.custom_hyperopt.strategy = self.backtesting.strategy
self.hyperopt_pickle_magic(self.backtesting.strategy.__class__.__bases__)
self.custom_hyperoptloss: IHyperOptLoss = HyperOptLossResolver.load_hyperoptloss( self.custom_hyperoptloss: IHyperOptLoss = HyperOptLossResolver.load_hyperoptloss(
self.config) self.config)
self.calculate_loss = self.custom_hyperoptloss.hyperopt_loss_function self.calculate_loss = self.custom_hyperoptloss.hyperopt_loss_function
@ -137,6 +140,17 @@ class Hyperopt:
logger.info(f"Removing `{p}`.") logger.info(f"Removing `{p}`.")
p.unlink() p.unlink()
def hyperopt_pickle_magic(self, bases) -> None:
"""
Hyperopt magic to allow strategy inheritance across files.
For this to properly work, we need to register the module of the imported class
to pickle as value.
"""
for modules in bases:
if modules.__name__ != 'IStrategy':
cloudpickle.register_pickle_by_value(sys.modules[modules.__module__])
self.hyperopt_pickle_magic(modules.__bases__)
def _get_params_dict(self, dimensions: List[Dimension], raw_params: List[Any]) -> Dict: def _get_params_dict(self, dimensions: List[Dimension], raw_params: List[Any]) -> Dict:
# Ensure the number of dimensions match # Ensure the number of dimensions match