freqtrade_origin/freqtrade/freqai/torch/PyTorchModelTrainer.py

225 lines
8.6 KiB
Python
Raw Normal View History

import logging
from pathlib import Path
2023-03-28 11:40:23 +00:00
from typing import Any, Dict, List, Optional
2023-03-08 14:03:36 +00:00
import pandas as pd
import torch
from torch import nn
2023-03-08 14:03:36 +00:00
from torch.optim import Optimizer
from torch.utils.data import DataLoader, TensorDataset
2023-04-03 12:19:10 +00:00
from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchTrainerInterface import PyTorchTrainerInterface
from .datasets import WindowDataset
2023-04-03 13:03:15 +00:00
logger = logging.getLogger(__name__)
2023-04-03 12:19:10 +00:00
class PyTorchModelTrainer(PyTorchTrainerInterface):
def __init__(
2024-05-12 15:12:20 +00:00
self,
model: nn.Module,
optimizer: Optimizer,
criterion: nn.Module,
device: str,
data_convertor: PyTorchDataConvertor,
model_meta_data: Optional[Dict[str, Any]] = None,
2024-05-12 15:12:20 +00:00
window_size: int = 1,
tb_logger: Any = None,
**kwargs,
):
2023-03-09 09:14:54 +00:00
"""
:param model: The PyTorch model to be trained.
:param optimizer: The optimizer to use for training.
:param criterion: The loss function to use for training.
:param device: The device to use for training (e.g. 'cpu', 'cuda').
:param init_model: A dictionary containing the initial model/optimizer
state_dict and model_meta_data saved by self.save() method.
:param model_meta_data: Additional metadata about the model (optional).
2024-04-18 20:51:25 +00:00
:param data_convertor: converter from pd.DataFrame to torch.tensor.
:param n_steps: used to calculate n_epochs. The number of training iterations to run.
iteration here refers to the number of times optimizer.step() is called.
ignored if n_epochs is set.
:param n_epochs: The maximum number batches to use for evaluation.
:param batch_size: The size of the batches to use during training.
2023-03-09 09:14:54 +00:00
"""
if model_meta_data is None:
model_meta_data = {}
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.model_meta_data = model_meta_data
self.device = device
self.n_epochs: Optional[int] = kwargs.get("n_epochs", 10)
2023-08-04 14:33:59 +00:00
self.n_steps: Optional[int] = kwargs.get("n_steps", None)
if self.n_steps is None and not self.n_epochs:
raise Exception("Either `n_steps` or `n_epochs` should be set.")
self.batch_size: int = kwargs.get("batch_size", 64)
2023-04-03 12:19:10 +00:00
self.data_convertor = data_convertor
self.window_size: int = window_size
self.tb_logger = tb_logger
self.test_batch_counter = 0
2023-03-28 11:40:23 +00:00
def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]):
2023-03-09 09:14:54 +00:00
"""
2023-03-28 11:40:23 +00:00
:param data_dictionary: the dictionary constructed by DataHandler to hold
all the training and test data/labels.
:param splits: splits to use in training, splits must contain "train",
optional "test" could be added by setting freqai.data_split_parameters.test_size > 0
in the config file.
2023-03-09 11:25:20 +00:00
- Calculates the predicted output for the batch using the PyTorch model.
- Calculates the loss between the predicted and actual output using a loss function.
- Computes the gradients of the loss with respect to the model's parameters using
backpropagation.
- Updates the model's parameters using an optimizer.
2023-03-09 09:14:54 +00:00
"""
self.model.train()
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits)
n_obs = len(data_dictionary["train_features"])
2023-08-04 14:33:59 +00:00
n_epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs)
batch_counter = 0
2023-07-15 11:43:05 +00:00
for _ in range(n_epochs):
for _, batch_data in enumerate(data_loaders_dictionary["train"]):
xb, yb = batch_data
xb = xb.to(self.device)
yb = yb.to(self.device)
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)
self.optimizer.zero_grad(set_to_none=True)
loss.backward()
self.optimizer.step()
self.tb_logger.log_scalar("train_loss", loss.item(), batch_counter)
batch_counter += 1
# evaluation
2023-03-28 11:40:23 +00:00
if "test" in splits:
self.estimate_loss(data_loaders_dictionary, "test")
@torch.no_grad()
def estimate_loss(
2024-05-12 15:12:20 +00:00
self,
data_loader_dictionary: Dict[str, DataLoader],
split: str,
) -> None:
self.model.eval()
for _, batch_data in enumerate(data_loader_dictionary[split]):
xb, yb = batch_data
xb = xb.to(self.device)
yb = yb.to(self.device)
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)
self.tb_logger.log_scalar(f"{split}_loss", loss.item(), self.test_batch_counter)
self.test_batch_counter += 1
self.model.train()
def create_data_loaders_dictionary(
2024-05-12 15:12:20 +00:00
self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]
) -> Dict[str, DataLoader]:
2023-03-09 09:21:10 +00:00
"""
Converts the input data to PyTorch tensors using a data loader.
"""
data_loader_dictionary = {}
2023-03-28 11:40:23 +00:00
for split in splits:
x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"], self.device)
y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"], self.device)
dataset = TensorDataset(x, y)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
num_workers=0,
)
data_loader_dictionary[split] = data_loader
return data_loader_dictionary
2023-08-04 14:33:59 +00:00
def calc_n_epochs(self, n_obs: int) -> int:
2023-03-09 09:21:10 +00:00
"""
Calculates the number of epochs required to reach the maximum number
of iterations specified in the model training parameters.
2023-03-21 10:29:05 +00:00
the motivation here is that `n_steps` is easier to optimize and keep stable,
2023-03-21 10:29:05 +00:00
across different n_obs - the number of data points.
2023-03-09 09:21:10 +00:00
"""
2024-06-08 07:32:54 +00:00
if not isinstance(self.n_steps, int):
raise ValueError("Either `n_steps` or `n_epochs` should be set.")
2023-08-04 14:33:59 +00:00
n_batches = n_obs // self.batch_size
n_epochs = max(self.n_steps // n_batches, 1)
if n_epochs <= 10:
logger.warning(
2023-07-13 18:32:46 +00:00
f"Setting low n_epochs: {n_epochs}. "
f"Please consider increasing `n_steps` hyper-parameter."
)
return n_epochs
def save(self, path: Path):
2023-03-09 11:25:20 +00:00
"""
- Saving any nn.Module state_dict
- Saving model_meta_data, this dict should contain any additional data that the
2023-07-15 11:43:05 +00:00
user needs to store. e.g. class_names for classification models.
2023-03-09 11:25:20 +00:00
"""
2024-05-12 15:12:20 +00:00
torch.save(
{
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"model_meta_data": self.model_meta_data,
"pytrainer": self,
},
path,
)
def load(self, path: Path):
checkpoint = torch.load(path)
return self.load_from_checkpoint(checkpoint)
def load_from_checkpoint(self, checkpoint: Dict):
2023-03-09 11:25:20 +00:00
"""
when using continual_learning, DataDrawer will load the dictionary
(containing state dicts and model_meta_data) by calling torch.load(path).
you can access this dict from any class that inherits IFreqaiModel by calling
get_init_model method.
"""
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.model_meta_data = checkpoint["model_meta_data"]
return self
class PyTorchTransformerTrainer(PyTorchModelTrainer):
"""
Creating a trainer for the Transformer model.
"""
def create_data_loaders_dictionary(
2024-05-12 15:12:20 +00:00
self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]
) -> Dict[str, DataLoader]:
"""
Converts the input data to PyTorch tensors using a data loader.
"""
data_loader_dictionary = {}
for split in splits:
x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"], self.device)
y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"], self.device)
dataset = WindowDataset(x, y, self.window_size)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
drop_last=True,
num_workers=0,
)
data_loader_dictionary[split] = data_loader
return data_loader_dictionary