improve documentation

This commit is contained in:
Yinon Polak 2023-03-09 14:55:52 +02:00
parent e88a0d5248
commit 8a9f2aedbb

View File

@ -32,6 +32,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
dict: A dictionary mapping class names to their corresponding indices.
dict: A dictionary mapping indices to their corresponding class names.
"""
super().__init__(**kwargs)
model_training_parameters = self.freqai_info["model_training_parameters"]
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
@ -50,6 +51,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
:raises ValueError: If self.class_names is not defined in the parent class.
"""
if not hasattr(self, "class_names"):
raise ValueError(
"Missing attribute: self.class_names "
@ -93,7 +95,9 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
:pred_df: dataframe containing the predictions
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
data (NaNs) or felt uncertain about data (PCA and DI index)
:raises ValueError: if 'class_name' doesn't exist in model meta_data.
"""
class_names = self.model.model_meta_data.get("class_names", None)
if not class_names:
raise ValueError(
@ -128,6 +132,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
encode class name str -> int
assuming first column of *_labels data frame to contain class names
"""
target_column_name = dk.label_list[0]
for split in ["train", "test"]:
label_df = data_dictionary[f"{split}_labels"]
@ -148,6 +153,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
"""
decode class name int -> str
"""
return list(map(lambda x: self.index_to_class_name[x.item()], classes))
def init_class_names_to_index_mapping(self, class_names):