mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-11 02:33:55 +00:00
modify feedforward net, move layer norm to start of thr block
This commit is contained in:
parent
719faab4b8
commit
8bee499328
|
@ -22,7 +22,7 @@ class PyTorchMLPModel(nn.Module):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.relu(self.input_layer(x))
|
||||
x = self.dropout(x)
|
||||
x = self.relu(self.blocks(x))
|
||||
x = self.blocks(x)
|
||||
logits = self.output_layer(x)
|
||||
return logits
|
||||
|
||||
|
@ -35,8 +35,8 @@ class Block(nn.Module):
|
|||
self.ln = nn.LayerNorm(hidden_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dropout(self.ff(x))
|
||||
x = self.ln(x)
|
||||
x = self.ff(self.ln(x))
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -46,7 +46,6 @@ class FeedForward(nn.Module):
|
|||
self.net = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
Loading…
Reference in New Issue
Block a user