refactor classifiers class names

This commit is contained in:
Yinon Polak 2023-03-20 11:54:17 +02:00
parent 1c11a5f048
commit 2f386913ac
3 changed files with 4 additions and 4 deletions

View File

@ -4,11 +4,11 @@ import torch
from freqtrade.freqai.base_models.PyTorchModelTrainer import PyTorchModelTrainer from freqtrade.freqai.base_models.PyTorchModelTrainer import PyTorchModelTrainer
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.PyTorchClassifierClassifier import PyTorchClassifier from freqtrade.freqai.prediction_models.PyTorchClassifier import PyTorchClassifier
from freqtrade.freqai.prediction_models.PyTorchMLPModel import PyTorchMLPModel from freqtrade.freqai.prediction_models.PyTorchMLPModel import PyTorchMLPModel
class MLPPyTorchClassifier(PyTorchClassifier): class PyTorchMLPClassifier(PyTorchClassifier):
""" """
This class implements the fit method of IFreqaiModel. This class implements the fit method of IFreqaiModel.
int the fit method we initialize the model and trainer objects. int the fit method we initialize the model and trainer objects.

View File

@ -49,8 +49,8 @@ class PyTorchMLPModel(nn.Module):
x = self.relu(self.input_layer(x)) x = self.relu(self.input_layer(x))
x = self.dropout(x) x = self.dropout(x)
x = self.blocks(x) x = self.blocks(x)
logits = self.output_layer(x) x = self.output_layer(x)
return logits return x
class Block(nn.Module): class Block(nn.Module):