mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
simplified predict and predict_proba using super(). Added duplicate class label check.
This commit is contained in:
parent
6ef82dd8b6
commit
7053f81fa8
|
@ -4,7 +4,9 @@ from sklearn.base import is_classifier
|
|||
from sklearn.multioutput import MultiOutputClassifier, _fit_estimator
|
||||
from sklearn.utils.fixes import delayed
|
||||
from sklearn.utils.multiclass import check_classification_targets
|
||||
from sklearn.utils.validation import check_is_fitted, has_fit_parameter
|
||||
from sklearn.utils.validation import has_fit_parameter
|
||||
|
||||
from freqtrade.exceptions import OperationalException
|
||||
|
||||
|
||||
class FreqaiMultiOutputClassifier(MultiOutputClassifier):
|
||||
|
@ -65,6 +67,9 @@ class FreqaiMultiOutputClassifier(MultiOutputClassifier):
|
|||
self.classes_ = []
|
||||
for estimator in self.estimators_:
|
||||
self.classes_.extend(estimator.classes_)
|
||||
if len(set(self.classes_)) != len(self.classes_):
|
||||
raise OperationalException(f"Class labels must be unique across targets: "
|
||||
f"{self.classes_}")
|
||||
|
||||
if hasattr(self.estimators_[0], "n_features_in_"):
|
||||
self.n_features_in_ = self.estimators_[0].n_features_in_
|
||||
|
@ -74,56 +79,15 @@ class FreqaiMultiOutputClassifier(MultiOutputClassifier):
|
|||
return self
|
||||
|
||||
def predict_proba(self, X):
|
||||
"""Return prediction probabilities for each class of each output.
|
||||
|
||||
This method will raise a ``ValueError`` if any of the
|
||||
estimators do not have ``predict_proba``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like of shape (n_samples, n_features)
|
||||
The input data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
p : array of shape (n_samples, n_classes), or a list of n_outputs \
|
||||
such arrays if n_outputs > 1.
|
||||
The class probabilities of the input samples. The order of the
|
||||
classes corresponds to that in the attribute :term:`classes_`.
|
||||
|
||||
.. versionchanged:: 0.19
|
||||
This function now returns a list of arrays where the length of
|
||||
the list is ``n_outputs``, and each array is (``n_samples``,
|
||||
``n_classes``) for that particular output.
|
||||
"""
|
||||
check_is_fitted(self)
|
||||
results = np.squeeze(np.hstack(
|
||||
[estimator.predict_proba(X) for estimator in self.estimators_]
|
||||
))
|
||||
return results
|
||||
"""
|
||||
Get predict_proba and stack arrays horizontally
|
||||
"""
|
||||
results = np.hstack(super().predict_proba(X))
|
||||
return np.squeeze(results)
|
||||
|
||||
def predict(self, X):
|
||||
"""Predict multi-output variable using model for each target variable.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
The input data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
y : {array-like, sparse matrix} of shape (n_samples, n_outputs)
|
||||
Multi-output targets predicted across multiple predictors.
|
||||
Note: Separate models are generated for each predictor.
|
||||
"""
|
||||
check_is_fitted(self)
|
||||
if not hasattr(self.estimators_[0], "predict"):
|
||||
raise ValueError("The base estimator should implement a predict method")
|
||||
|
||||
y = Parallel(n_jobs=self.n_jobs)(
|
||||
delayed(e.predict)(X) for e in self.estimators_
|
||||
)
|
||||
|
||||
results = np.squeeze(np.asarray(y).T)
|
||||
|
||||
return results
|
||||
Get predict and squeeze into 2D array
|
||||
"""
|
||||
results = super().predict(X)
|
||||
return np.squeeze(results)
|
||||
|
|
Loading…
Reference in New Issue
Block a user