diff --git a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py index a98643b3f..a5b8b1591 100644 --- a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py +++ b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py @@ -32,6 +32,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): dict: A dictionary mapping class names to their corresponding indices. dict: A dictionary mapping indices to their corresponding class names. """ + super().__init__(**kwargs) model_training_parameters = self.freqai_info["model_training_parameters"] self.n_hidden = model_training_parameters.get("n_hidden", 1024) @@ -50,6 +51,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): :raises ValueError: If self.class_names is not defined in the parent class. """ + if not hasattr(self, "class_names"): raise ValueError( "Missing attribute: self.class_names " @@ -93,7 +95,9 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): :pred_df: dataframe containing the predictions :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove data (NaNs) or felt uncertain about data (PCA and DI index) + :raises ValueError: if 'class_name' doesn't exist in model meta_data. """ + class_names = self.model.model_meta_data.get("class_names", None) if not class_names: raise ValueError( @@ -128,6 +132,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): encode class name str -> int assuming first column of *_labels data frame to contain class names """ + target_column_name = dk.label_list[0] for split in ["train", "test"]: label_df = data_dictionary[f"{split}_labels"] @@ -148,6 +153,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): """ decode class name int -> str """ + return list(map(lambda x: self.index_to_class_name[x.item()], classes)) def init_class_names_to_index_mapping(self, class_names):