mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 02:12:01 +00:00
improve transformer architecture, remove 3.10 install constraint, add documentation for torch.compile()
This commit is contained in:
parent
af139ffbab
commit
3bbb7e38ea
|
@ -395,3 +395,21 @@ Here we create a `PyTorchMLPRegressor` class that implements the `fit` method. T
|
|||
return dataframe
|
||||
```
|
||||
To see a full example, you can refer to the [classifier test strategy class](https://github.com/freqtrade/freqtrade/blob/develop/tests/strategy/strats/freqai_test_classifier.py).
|
||||
|
||||
|
||||
#### Improving performance with `torch.compile()`
|
||||
|
||||
Torch provides a `torch.compile()` method that can be used to improve performance for specific GPU hardware. More details can be found [here](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). In brief, you simply wrap your `model` in `torch.compile()`:
|
||||
|
||||
|
||||
```python
|
||||
model = PyTorchMLPModel(
|
||||
input_dim=n_features,
|
||||
output_dim=1,
|
||||
**self.model_kwargs
|
||||
)
|
||||
model.to(self.device)
|
||||
model = torch.compile(model)
|
||||
```
|
||||
|
||||
Then proceed to use the model as normal. Keep in mind that doing this will remove eager execution, which means errors and tracebacks will not be informative.
|
||||
|
|
|
@ -20,32 +20,35 @@ class PyTorchTransformerModel(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, input_dim: int = 7, output_dim: int = 7, hidden_dim=1024,
|
||||
n_layer=2, dropout_percent=0.1, time_window=10):
|
||||
n_layer=2, dropout_percent=0.1, time_window=10, nhead=8):
|
||||
super().__init__()
|
||||
self.time_window = time_window
|
||||
# ensure the input dimension to the transformer is divisible by nhead
|
||||
self.dim_val = input_dim - (input_dim % nhead)
|
||||
self.input_net = nn.Sequential(
|
||||
nn.Dropout(dropout_percent), nn.Linear(input_dim, hidden_dim)
|
||||
nn.Dropout(dropout_percent), nn.Linear(input_dim, self.dim_val)
|
||||
)
|
||||
|
||||
# Encode the timeseries with Positional encoding
|
||||
self.positional_encoding = PositionalEncoding(d_model=hidden_dim, max_len=hidden_dim)
|
||||
self.positional_encoding = PositionalEncoding(d_model=self.dim_val, max_len=self.dim_val)
|
||||
|
||||
# Define the encoder block of the Transformer
|
||||
self.encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim, nhead=8, dropout=dropout_percent, batch_first=True)
|
||||
d_model=self.dim_val, nhead=nhead, dropout=dropout_percent, batch_first=True)
|
||||
self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=n_layer)
|
||||
|
||||
# Pseudo decoder
|
||||
# the pseudo decoding FC
|
||||
self.output_net = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_dim * time_window, int(hidden_dim)),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_percent),
|
||||
)
|
||||
|
||||
self.output_layer = nn.Sequential(
|
||||
nn.Linear(hidden_dim * time_window, output_dim),
|
||||
nn.Tanh()
|
||||
nn.Linear(int(hidden_dim), int(hidden_dim / 2)),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_percent),
|
||||
nn.Linear(int(hidden_dim / 2), int(hidden_dim / 4)),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_percent),
|
||||
nn.Linear(int(hidden_dim / 4), output_dim)
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, add_positional_encoding=True):
|
||||
|
@ -60,9 +63,8 @@ class PyTorchTransformerModel(nn.Module):
|
|||
if add_positional_encoding:
|
||||
x = self.positional_encoding(x)
|
||||
x = self.transformer(x, mask=mask)
|
||||
x = self.output_net(x)
|
||||
x = x.reshape(-1, 1, self.time_window * x.shape[-1])
|
||||
x = self.output_layer(x)
|
||||
x = self.output_net(x)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
-r requirements-freqai.txt
|
||||
|
||||
# Required for freqai-rl
|
||||
torch==2.0.0; python_version < '3.11'
|
||||
torch==2.0.0
|
||||
#until these branches will be released we can use this
|
||||
gymnasium==0.28.1
|
||||
stable_baselines3==2.0.0a5
|
||||
|
|
Loading…
Reference in New Issue
Block a user