type hints fixes

This commit is contained in:
Yinon Polak 2023-03-06 19:14:54 +02:00
parent 125085fbaf
commit 8acdd0b47c
2 changed files with 3 additions and 2 deletions

View File

@ -51,7 +51,7 @@ class PyTorchModelTrainer:
# training
for batch_data in data_loaders_dictionary['train']:
xb, yb = batch_data
xb = xb.to(self.device) # type: ignore
xb = xb.to(self.device)
yb = yb.to(self.device)
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)

View File

@ -3,6 +3,7 @@ import logging
import torch
import torch.nn as nn
from torch import Tensor
logger = logging.getLogger(__name__)
@ -16,7 +17,7 @@ class PyTorchMLPModel(nn.Module):
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
def forward(self, x: torch.tensor) -> torch.tensor:
def forward(self, x: Tensor) -> Tensor:
x = self.relu(self.input_layer(x))
x = self.dropout(x)
x = self.relu(self.hidden_layer(x))