freqtrade_origin/freqtrade/freqai/torch/PyTorchTrainerInterface.py

52 lines
1.9 KiB
Python
Raw Normal View History

2023-04-03 12:19:10 +00:00
from abc import ABC, abstractmethod
2023-04-03 13:03:15 +00:00
from pathlib import Path
2023-04-03 12:19:10 +00:00
import pandas as pd
import torch
from torch import nn
2023-04-03 12:19:10 +00:00
class PyTorchTrainerInterface(ABC):
@abstractmethod
def fit(self, data_dictionary: dict[str, pd.DataFrame], splits: list[str]) -> None:
2023-04-03 12:19:10 +00:00
"""
:param data_dictionary: the dictionary constructed by DataHandler to hold
all the training and test data/labels.
:param splits: splits to use in training, splits must contain "train",
optional "test" could be added by setting freqai.data_split_parameters.test_size > 0
in the config file.
- Calculates the predicted output for the batch using the PyTorch model.
- Calculates the loss between the predicted and actual output using a loss function.
- Computes the gradients of the loss with respect to the model's parameters using
backpropagation.
- Updates the model's parameters using an optimizer.
"""
@abstractmethod
def save(self, path: Path) -> None:
"""
- Saving any nn.Module state_dict
- Saving model_meta_data, this dict should contain any additional data that the
user needs to store. e.g class_names for classification models.
"""
def load(self, path: Path) -> nn.Module:
"""
:param path: path to zip file.
:returns: pytorch model.
"""
checkpoint = torch.load(path)
return self.load_from_checkpoint(checkpoint)
@abstractmethod
def load_from_checkpoint(self, checkpoint: dict) -> nn.Module:
2023-04-03 12:19:10 +00:00
"""
when using continual_learning, DataDrawer will load the dictionary
(containing state dicts and model_meta_data) by calling torch.load(path).
you can access this dict from any class that inherits IFreqaiModel by calling
get_init_model method.
:checkpoint checkpoint: dict containing the model & optimizer state dicts,
model_meta_data, etc..
2023-04-03 13:03:15 +00:00
"""