pytorch - trainer - add device arg to load method

This commit is contained in:
yinon 2023-07-13 15:39:47 +00:00
parent 0c9aa86885
commit 49a7de4ebd

View File

@ -182,8 +182,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
"pytrainer": self "pytrainer": self
}, path) }, path)
def load(self, path: Path): def load(self, path: Path, device: str = None):
checkpoint = torch.load(path) checkpoint = torch.load(path, map_location=device)
return self.load_from_checkpoint(checkpoint) return self.load_from_checkpoint(checkpoint)
def load_from_checkpoint(self, checkpoint: Dict): def load_from_checkpoint(self, checkpoint: Dict):