Merge pull request #8494 from freqtrade/bug-fix-pytorch

Bug fix: ensure data is on same device as model
This commit is contained in:
Robert Caulk 2023-04-14 00:31:43 +02:00 committed by GitHub
commit daa9f6cc19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 2 deletions

View File

@ -45,5 +45,6 @@ class BasePyTorchRegressor(BasePyTorchModel):
device=self.device
)
y = self.model.model(x)
y = y.cpu()
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])
return (pred_df, dk.do_predict)

View File

@ -143,8 +143,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
"""
data_loader_dictionary = {}
for split in splits:
x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"])
y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"])
x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"], self.device)
y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"], self.device)
dataset = TensorDataset(*x, *y)
data_loader = DataLoader(
dataset,