pytorch - bugfix - explicitly assign tensor to var as .to() is not inplace operation

This commit is contained in:
yinon 2023-08-04 12:51:42 +00:00
parent 836d7b885a
commit 777d25192c

View File

@ -113,8 +113,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
self.model.eval()
for _, batch_data in enumerate(data_loader_dictionary[split]):
xb, yb = batch_data
xb.to(self.device)
yb.to(self.device)
xb = xb.to(self.device)
yb = yb.to(self.device)
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)