ensure data is on same device as the model

This commit is contained in:
robcaulk 2023-04-13 12:19:34 +02:00
parent 0afd5a7385
commit dcf9bbdaea
2 changed files with 3 additions and 2 deletions

View File

@ -45,5 +45,6 @@ class BasePyTorchRegressor(BasePyTorchModel):
device=self.device
)
y = self.model.model(x)
y = y.cpu()
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])
return (pred_df, dk.do_predict)

View File

@ -143,8 +143,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
"""
data_loader_dictionary = {}
for split in splits:
x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"])
y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"])
x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"], self.device)
y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"], self.device)
dataset = TensorDataset(*x, *y)
data_loader = DataLoader(
dataset,