2023-05-01 13:18:03 +00:00
|
|
|
import math
|
|
|
|
|
|
|
|
import torch
|
2023-05-21 07:50:59 +00:00
|
|
|
from torch import nn
|
2023-05-01 13:18:03 +00:00
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
The architecture is based on the paper “Attention Is All You Need”.
|
|
|
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
|
|
|
Lukasz Kaiser, and Illia Polosukhin. 2017.
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
class PyTorchTransformerModel(nn.Module):
|
|
|
|
"""
|
|
|
|
A transformer approach to time series modeling using positional encoding.
|
|
|
|
The architecture is based on the paper “Attention Is All You Need”.
|
|
|
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
|
|
|
Lukasz Kaiser, and Illia Polosukhin. 2017.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, input_dim: int = 7, output_dim: int = 7, hidden_dim=1024,
|
2023-05-06 16:12:10 +00:00
|
|
|
n_layer=2, dropout_percent=0.1, time_window=10, nhead=8):
|
2023-05-01 13:18:03 +00:00
|
|
|
super().__init__()
|
|
|
|
self.time_window = time_window
|
2023-05-06 16:12:10 +00:00
|
|
|
# ensure the input dimension to the transformer is divisible by nhead
|
|
|
|
self.dim_val = input_dim - (input_dim % nhead)
|
2023-05-01 13:18:03 +00:00
|
|
|
self.input_net = nn.Sequential(
|
2023-05-06 16:12:10 +00:00
|
|
|
nn.Dropout(dropout_percent), nn.Linear(input_dim, self.dim_val)
|
2023-05-01 13:18:03 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# Encode the timeseries with Positional encoding
|
2023-05-06 16:12:10 +00:00
|
|
|
self.positional_encoding = PositionalEncoding(d_model=self.dim_val, max_len=self.dim_val)
|
2023-05-01 13:18:03 +00:00
|
|
|
|
|
|
|
# Define the encoder block of the Transformer
|
|
|
|
self.encoder_layer = nn.TransformerEncoderLayer(
|
2023-05-06 16:12:10 +00:00
|
|
|
d_model=self.dim_val, nhead=nhead, dropout=dropout_percent, batch_first=True)
|
2023-05-01 13:18:03 +00:00
|
|
|
self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=n_layer)
|
|
|
|
|
2023-05-06 16:12:10 +00:00
|
|
|
# the pseudo decoding FC
|
2023-05-01 13:18:03 +00:00
|
|
|
self.output_net = nn.Sequential(
|
2023-05-06 17:40:04 +00:00
|
|
|
nn.Linear(self.dim_val * time_window, int(hidden_dim)),
|
2023-05-06 16:12:10 +00:00
|
|
|
nn.ReLU(),
|
2023-05-01 13:18:03 +00:00
|
|
|
nn.Dropout(dropout_percent),
|
2023-05-06 16:12:10 +00:00
|
|
|
nn.Linear(int(hidden_dim), int(hidden_dim / 2)),
|
|
|
|
nn.ReLU(),
|
|
|
|
nn.Dropout(dropout_percent),
|
|
|
|
nn.Linear(int(hidden_dim / 2), int(hidden_dim / 4)),
|
|
|
|
nn.ReLU(),
|
|
|
|
nn.Dropout(dropout_percent),
|
|
|
|
nn.Linear(int(hidden_dim / 4), output_dim)
|
2023-05-01 13:18:03 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x, mask=None, add_positional_encoding=True):
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
x: Input features of shape [Batch, SeqLen, input_dim]
|
|
|
|
mask: Mask to apply on the attention outputs (optional)
|
|
|
|
add_positional_encoding: If True, we add the positional encoding to the input.
|
|
|
|
Might not be desired for some tasks.
|
|
|
|
"""
|
|
|
|
x = self.input_net(x)
|
|
|
|
if add_positional_encoding:
|
|
|
|
x = self.positional_encoding(x)
|
|
|
|
x = self.transformer(x, mask=mask)
|
|
|
|
x = x.reshape(-1, 1, self.time_window * x.shape[-1])
|
2023-05-06 16:12:10 +00:00
|
|
|
x = self.output_net(x)
|
2023-05-01 13:18:03 +00:00
|
|
|
return x
|
|
|
|
|
|
|
|
|
2023-05-21 07:50:59 +00:00
|
|
|
class PositionalEncoding(nn.Module):
|
2023-05-01 13:18:03 +00:00
|
|
|
def __init__(self, d_model, max_len=5000):
|
|
|
|
"""
|
|
|
|
Args
|
|
|
|
d_model: Hidden dimensionality of the input.
|
|
|
|
max_len: Maximum length of a sequence to expect.
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
# Create matrix of [SeqLen, HiddenDim] representing the positional encoding
|
|
|
|
# for max_len inputs
|
|
|
|
pe = torch.zeros(max_len, d_model)
|
|
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
|
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
|
pe = pe.unsqueeze(0)
|
|
|
|
|
|
|
|
self.register_buffer("pe", pe, persistent=False)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = x + self.pe[:, : x.size(1)]
|
|
|
|
return x
|