mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
pytorch - trainer - reomve max_n_eval_batches arg from estimate loss method
This commit is contained in:
parent
49a7de4ebd
commit
588ffeedc1
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
@ -53,7 +52,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
self.device = device
|
||||
self.max_iters: int = kwargs.get("max_iters", 100)
|
||||
self.batch_size: int = kwargs.get("batch_size", 64)
|
||||
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
|
||||
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None) # TODO change this to n_batches
|
||||
self.data_convertor = data_convertor
|
||||
self.window_size: int = window_size
|
||||
self.tb_logger = tb_logger
|
||||
|
@ -95,25 +94,16 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
|
||||
# evaluation
|
||||
if "test" in splits:
|
||||
self.estimate_loss(
|
||||
data_loaders_dictionary,
|
||||
self.max_n_eval_batches,
|
||||
"test"
|
||||
)
|
||||
self.estimate_loss(data_loaders_dictionary, "test")
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_loss(
|
||||
self,
|
||||
data_loader_dictionary: Dict[str, DataLoader],
|
||||
max_n_eval_batches: Optional[int],
|
||||
split: str,
|
||||
) -> None:
|
||||
self.model.eval()
|
||||
n_batches = 0
|
||||
for i, batch_data in enumerate(data_loader_dictionary[split]):
|
||||
if max_n_eval_batches and i > max_n_eval_batches:
|
||||
n_batches += 1
|
||||
break
|
||||
xb, yb = batch_data
|
||||
xb.to(self.device)
|
||||
yb.to(self.device)
|
||||
|
@ -158,8 +148,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||
across different n_obs - the number of data points.
|
||||
"""
|
||||
|
||||
n_batches = math.ceil(n_obs // batch_size)
|
||||
epochs = math.ceil(n_iters // n_batches)
|
||||
n_batches = n_obs // batch_size
|
||||
epochs = n_iters // n_batches
|
||||
if epochs <= 10:
|
||||
logger.warning("User set `max_iters` in such a way that the trainer will only perform "
|
||||
f" {epochs} epochs. Please consider increasing this value accordingly")
|
||||
|
|
Loading…
Reference in New Issue
Block a user