pytorch mlp add explicit annotation to fix mypy error

This commit is contained in:
Yinon Polak 2023-04-04 12:12:02 +03:00
parent 6b204c97ed
commit 26738370c7

View File

@ -48,7 +48,7 @@ class PyTorchMLPModel(nn.Module):
self.dropout = nn.Dropout(p=dropout_percent) self.dropout = nn.Dropout(p=dropout_percent)
def forward(self, x: List[torch.Tensor]) -> torch.Tensor: def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
x = x[0] x: torch.Tensor = x[0]
x = self.relu(self.input_layer(x)) x = self.relu(self.input_layer(x))
x = self.dropout(x) x = self.dropout(x)
x = self.blocks(x) x = self.blocks(x)