pytorch - add n_epochs param to trainer

This commit is contained in:
Yinon Polak 2023-07-13 19:41:39 +03:00
parent 588ffeedc1
commit 7d28dad209
5 changed files with 12 additions and 15 deletions

View File

@ -28,7 +28,7 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
"trainer_kwargs": {
"max_iters": 5000,
"batch_size": 64,
"max_n_eval_batches": null,
"n_epochs": null,
},
"model_kwargs": {
"hidden_dim": 512,

View File

@ -29,7 +29,7 @@ class PyTorchMLPRegressor(BasePyTorchRegressor):
"trainer_kwargs": {
"max_iters": 5000,
"batch_size": 64,
"max_n_eval_batches": null,
"n_epochs": null,
},
"model_kwargs": {
"hidden_dim": 512,

View File

@ -32,7 +32,7 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
"trainer_kwargs": {
"max_iters": 5000,
"batch_size": 64,
"max_n_eval_batches": null
"n_epochs": null
},
"model_kwargs": {
"hidden_dim": 512,

View File

@ -40,10 +40,10 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
:param model_meta_data: Additional metadata about the model (optional).
:param data_convertor: convertor from pd.DataFrame to torch.tensor.
:param max_iters: The number of training iterations to run.
iteration here refers to the number of times we call
self.optimizer.step(). used to calculate n_epochs.
iteration here refers to the number of times optimizer.step() is called,
used to calculate n_epochs. ignored if n_epochs is set.
:param n_epochs: The maximum number batches to use for evaluation.
:param batch_size: The size of the batches to use during training.
:param max_n_eval_batches: The maximum number batches to use for evaluation.
"""
self.model = model
self.optimizer = optimizer
@ -51,8 +51,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
self.model_meta_data = model_meta_data
self.device = device
self.max_iters: int = kwargs.get("max_iters", 100)
self.n_epochs: Optional[int] = kwargs.get("n_epochs", None)
self.batch_size: int = kwargs.get("batch_size", 64)
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None) # TODO change this to n_batches
self.data_convertor = data_convertor
self.window_size: int = window_size
self.tb_logger = tb_logger
@ -71,16 +71,13 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
backpropagation.
- Updates the model's parameters using an optimizer.
"""
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits)
epochs = self.calc_n_epochs(
n_obs=len(data_dictionary["train_features"]),
batch_size=self.batch_size,
n_iters=self.max_iters
)
self.model.train()
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits)
n_obs = len(data_dictionary["train_features"])
epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs, batch_size=self.batch_size, n_iters=self.max_iters)
for epoch in range(1, epochs + 1):
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
xb, yb = batch_data
xb = xb.to(self.device)
yb = yb.to(self.device)

View File

@ -99,7 +99,7 @@ def mock_pytorch_mlp_model_training_parameters() -> Dict[str, Any]:
"trainer_kwargs": {
"max_iters": 1,
"batch_size": 64,
"max_n_eval_batches": 1,
"n_epochs": None,
},
"model_kwargs": {
"hidden_dim": 32,