Don't force-patch torch if it ain't installed.

This commit is contained in:
Matthias 2024-01-14 15:18:10 +01:00
parent c79502cb4b
commit 59cc607761
2 changed files with 7 additions and 7 deletions

View File

@ -1,4 +1,5 @@
import platform
import sys
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict
@ -15,6 +16,10 @@ from freqtrade.resolvers.freqaimodel_resolver import FreqaiModelResolver
from tests.conftest import get_patched_exchange
def is_py12() -> bool:
return sys.version_info >= (3, 12)
def is_mac() -> bool:
machine = platform.system()
return "Darwin" in machine
@ -31,7 +36,7 @@ def patch_torch_initlogs(mocker) -> None:
module_name = 'torch'
mocked_module = types.ModuleType(module_name)
sys.modules[module_name] = mocked_module
else:
elif not is_py12():
mocker.patch("torch._logging._init_logs")

View File

@ -1,7 +1,6 @@
import logging
import platform
import shutil
import sys
from pathlib import Path
from unittest.mock import MagicMock
@ -16,14 +15,10 @@ from freqtrade.optimize.backtesting import Backtesting
from freqtrade.persistence import Trade
from freqtrade.plugins.pairlistmanager import PairListManager
from tests.conftest import EXMS, create_mock_trades, get_patched_exchange, log_has_re
from tests.freqai.conftest import (get_patched_freqai_strategy, is_mac, make_rl_config,
from tests.freqai.conftest import (get_patched_freqai_strategy, is_mac, is_py12, make_rl_config,
mock_pytorch_mlp_model_training_parameters)
def is_py12() -> bool:
return sys.version_info >= (3, 12)
def is_arm() -> bool:
machine = platform.machine()
return "arm" in machine or "aarch64" in machine