diff --git a/freqtrade/freqai/prediction_models/PyTorchClassifier.py b/freqtrade/freqai/prediction_models/PyTorchClassifier.py index 01432e0fe..b14a89b38 100644 --- a/freqtrade/freqai/prediction_models/PyTorchClassifier.py +++ b/freqtrade/freqai/prediction_models/PyTorchClassifier.py @@ -20,6 +20,18 @@ class PyTorchClassifier(BasePyTorchModel): """ A PyTorch implementation of a classifier. User must implement fit method + + Important! + User must declare the target class names in the strategy, under + IStrategy.set_freqai_targets method. + ``` + def set_freqai_targets(self, dataframe: DataFrame, metadata: Dict, **kwargs): + self.freqai.class_names = ["down", "up"] + dataframe['&s-up_or_down'] = np.where(dataframe["close"].shift(-100) > + dataframe["close"], 'up', 'down') + + return dataframe + ``` """ def __init__(self, **kwargs): super().__init__(**kwargs) @@ -127,7 +139,7 @@ class PyTorchClassifier(BasePyTorchModel): if not hasattr(self, "class_names"): raise ValueError( "Missing attribute: self.class_names " - "set self.freqai.class_names = [\"class a\", \"class b\", \"class c\"] " + "set self.freqai.class_names = ['class a', 'class b', 'class c'] " "inside IStrategy.set_freqai_targets method." ) return self.class_names