refactor classifiers class names

This commit is contained in:
Yinon Polak 2023-03-20 11:54:17 +02:00
parent 1c11a5f048
commit 2f386913ac
3 changed files with 4 additions and 4 deletions

View File

@ -4,11 +4,11 @@ import torch
from freqtrade.freqai.base_models.PyTorchModelTrainer import PyTorchModelTrainer
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.PyTorchClassifierClassifier import PyTorchClassifier
from freqtrade.freqai.prediction_models.PyTorchClassifier import PyTorchClassifier
from freqtrade.freqai.prediction_models.PyTorchMLPModel import PyTorchMLPModel
class MLPPyTorchClassifier(PyTorchClassifier):
class PyTorchMLPClassifier(PyTorchClassifier):
"""
This class implements the fit method of IFreqaiModel.
int the fit method we initialize the model and trainer objects.

View File

@ -49,8 +49,8 @@ class PyTorchMLPModel(nn.Module):
x = self.relu(self.input_layer(x))
x = self.dropout(x)
x = self.blocks(x)
logits = self.output_layer(x)
return logits
x = self.output_layer(x)
return x
class Block(nn.Module):