improve documentation

This commit is contained in:
Yinon Polak 2023-03-09 14:55:52 +02:00
parent e88a0d5248
commit 8a9f2aedbb

View File

@ -32,6 +32,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
dict: A dictionary mapping class names to their corresponding indices. dict: A dictionary mapping class names to their corresponding indices.
dict: A dictionary mapping indices to their corresponding class names. dict: A dictionary mapping indices to their corresponding class names.
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
model_training_parameters = self.freqai_info["model_training_parameters"] model_training_parameters = self.freqai_info["model_training_parameters"]
self.n_hidden = model_training_parameters.get("n_hidden", 1024) self.n_hidden = model_training_parameters.get("n_hidden", 1024)
@ -50,6 +51,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
:raises ValueError: If self.class_names is not defined in the parent class. :raises ValueError: If self.class_names is not defined in the parent class.
""" """
if not hasattr(self, "class_names"): if not hasattr(self, "class_names"):
raise ValueError( raise ValueError(
"Missing attribute: self.class_names " "Missing attribute: self.class_names "
@ -93,7 +95,9 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
:pred_df: dataframe containing the predictions :pred_df: dataframe containing the predictions
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
data (NaNs) or felt uncertain about data (PCA and DI index) data (NaNs) or felt uncertain about data (PCA and DI index)
:raises ValueError: if 'class_name' doesn't exist in model meta_data.
""" """
class_names = self.model.model_meta_data.get("class_names", None) class_names = self.model.model_meta_data.get("class_names", None)
if not class_names: if not class_names:
raise ValueError( raise ValueError(
@ -128,6 +132,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
encode class name str -> int encode class name str -> int
assuming first column of *_labels data frame to contain class names assuming first column of *_labels data frame to contain class names
""" """
target_column_name = dk.label_list[0] target_column_name = dk.label_list[0]
for split in ["train", "test"]: for split in ["train", "test"]:
label_df = data_dictionary[f"{split}_labels"] label_df = data_dictionary[f"{split}_labels"]
@ -148,6 +153,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
""" """
decode class name int -> str decode class name int -> str
""" """
return list(map(lambda x: self.index_to_class_name[x.item()], classes)) return list(map(lambda x: self.index_to_class_name[x.item()], classes))
def init_class_names_to_index_mapping(self, class_names): def init_class_names_to_index_mapping(self, class_names):