From 7d28dad209784b48799ce9099bdd442b243e4632 Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Thu, 13 Jul 2023 19:41:39 +0300 Subject: [PATCH] pytorch - add n_epochs param to trainer --- .../prediction_models/PyTorchMLPClassifier.py | 2 +- .../prediction_models/PyTorchMLPRegressor.py | 2 +- .../PyTorchTransformerRegressor.py | 2 +- freqtrade/freqai/torch/PyTorchModelTrainer.py | 19 ++++++++----------- tests/freqai/conftest.py | 2 +- 5 files changed, 12 insertions(+), 15 deletions(-) diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py index 71279dba9..ca333d9cf 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py @@ -28,7 +28,7 @@ class PyTorchMLPClassifier(BasePyTorchClassifier): "trainer_kwargs": { "max_iters": 5000, "batch_size": 64, - "max_n_eval_batches": null, + "n_epochs": null, }, "model_kwargs": { "hidden_dim": 512, diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py b/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py index 9f4534487..42fddf8ff 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py @@ -29,7 +29,7 @@ class PyTorchMLPRegressor(BasePyTorchRegressor): "trainer_kwargs": { "max_iters": 5000, "batch_size": 64, - "max_n_eval_batches": null, + "n_epochs": null, }, "model_kwargs": { "hidden_dim": 512, diff --git a/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py b/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py index a76bab05c..32663c86b 100644 --- a/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py +++ b/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py @@ -32,7 +32,7 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor): "trainer_kwargs": { "max_iters": 5000, "batch_size": 64, - "max_n_eval_batches": null + "n_epochs": null }, "model_kwargs": { "hidden_dim": 512, diff --git a/freqtrade/freqai/torch/PyTorchModelTrainer.py b/freqtrade/freqai/torch/PyTorchModelTrainer.py index fe9919810..a34d673b4 100644 --- a/freqtrade/freqai/torch/PyTorchModelTrainer.py +++ b/freqtrade/freqai/torch/PyTorchModelTrainer.py @@ -40,10 +40,10 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): :param model_meta_data: Additional metadata about the model (optional). :param data_convertor: convertor from pd.DataFrame to torch.tensor. :param max_iters: The number of training iterations to run. - iteration here refers to the number of times we call - self.optimizer.step(). used to calculate n_epochs. + iteration here refers to the number of times optimizer.step() is called, + used to calculate n_epochs. ignored if n_epochs is set. + :param n_epochs: The maximum number batches to use for evaluation. :param batch_size: The size of the batches to use during training. - :param max_n_eval_batches: The maximum number batches to use for evaluation. """ self.model = model self.optimizer = optimizer @@ -51,8 +51,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): self.model_meta_data = model_meta_data self.device = device self.max_iters: int = kwargs.get("max_iters", 100) + self.n_epochs: Optional[int] = kwargs.get("n_epochs", None) self.batch_size: int = kwargs.get("batch_size", 64) - self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None) # TODO change this to n_batches self.data_convertor = data_convertor self.window_size: int = window_size self.tb_logger = tb_logger @@ -71,16 +71,13 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): backpropagation. - Updates the model's parameters using an optimizer. """ - data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits) - epochs = self.calc_n_epochs( - n_obs=len(data_dictionary["train_features"]), - batch_size=self.batch_size, - n_iters=self.max_iters - ) self.model.train() + + data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits) + n_obs = len(data_dictionary["train_features"]) + epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs, batch_size=self.batch_size, n_iters=self.max_iters) for epoch in range(1, epochs + 1): for i, batch_data in enumerate(data_loaders_dictionary["train"]): - xb, yb = batch_data xb = xb.to(self.device) yb = yb.to(self.device) diff --git a/tests/freqai/conftest.py b/tests/freqai/conftest.py index 4c4891ceb..96716e83f 100644 --- a/tests/freqai/conftest.py +++ b/tests/freqai/conftest.py @@ -99,7 +99,7 @@ def mock_pytorch_mlp_model_training_parameters() -> Dict[str, Any]: "trainer_kwargs": { "max_iters": 1, "batch_size": 64, - "max_n_eval_batches": 1, + "n_epochs": None, }, "model_kwargs": { "hidden_dim": 32,