diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py index ca333d9cf..9aabdf7ad 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py @@ -26,7 +26,7 @@ class PyTorchMLPClassifier(BasePyTorchClassifier): "model_training_parameters" : { "learning_rate": 3e-4, "trainer_kwargs": { - "max_iters": 5000, + "n_steps": 5000, "batch_size": 64, "n_epochs": null, }, diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py b/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py index 42fddf8ff..dc8dc4b61 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py @@ -27,7 +27,7 @@ class PyTorchMLPRegressor(BasePyTorchRegressor): "model_training_parameters" : { "learning_rate": 3e-4, "trainer_kwargs": { - "max_iters": 5000, + "n_steps": 5000, "batch_size": 64, "n_epochs": null, }, diff --git a/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py b/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py index 32663c86b..846d6df2e 100644 --- a/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py +++ b/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py @@ -30,7 +30,7 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor): "model_training_parameters" : { "learning_rate": 3e-4, "trainer_kwargs": { - "max_iters": 5000, + "n_steps": 5000, "batch_size": 64, "n_epochs": null }, diff --git a/freqtrade/freqai/torch/PyTorchModelTrainer.py b/freqtrade/freqai/torch/PyTorchModelTrainer.py index 2b0090c78..44f7dec4e 100644 --- a/freqtrade/freqai/torch/PyTorchModelTrainer.py +++ b/freqtrade/freqai/torch/PyTorchModelTrainer.py @@ -39,7 +39,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): state_dict and model_meta_data saved by self.save() method. :param model_meta_data: Additional metadata about the model (optional). :param data_convertor: convertor from pd.DataFrame to torch.tensor. - :param max_iters: used to calculate n_epochs. The number of training iterations to run. + :param n_steps: used to calculate n_epochs. The number of training iterations to run. iteration here refers to the number of times optimizer.step() is called. ignored if n_epochs is set. :param n_epochs: The maximum number batches to use for evaluation. @@ -50,10 +50,10 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): self.criterion = criterion self.model_meta_data = model_meta_data self.device = device - self.max_iters: int = kwargs.get("max_iters", None) + self.n_steps: int = kwargs.get("n_steps", None) self.n_epochs: Optional[int] = kwargs.get("n_epochs", 10) - if not self.max_iters and not self.n_epochs: - raise Exception("Either `max_iters` or `n_epochs` should be set.") + if not self.n_steps and not self.n_epochs: + raise Exception("Either `n_steps` or `n_epochs` should be set.") self.batch_size: int = kwargs.get("batch_size", 64) self.data_convertor = data_convertor @@ -82,7 +82,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): n_epochs = self.n_epochs or self.calc_n_epochs( n_obs=n_obs, batch_size=self.batch_size, - n_iters=self.max_iters, + n_iters=self.n_steps, ) batch_counter = 0 @@ -153,7 +153,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): Calculates the number of epochs required to reach the maximum number of iterations specified in the model training parameters. - the motivation here is that `max_iters` is easier to optimize and keep stable, + the motivation here is that `n_steps` is easier to optimize and keep stable, across different n_obs - the number of data points. """ @@ -162,7 +162,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): if n_epochs <= 10: logger.warning( f"Setting low n_epochs: {n_epochs}. " - f"Please consider increasing `max_iters` hyper-parameter." + f"Please consider increasing `n_steps` hyper-parameter." ) return n_epochs diff --git a/tests/freqai/conftest.py b/tests/freqai/conftest.py index 96716e83f..9c7a950e7 100644 --- a/tests/freqai/conftest.py +++ b/tests/freqai/conftest.py @@ -97,9 +97,9 @@ def mock_pytorch_mlp_model_training_parameters() -> Dict[str, Any]: return { "learning_rate": 3e-4, "trainer_kwargs": { - "max_iters": 1, + "n_steps": None, "batch_size": 64, - "n_epochs": None, + "n_epochs": 1, }, "model_kwargs": { "hidden_dim": 32,