refactor(BasePyTorchClassifier.py): convert tensor to list before creating DataFrame to avoid TypeError.

docs(BasePyTorchClassifier.py): add missing parameter description in predict method
This commit is contained in:
Tommaso Falchi 2023-05-05 13:04:53 +02:00
parent e3ff2ccc97
commit 306dfc4ae8

View File

@ -45,6 +45,7 @@ class BasePyTorchClassifier(BasePyTorchModel):
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
"""
Filter the prediction features data and predict with it.
:param dk: dk: The datakitchen object
:param unfiltered_df: Full dataframe for the current backtest period.
:return:
:pred_df: dataframe containing the predictions
@ -78,7 +79,9 @@ class BasePyTorchClassifier(BasePyTorchModel):
probs = F.softmax(logits, dim=-1)
predicted_classes = torch.argmax(probs, dim=-1)
predicted_classes_str = self.decode_class_names(predicted_classes)
pred_df_prob = DataFrame(probs.detach().numpy(), columns=class_names)
# used .tolist to convert probs into an iterable, in this way Tensors
# are automatically moved to the CPU first if necessary.
pred_df_prob = DataFrame(probs.detach().tolist(), columns=class_names)
pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]])
pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
return (pred_df, dk.do_predict)