diff --git a/freqtrade/freqai/base_models/BasePyTorchClassifier.py b/freqtrade/freqai/base_models/BasePyTorchClassifier.py index 977152cc5..3a4de4df0 100644 --- a/freqtrade/freqai/base_models/BasePyTorchClassifier.py +++ b/freqtrade/freqai/base_models/BasePyTorchClassifier.py @@ -45,6 +45,7 @@ class BasePyTorchClassifier(BasePyTorchModel): ) -> Tuple[DataFrame, npt.NDArray[np.int_]]: """ Filter the prediction features data and predict with it. + :param dk: dk: The datakitchen object :param unfiltered_df: Full dataframe for the current backtest period. :return: :pred_df: dataframe containing the predictions @@ -78,7 +79,9 @@ class BasePyTorchClassifier(BasePyTorchModel): probs = F.softmax(logits, dim=-1) predicted_classes = torch.argmax(probs, dim=-1) predicted_classes_str = self.decode_class_names(predicted_classes) - pred_df_prob = DataFrame(probs.detach().numpy(), columns=class_names) + # used .tolist to convert probs into an iterable, in this way Tensors + # are automatically moved to the CPU first if necessary. + pred_df_prob = DataFrame(probs.detach().tolist(), columns=class_names) pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]]) pred_df = pd.concat([pred_df, pred_df_prob], axis=1) return (pred_df, dk.do_predict)