From 92b2a6fa24e2d21499cd40b69902ea25f66130cf Mon Sep 17 00:00:00 2001 From: Matthias Date: Tue, 8 Oct 2024 07:20:49 +0200 Subject: [PATCH] fix: Support mps device where available --- freqtrade/freqai/base_models/BasePyTorchModel.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/freqtrade/freqai/base_models/BasePyTorchModel.py b/freqtrade/freqai/base_models/BasePyTorchModel.py index 50b023021..51de5fc00 100644 --- a/freqtrade/freqai/base_models/BasePyTorchModel.py +++ b/freqtrade/freqai/base_models/BasePyTorchModel.py @@ -20,7 +20,11 @@ class BasePyTorchModel(IFreqaiModel, ABC): def __init__(self, **kwargs): super().__init__(config=kwargs["config"]) self.dd.model_type = "pytorch" - self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = ( + "mps" + if torch.backends.mps.is_available() and torch.backends.mps.is_built() + else ("cuda" if torch.cuda.is_available() else "cpu") + ) test_size = self.freqai_info.get("data_split_parameters", {}).get("test_size") self.splits = ["train", "test"] if test_size != 0 else ["train"] self.window_size = self.freqai_info.get("conv_width", 1)