Update PyTorchModelTrainer.py

The n_epochs should be defined using the `max` not the `min` function.
This commit is contained in:
Robert Caulk 2024-03-25 09:21:32 +01:00 committed by GitHub
parent 526d7fad62
commit 18e34632d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -152,7 +152,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
"""
assert isinstance(self.n_steps, int), "Either `n_steps` or `n_epochs` should be set."
n_batches = n_obs // self.batch_size
n_epochs = min(self.n_steps // n_batches, 1)
n_epochs = max(self.n_steps // n_batches, 1)
if n_epochs <= 10:
logger.warning(
f"Setting low n_epochs: {n_epochs}. "