pytorch - trainer - add assertion that either n_epochs or max_iters is been set.

This commit is contained in:
Yinon Polak 2023-07-13 20:59:33 +03:00
parent 7d28dad209
commit 5734358d91

View File

@ -39,9 +39,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
state_dict and model_meta_data saved by self.save() method.
:param model_meta_data: Additional metadata about the model (optional).
:param data_convertor: convertor from pd.DataFrame to torch.tensor.
:param max_iters: The number of training iterations to run.
iteration here refers to the number of times optimizer.step() is called,
used to calculate n_epochs. ignored if n_epochs is set.
:param max_iters: used to calculate n_epochs. The number of training iterations to run.
iteration here refers to the number of times optimizer.step() is called.
ignored if n_epochs is set.
:param n_epochs: The maximum number batches to use for evaluation.
:param batch_size: The size of the batches to use during training.
"""
@ -52,6 +52,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
self.device = device
self.max_iters: int = kwargs.get("max_iters", 100)
self.n_epochs: Optional[int] = kwargs.get("n_epochs", None)
if not self.max_iters and not self.n_epochs:
raise Exception("Either `max_iters` or `n_epochs` should be set.")
self.batch_size: int = kwargs.get("batch_size", 64)
self.data_convertor = data_convertor
self.window_size: int = window_size
@ -75,8 +78,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits)
n_obs = len(data_dictionary["train_features"])
epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs, batch_size=self.batch_size, n_iters=self.max_iters)
for epoch in range(1, epochs + 1):
n_epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs, batch_size=self.batch_size, n_iters=self.max_iters)
for epoch in range(1, n_epochs + 1):
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
xb, yb = batch_data
xb = xb.to(self.device)
@ -146,14 +149,14 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
"""
n_batches = n_obs // batch_size
epochs = n_iters // n_batches
if epochs <= 10:
logger.warning("User set `max_iters` in such a way that the trainer will only perform "
f" {epochs} epochs. Please consider increasing this value accordingly")
if epochs <= 1:
logger.warning("Epochs set to 1. Please review your `max_iters` value")
epochs = 1
return epochs
n_epochs = min(n_iters // n_batches, 1)
if n_epochs <= 10:
logger.warning(
f"Setting low n_epochs. {n_epochs} = n_epochs = n_iters // n_batches = {n_iters} // {n_batches}. "
f"Please consider increasing `max_iters` hyper-parameter."
)
return n_epochs
def save(self, path: Path):
"""