From 3bbb7e38ead75996f5ba1bf23098067f57ed313b Mon Sep 17 00:00:00 2001 From: robcaulk Date: Sat, 6 May 2023 16:12:10 +0000 Subject: [PATCH] improve transformer architecture, remove 3.10 install constraint, add documentation for torch.compile() --- docs/freqai-configuration.md | 18 +++++++++++ .../freqai/torch/PyTorchTransformerModel.py | 32 ++++++++++--------- requirements-freqai-rl.txt | 2 +- 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/docs/freqai-configuration.md b/docs/freqai-configuration.md index e7aca20be..ad7cafd3d 100644 --- a/docs/freqai-configuration.md +++ b/docs/freqai-configuration.md @@ -395,3 +395,21 @@ Here we create a `PyTorchMLPRegressor` class that implements the `fit` method. T return dataframe ``` To see a full example, you can refer to the [classifier test strategy class](https://github.com/freqtrade/freqtrade/blob/develop/tests/strategy/strats/freqai_test_classifier.py). + + +#### Improving performance with `torch.compile()` + +Torch provides a `torch.compile()` method that can be used to improve performance for specific GPU hardware. More details can be found [here](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). In brief, you simply wrap your `model` in `torch.compile()`: + + +```python + model = PyTorchMLPModel( + input_dim=n_features, + output_dim=1, + **self.model_kwargs + ) + model.to(self.device) + model = torch.compile(model) +``` + +Then proceed to use the model as normal. Keep in mind that doing this will remove eager execution, which means errors and tracebacks will not be informative. diff --git a/freqtrade/freqai/torch/PyTorchTransformerModel.py b/freqtrade/freqai/torch/PyTorchTransformerModel.py index 0a252112a..2ab3ea434 100644 --- a/freqtrade/freqai/torch/PyTorchTransformerModel.py +++ b/freqtrade/freqai/torch/PyTorchTransformerModel.py @@ -20,32 +20,35 @@ class PyTorchTransformerModel(nn.Module): """ def __init__(self, input_dim: int = 7, output_dim: int = 7, hidden_dim=1024, - n_layer=2, dropout_percent=0.1, time_window=10): + n_layer=2, dropout_percent=0.1, time_window=10, nhead=8): super().__init__() self.time_window = time_window + # ensure the input dimension to the transformer is divisible by nhead + self.dim_val = input_dim - (input_dim % nhead) self.input_net = nn.Sequential( - nn.Dropout(dropout_percent), nn.Linear(input_dim, hidden_dim) + nn.Dropout(dropout_percent), nn.Linear(input_dim, self.dim_val) ) # Encode the timeseries with Positional encoding - self.positional_encoding = PositionalEncoding(d_model=hidden_dim, max_len=hidden_dim) + self.positional_encoding = PositionalEncoding(d_model=self.dim_val, max_len=self.dim_val) # Define the encoder block of the Transformer self.encoder_layer = nn.TransformerEncoderLayer( - d_model=hidden_dim, nhead=8, dropout=dropout_percent, batch_first=True) + d_model=self.dim_val, nhead=nhead, dropout=dropout_percent, batch_first=True) self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=n_layer) - # Pseudo decoder + # the pseudo decoding FC self.output_net = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.Tanh(), + nn.Linear(hidden_dim * time_window, int(hidden_dim)), + nn.ReLU(), nn.Dropout(dropout_percent), - ) - - self.output_layer = nn.Sequential( - nn.Linear(hidden_dim * time_window, output_dim), - nn.Tanh() + nn.Linear(int(hidden_dim), int(hidden_dim / 2)), + nn.ReLU(), + nn.Dropout(dropout_percent), + nn.Linear(int(hidden_dim / 2), int(hidden_dim / 4)), + nn.ReLU(), + nn.Dropout(dropout_percent), + nn.Linear(int(hidden_dim / 4), output_dim) ) def forward(self, x, mask=None, add_positional_encoding=True): @@ -60,9 +63,8 @@ class PyTorchTransformerModel(nn.Module): if add_positional_encoding: x = self.positional_encoding(x) x = self.transformer(x, mask=mask) - x = self.output_net(x) x = x.reshape(-1, 1, self.time_window * x.shape[-1]) - x = self.output_layer(x) + x = self.output_net(x) return x diff --git a/requirements-freqai-rl.txt b/requirements-freqai-rl.txt index 525c25229..6b9c1c298 100644 --- a/requirements-freqai-rl.txt +++ b/requirements-freqai-rl.txt @@ -2,7 +2,7 @@ -r requirements-freqai.txt # Required for freqai-rl -torch==2.0.0; python_version < '3.11' +torch==2.0.0 #until these branches will be released we can use this gymnasium==0.28.1 stable_baselines3==2.0.0a5