freqtrade_origin/freqtrade/freqai/base_models/BasePyTorchModel.py

40 lines
1.2 KiB
Python
Raw Permalink Normal View History

2023-03-05 14:59:24 +00:00
import logging
2023-04-03 12:19:10 +00:00
from abc import ABC, abstractmethod
2023-03-05 14:59:24 +00:00
import torch
from freqtrade.freqai.freqai_interface import IFreqaiModel
2023-04-03 13:03:15 +00:00
from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor
2023-03-05 14:59:24 +00:00
2023-03-08 14:03:36 +00:00
2023-03-05 14:59:24 +00:00
logger = logging.getLogger(__name__)
class BasePyTorchModel(IFreqaiModel, ABC):
2023-03-05 14:59:24 +00:00
"""
2023-03-09 09:14:54 +00:00
Base class for PyTorch type models.
2023-04-03 12:19:10 +00:00
User *must* inherit from this class and set fit() and predict() and
data_convertor property.
2023-03-05 14:59:24 +00:00
"""
def __init__(self, **kwargs):
2023-03-09 11:29:11 +00:00
super().__init__(config=kwargs["config"])
self.dd.model_type = "pytorch"
self.device = (
"mps"
if torch.backends.mps.is_available() and torch.backends.mps.is_built()
else ("cuda" if torch.cuda.is_available() else "cpu")
)
2024-05-12 15:12:20 +00:00
test_size = self.freqai_info.get("data_split_parameters", {}).get("test_size")
2023-03-28 11:40:23 +00:00
self.splits = ["train", "test"] if test_size != 0 else ["train"]
self.window_size = self.freqai_info.get("conv_width", 1)
2023-03-05 14:59:24 +00:00
2023-04-03 12:19:10 +00:00
@property
@abstractmethod
def data_convertor(self) -> PyTorchDataConvertor:
2023-04-03 14:06:39 +00:00
"""
a class responsible for converting `*_features` & `*_labels` pandas dataframes
to pytorch tensors.
"""
2023-04-03 12:19:10 +00:00
raise NotImplementedError("Abstract property")