type hints fixes

This commit is contained in:
Yinon Polak 2023-03-06 19:37:08 +02:00
parent 8acdd0b47c
commit 5dd60eda36

View File

@ -84,7 +84,7 @@ class PyTorchModelTrainer:
loss = self.criterion(yb_pred, yb) loss = self.criterion(yb_pred, yb)
losses[i] = loss.item() losses[i] = loss.item()
loss_dictionary[split] = losses.mean() loss_dictionary[split] = losses.mean().item()
self.model.train() self.model.train()
return loss_dictionary return loss_dictionary