From f9fdf1c31b7d437f269aed33626d05c6bae6bf10 Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Sun, 12 Mar 2023 14:31:08 +0200 Subject: [PATCH] generalize mlp model --- .../PyTorchClassifierMultiTarget.py | 19 +++++----- .../prediction_models/PyTorchMLPModel.py | 38 ++++++++++++++++--- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py index f951778bf..be42fd8e6 100644 --- a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py +++ b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Optional import numpy as np import numpy.typing as npt @@ -34,12 +34,13 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): """ super().__init__(**kwargs) - model_training_parameters = self.freqai_info["model_training_parameters"] - self.n_hidden = model_training_parameters.get("n_hidden", 1024) - self.max_iters = model_training_parameters.get("max_iters", 100) - self.batch_size = model_training_parameters.get("batch_size", 64) - self.learning_rate = model_training_parameters.get("learning_rate", 3e-4) - self.max_n_eval_batches = model_training_parameters.get("max_n_eval_batches", None) + trainer_kwargs = self.freqai_info.get("trainer_kwargs", {}) + self.n_hidden: int = trainer_kwargs.get("n_hidden", 1024) + self.max_iters: int = trainer_kwargs.get("max_iters", 100) + self.batch_size: int = trainer_kwargs.get("batch_size", 64) + self.learning_rate: float = trainer_kwargs.get("learning_rate", 3e-4) + self.max_n_eval_batches: Optional[int] = trainer_kwargs.get("max_n_eval_batches", None) + self.model_kwargs: Dict = trainer_kwargs.get("model_kwargs", {}) self.class_name_to_index = None self.index_to_class_name = None @@ -64,8 +65,8 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): n_features = data_dictionary["train_features"].shape[-1] model = PyTorchMLPModel( input_dim=n_features, - hidden_dim=self.n_hidden, - output_dim=len(self.class_names) + output_dim=len(self.class_names), + **self.model_kwargs ) model.to(self.device) optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate) diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPModel.py b/freqtrade/freqai/prediction_models/PyTorchMLPModel.py index 1f13ca069..91e496c5d 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPModel.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPModel.py @@ -8,18 +8,46 @@ logger = logging.getLogger(__name__) class PyTorchMLPModel(nn.Module): - def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + def __init__(self, input_dim: int, output_dim: int, **kwargs): super(PyTorchMLPModel, self).__init__() + hidden_dim: int = kwargs.get("hidden_dim", 1024) + dropout_percent: int = kwargs.get("dropout_percent", 0.2) + n_layer: int = kwargs.get("n_layer", 1) self.input_layer = nn.Linear(input_dim, hidden_dim) - self.hidden_layer = nn.Linear(hidden_dim, hidden_dim) + self.blocks = nn.Sequential(*[Block(hidden_dim, dropout_percent) for _ in range(n_layer)]) self.output_layer = nn.Linear(hidden_dim, output_dim) self.relu = nn.ReLU() - self.dropout = nn.Dropout(p=0.2) + self.dropout = nn.Dropout(p=dropout_percent) def forward(self, x: Tensor) -> Tensor: x = self.relu(self.input_layer(x)) x = self.dropout(x) - x = self.relu(self.hidden_layer(x)) - x = self.dropout(x) + x = self.relu(self.blocks(x)) logits = self.output_layer(x) return logits + + +class Block(nn.Module): + def __init__(self, hidden_dim: int, dropout_percent: int): + super(Block, self).__init__() + self.ff = FeedForward(hidden_dim) + self.dropout = nn.Dropout(p=dropout_percent) + self.ln = nn.LayerNorm(hidden_dim) + + def forward(self, x): + x = self.dropout(self.ff(x)) + x = self.ln(x) + return x + + +class FeedForward(nn.Module): + def __init__(self, hidden_dim: int): + super(FeedForward, self).__init__() + self.net = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + ) + + def forward(self, x): + return self.net(x)