diff --git a/freqtrade/freqai/torch/PyTorchModelTrainer.py b/freqtrade/freqai/torch/PyTorchModelTrainer.py index e74b572fd..b49e16196 100644 --- a/freqtrade/freqai/torch/PyTorchModelTrainer.py +++ b/freqtrade/freqai/torch/PyTorchModelTrainer.py @@ -182,8 +182,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): "pytrainer": self }, path) - def load(self, path: Path): - checkpoint = torch.load(path) + def load(self, path: Path, device: str = None): + checkpoint = torch.load(path, map_location=device) return self.load_from_checkpoint(checkpoint) def load_from_checkpoint(self, checkpoint: Dict):