pytorch - trainer - reomve max_n_eval_batches arg from estimate loss method

This commit is contained in:
yinon 2023-07-13 15:40:40 +00:00
parent 49a7de4ebd
commit 588ffeedc1

View File

@ -1,5 +1,4 @@
import logging
import math
from pathlib import Path
from typing import Any, Dict, List, Optional
@ -53,7 +52,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
self.device = device
self.max_iters: int = kwargs.get("max_iters", 100)
self.batch_size: int = kwargs.get("batch_size", 64)
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None) # TODO change this to n_batches
self.data_convertor = data_convertor
self.window_size: int = window_size
self.tb_logger = tb_logger
@ -95,25 +94,16 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
# evaluation
if "test" in splits:
self.estimate_loss(
data_loaders_dictionary,
self.max_n_eval_batches,
"test"
)
self.estimate_loss(data_loaders_dictionary, "test")
@torch.no_grad()
def estimate_loss(
self,
data_loader_dictionary: Dict[str, DataLoader],
max_n_eval_batches: Optional[int],
split: str,
) -> None:
self.model.eval()
n_batches = 0
for i, batch_data in enumerate(data_loader_dictionary[split]):
if max_n_eval_batches and i > max_n_eval_batches:
n_batches += 1
break
xb, yb = batch_data
xb.to(self.device)
yb.to(self.device)
@ -158,8 +148,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
across different n_obs - the number of data points.
"""
n_batches = math.ceil(n_obs // batch_size)
epochs = math.ceil(n_iters // n_batches)
n_batches = n_obs // batch_size
epochs = n_iters // n_batches
if epochs <= 10:
logger.warning("User set `max_iters` in such a way that the trainer will only perform "
f" {epochs} epochs. Please consider increasing this value accordingly")