Update BasePyTorchRegressor.py

Denormalization of prediction added to te PytorchMLP Model
This commit is contained in:
vinistation 2023-04-28 14:48:16 -05:00 committed by GitHub
parent 2a9e50a6a9
commit d1eb6d4fed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -47,4 +47,5 @@ class BasePyTorchRegressor(BasePyTorchModel):
y = self.model.model(x)
y = y.cpu()
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])
pred_df = dk.denormalize_labels_from_metadata(pred_df)
return (pred_df, dk.do_predict)