mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
pytorch - trainer - add assertion that either n_epochs or max_iters is been set.
This commit is contained in:
parent
7d28dad209
commit
5734358d91
|
@ -39,9 +39,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
state_dict and model_meta_data saved by self.save() method.
|
||||
:param model_meta_data: Additional metadata about the model (optional).
|
||||
:param data_convertor: convertor from pd.DataFrame to torch.tensor.
|
||||
:param max_iters: The number of training iterations to run.
|
||||
iteration here refers to the number of times optimizer.step() is called,
|
||||
used to calculate n_epochs. ignored if n_epochs is set.
|
||||
:param max_iters: used to calculate n_epochs. The number of training iterations to run.
|
||||
iteration here refers to the number of times optimizer.step() is called.
|
||||
ignored if n_epochs is set.
|
||||
:param n_epochs: The maximum number batches to use for evaluation.
|
||||
:param batch_size: The size of the batches to use during training.
|
||||
"""
|
||||
|
@ -52,6 +52,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
self.device = device
|
||||
self.max_iters: int = kwargs.get("max_iters", 100)
|
||||
self.n_epochs: Optional[int] = kwargs.get("n_epochs", None)
|
||||
if not self.max_iters and not self.n_epochs:
|
||||
raise Exception("Either `max_iters` or `n_epochs` should be set.")
|
||||
|
||||
self.batch_size: int = kwargs.get("batch_size", 64)
|
||||
self.data_convertor = data_convertor
|
||||
self.window_size: int = window_size
|
||||
|
@ -75,8 +78,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
|
||||
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits)
|
||||
n_obs = len(data_dictionary["train_features"])
|
||||
epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs, batch_size=self.batch_size, n_iters=self.max_iters)
|
||||
for epoch in range(1, epochs + 1):
|
||||
n_epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs, batch_size=self.batch_size, n_iters=self.max_iters)
|
||||
for epoch in range(1, n_epochs + 1):
|
||||
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
|
||||
xb, yb = batch_data
|
||||
xb = xb.to(self.device)
|
||||
|
@ -146,14 +149,14 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
"""
|
||||
|
||||
n_batches = n_obs // batch_size
|
||||
epochs = n_iters // n_batches
|
||||
if epochs <= 10:
|
||||
logger.warning("User set `max_iters` in such a way that the trainer will only perform "
|
||||
f" {epochs} epochs. Please consider increasing this value accordingly")
|
||||
if epochs <= 1:
|
||||
logger.warning("Epochs set to 1. Please review your `max_iters` value")
|
||||
epochs = 1
|
||||
return epochs
|
||||
n_epochs = min(n_iters // n_batches, 1)
|
||||
if n_epochs <= 10:
|
||||
logger.warning(
|
||||
f"Setting low n_epochs. {n_epochs} = n_epochs = n_iters // n_batches = {n_iters} // {n_batches}. "
|
||||
f"Please consider increasing `max_iters` hyper-parameter."
|
||||
)
|
||||
|
||||
return n_epochs
|
||||
|
||||
def save(self, path: Path):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user