mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
Merge pull request #7883 from freqtrade/fix/multioutput-bug
fix bug in MultiOutput* with conv_width = 1
This commit is contained in:
commit
42afdbb0e5
|
@ -95,9 +95,14 @@ class BaseClassifierModel(IFreqaiModel):
|
|||
self.data_cleaning_predict(dk)
|
||||
|
||||
predictions = self.model.predict(dk.data_dictionary["prediction_features"])
|
||||
if self.CONV_WIDTH == 1:
|
||||
predictions = np.reshape(predictions, (-1, len(dk.label_list)))
|
||||
|
||||
pred_df = DataFrame(predictions, columns=dk.label_list)
|
||||
|
||||
predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"])
|
||||
if self.CONV_WIDTH == 1:
|
||||
predictions_prob = np.reshape(predictions_prob, (-1, len(self.model.classes_)))
|
||||
pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_)
|
||||
|
||||
pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
|
||||
|
|
|
@ -95,6 +95,9 @@ class BaseRegressionModel(IFreqaiModel):
|
|||
self.data_cleaning_predict(dk)
|
||||
|
||||
predictions = self.model.predict(dk.data_dictionary["prediction_features"])
|
||||
if self.CONV_WIDTH == 1:
|
||||
predictions = np.reshape(predictions, (-1, len(dk.label_list)))
|
||||
|
||||
pred_df = DataFrame(predictions, columns=dk.label_list)
|
||||
|
||||
pred_df = dk.denormalize_labels_from_metadata(pred_df)
|
||||
|
|
Loading…
Reference in New Issue
Block a user