mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
use one iteration on all test and train data for evaluation
This commit is contained in:
parent
8a9f2aedbb
commit
1cf0e7be24
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
@ -21,7 +21,7 @@ class PyTorchModelTrainer:
|
|||
device: str,
|
||||
batch_size: int,
|
||||
max_iters: int,
|
||||
eval_iters: int,
|
||||
max_n_eval_batches: int,
|
||||
init_model: Dict,
|
||||
model_meta_data: Dict[str, Any] = {},
|
||||
):
|
||||
|
@ -34,7 +34,7 @@ class PyTorchModelTrainer:
|
|||
:param max_iters: The number of training iterations to run.
|
||||
iteration here refers to the number of times we call
|
||||
self.optimizer.step(). used to calculate n_epochs.
|
||||
:param eval_iters: The number of iterations used to estimate the loss.
|
||||
:param max_n_eval_batches: The maximum number batches to use for evaluation.
|
||||
:param init_model: A dictionary containing the initial model/optimizer
|
||||
state_dict and model_meta_data saved by self.save() method.
|
||||
:param model_meta_data: Additional metadata about the model (optional).
|
||||
|
@ -46,7 +46,7 @@ class PyTorchModelTrainer:
|
|||
self.device = device
|
||||
self.max_iters = max_iters
|
||||
self.batch_size = batch_size
|
||||
self.eval_iters = eval_iters
|
||||
self.max_n_eval_batches = max_n_eval_batches
|
||||
|
||||
if init_model:
|
||||
self.load_from_checkpoint(init_model)
|
||||
|
@ -67,7 +67,7 @@ class PyTorchModelTrainer:
|
|||
)
|
||||
for epoch in range(epochs):
|
||||
# evaluation
|
||||
losses = self.estimate_loss(data_loaders_dictionary, data_dictionary)
|
||||
losses = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches)
|
||||
logger.info(
|
||||
f"epoch ({epoch}/{epochs}):"
|
||||
f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}"
|
||||
|
@ -88,27 +88,27 @@ class PyTorchModelTrainer:
|
|||
def estimate_loss(
|
||||
self,
|
||||
data_loader_dictionary: Dict[str, DataLoader],
|
||||
data_dictionary: Dict[str, pd.DataFrame]
|
||||
max_n_eval_batches: Optional[int]
|
||||
) -> Dict[str, float]:
|
||||
|
||||
self.model.eval()
|
||||
epochs = self.calc_n_epochs(
|
||||
n_obs=len(data_dictionary["test_features"]),
|
||||
batch_size=self.batch_size,
|
||||
n_iters=self.eval_iters
|
||||
)
|
||||
loss_dictionary = {}
|
||||
n_batches = 0
|
||||
for split in ["train", "test"]:
|
||||
losses = torch.zeros(epochs)
|
||||
losses = []
|
||||
for i, batch in enumerate(data_loader_dictionary[split]):
|
||||
if max_n_eval_batches and i > max_n_eval_batches:
|
||||
n_batches += 1
|
||||
break
|
||||
|
||||
xb, yb = batch
|
||||
xb = xb.to(self.device)
|
||||
yb = yb.to(self.device)
|
||||
yb_pred = self.model(xb)
|
||||
loss = self.criterion(yb_pred, yb)
|
||||
losses[i] = loss.item()
|
||||
losses.append(loss.item())
|
||||
|
||||
loss_dictionary[split] = losses.mean().item()
|
||||
loss_dictionary[split] = sum(losses) / len(losses)
|
||||
|
||||
self.model.train()
|
||||
return loss_dictionary
|
||||
|
|
|
@ -39,7 +39,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||
self.max_iters = model_training_parameters.get("max_iters", 100)
|
||||
self.batch_size = model_training_parameters.get("batch_size", 64)
|
||||
self.learning_rate = model_training_parameters.get("learning_rate", 3e-4)
|
||||
self.eval_iters = model_training_parameters.get("eval_iters", 10)
|
||||
self.max_n_eval_batches = model_training_parameters.get("max_n_eval_batches", None)
|
||||
self.class_name_to_index = None
|
||||
self.index_to_class_name = None
|
||||
|
||||
|
@ -79,7 +79,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||
device=self.device,
|
||||
batch_size=self.batch_size,
|
||||
max_iters=self.max_iters,
|
||||
eval_iters=self.eval_iters,
|
||||
max_n_eval_batches=self.max_n_eval_batches,
|
||||
init_model=init_model
|
||||
)
|
||||
trainer.fit(data_dictionary)
|
||||
|
|
Loading…
Reference in New Issue
Block a user