mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 02:12:01 +00:00
Add ground work for TensorFlow models, add protections from common mistakes
This commit is contained in:
parent
fea63fba12
commit
ef409dd345
|
@ -57,6 +57,7 @@ class FreqaiDataKitchen:
|
|||
self.live = live
|
||||
self.pair = pair
|
||||
self.svm_model: linear_model.SGDOneClassSVM = None
|
||||
self.keras = self.freqai_config.get("keras", False)
|
||||
self.set_all_pairs()
|
||||
if not self.live:
|
||||
self.full_timerange = self.create_fulltimerange(
|
||||
|
@ -92,7 +93,7 @@ class FreqaiDataKitchen:
|
|||
|
||||
return
|
||||
|
||||
def save_data(self, model: Any, coin: str = "", keras_model=False, label=None) -> None:
|
||||
def save_data(self, model: Any, coin: str = "", label=None) -> None:
|
||||
"""
|
||||
Saves all data associated with a model for a single sub-train time range
|
||||
:params:
|
||||
|
@ -106,7 +107,7 @@ class FreqaiDataKitchen:
|
|||
save_path = Path(self.data_path)
|
||||
|
||||
# Save the trained model
|
||||
if not keras_model:
|
||||
if not self.keras:
|
||||
dump(model, save_path / f"{self.model_filename}_model.joblib")
|
||||
else:
|
||||
model.save(save_path / f"{self.model_filename}_model.h5")
|
||||
|
@ -140,7 +141,7 @@ class FreqaiDataKitchen:
|
|||
|
||||
return
|
||||
|
||||
def load_data(self, coin: str = "", keras_model=False) -> Any:
|
||||
def load_data(self, coin: str = "") -> Any:
|
||||
"""
|
||||
loads all data required to make a prediction on a sub-train time range
|
||||
:returns:
|
||||
|
@ -174,7 +175,7 @@ class FreqaiDataKitchen:
|
|||
# try to access model in memory instead of loading object from disk to save time
|
||||
if self.live and self.model_filename in self.dd.model_dictionary:
|
||||
model = self.dd.model_dictionary[self.model_filename]
|
||||
elif not keras_model:
|
||||
elif not self.keras:
|
||||
model = load(self.data_path / str(self.model_filename + "_model.joblib"))
|
||||
else:
|
||||
from tensorflow import keras
|
||||
|
@ -559,6 +560,13 @@ class FreqaiDataKitchen:
|
|||
predict: bool = If true, inference an existing SVM model, else construct one
|
||||
"""
|
||||
|
||||
if self.keras:
|
||||
logger.warning("SVM outlier removal not currently supported for Keras based models. "
|
||||
"Skipping user requested function.")
|
||||
if predict:
|
||||
self.do_predict = np.ones(len(self.data_dictionary["prediction_features"]))
|
||||
return
|
||||
|
||||
if predict:
|
||||
assert self.svm_model, "No svm model available for outlier removal"
|
||||
y_pred = self.svm_model.predict(self.data_dictionary["prediction_features"])
|
||||
|
|
|
@ -69,6 +69,9 @@ class IFreqaiModel(ABC):
|
|||
self.ready_to_scan = False
|
||||
self.first = True
|
||||
self.keras = self.freqai_info.get("keras", False)
|
||||
if self.keras and self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0):
|
||||
self.freqai_info["feature_parameters"]["DI_threshold"] = 0
|
||||
logger.warning("DI threshold is not configured for Keras models yet. Deactivating.")
|
||||
self.CONV_WIDTH = self.freqai_info.get("conv_width", 2)
|
||||
|
||||
def assert_config(self, config: Dict[str, Any]) -> None:
|
||||
|
@ -197,9 +200,9 @@ class IFreqaiModel(ABC):
|
|||
self.model = self.train(dataframe_train, metadata["pair"], dk)
|
||||
self.dd.pair_dict[metadata["pair"]]["trained_timestamp"] = trained_timestamp.stopts
|
||||
dk.set_new_model_names(metadata["pair"], trained_timestamp)
|
||||
dk.save_data(self.model, metadata["pair"], keras_model=self.keras)
|
||||
dk.save_data(self.model, metadata["pair"])
|
||||
else:
|
||||
self.model = dk.load_data(metadata["pair"], keras_model=self.keras)
|
||||
self.model = dk.load_data(metadata["pair"])
|
||||
|
||||
self.check_if_feature_list_matches_strategy(dataframe_train, dk)
|
||||
|
||||
|
@ -276,7 +279,7 @@ class IFreqaiModel(ABC):
|
|||
)
|
||||
|
||||
# load the model and associated data into the data kitchen
|
||||
self.model = dk.load_data(coin=metadata["pair"], keras_model=self.keras)
|
||||
self.model = dk.load_data(coin=metadata["pair"])
|
||||
|
||||
if not self.model:
|
||||
logger.warning(
|
||||
|
@ -353,13 +356,15 @@ class IFreqaiModel(ABC):
|
|||
of how outlier data points are dropped from the dataframe used for training.
|
||||
"""
|
||||
|
||||
if self.freqai_info.get("feature_parameters", {}).get("principal_component_analysis"):
|
||||
if self.freqai_info.get("feature_parameters", {}).get(
|
||||
"principal_component_analysis", False
|
||||
):
|
||||
dk.principal_component_analysis()
|
||||
|
||||
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers"):
|
||||
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers", False):
|
||||
dk.use_SVM_to_remove_outliers(predict=False)
|
||||
|
||||
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold"):
|
||||
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0):
|
||||
dk.data["avg_mean_dist"] = dk.compute_distances()
|
||||
|
||||
# if self.feature_parameters["determine_statistical_distributions"]:
|
||||
|
@ -378,13 +383,15 @@ class IFreqaiModel(ABC):
|
|||
of how the do_predict vector is modified. do_predict is ultimately passed back to strategy
|
||||
for buy signals.
|
||||
"""
|
||||
if self.freqai_info.get("feature_parameters", {}).get("principal_component_analysis"):
|
||||
if self.freqai_info.get("feature_parameters", {}).get(
|
||||
"principal_component_analysis", False
|
||||
):
|
||||
dk.pca_transform(dataframe)
|
||||
|
||||
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers"):
|
||||
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers", False):
|
||||
dk.use_SVM_to_remove_outliers(predict=True)
|
||||
|
||||
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold"):
|
||||
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0):
|
||||
dk.check_if_pred_in_training_spaces()
|
||||
|
||||
# if self.feature_parameters["determine_statistical_distributions"]:
|
||||
|
@ -479,14 +486,15 @@ class IFreqaiModel(ABC):
|
|||
if self.dd.pair_dict[pair]["priority"] == 1 and self.scanning:
|
||||
with self.lock:
|
||||
self.dd.pair_to_end_of_training_queue(pair)
|
||||
dk.save_data(model, coin=pair, keras_model=self.keras)
|
||||
dk.save_data(model, coin=pair)
|
||||
|
||||
if self.freqai_info.get("purge_old_models", False):
|
||||
self.dd.purge_old_models()
|
||||
# self.retrain = False
|
||||
|
||||
def set_initial_historic_predictions(self, df: DataFrame, model: Any,
|
||||
dk: FreqaiDataKitchen, pair: str) -> None:
|
||||
def set_initial_historic_predictions(
|
||||
self, df: DataFrame, model: Any, dk: FreqaiDataKitchen, pair: str
|
||||
) -> None:
|
||||
trained_predictions = model.predict(df)
|
||||
pred_df = DataFrame(trained_predictions, columns=dk.label_list)
|
||||
for label in dk.label_list:
|
||||
|
|
|
@ -12,9 +12,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class BaseRegressionModel(IFreqaiModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
Base class for regression type models (e.g. Catboost, LightGBM, XGboost etc.).
|
||||
User *must* inherit from this class and set fit() and predict(). See example scripts
|
||||
such as prediction_models/CatboostPredictionModel.py for guidance.
|
||||
"""
|
||||
|
||||
def return_values(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame:
|
||||
|
|
|
@ -24,8 +24,9 @@ class FreqaiModelResolver(IResolver):
|
|||
object_type = IFreqaiModel
|
||||
object_type_str = "FreqaiModel"
|
||||
user_subdir = USERPATH_FREQAIMODELS
|
||||
initial_search_path = Path(__file__).parent.parent.joinpath(
|
||||
"freqai/prediction_models").resolve()
|
||||
initial_search_path = (
|
||||
Path(__file__).parent.parent.joinpath("freqai/prediction_models").resolve()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_freqaimodel(config: Dict) -> IFreqaiModel:
|
||||
|
@ -33,6 +34,7 @@ class FreqaiModelResolver(IResolver):
|
|||
Load the custom class from config parameter
|
||||
:param config: configuration dictionary
|
||||
"""
|
||||
disallowed_models = ["BaseRegressionModel", "BaseTensorFlowModel"]
|
||||
|
||||
freqaimodel_name = config.get("freqaimodel")
|
||||
if not freqaimodel_name:
|
||||
|
@ -40,6 +42,11 @@ class FreqaiModelResolver(IResolver):
|
|||
"No freqaimodel set. Please use `--freqaimodel` to "
|
||||
"specify the FreqaiModel class to use.\n"
|
||||
)
|
||||
if freqaimodel_name in disallowed_models:
|
||||
raise OperationalException(
|
||||
f"{freqaimodel_name} is a baseclass and cannot be used directly. User must choose "
|
||||
"an existing child class or inherit from this baseclass.\n"
|
||||
)
|
||||
freqaimodel = FreqaiModelResolver.load_object(
|
||||
freqaimodel_name,
|
||||
config,
|
||||
|
|
Loading…
Reference in New Issue
Block a user