mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-15 04:33:57 +00:00
Merge pull request #10711 from freqtrade/fix/pytorch-scaling
Some checks failed
Build Documentation / Deploy Docs through mike (push) Has been cancelled
Some checks failed
Build Documentation / Deploy Docs through mike (push) Has been cancelled
fix: Update BasePyTorchRegressor.py
This commit is contained in:
commit
91d9c9b4d5
|
@ -86,9 +86,6 @@ class BasePyTorchRegressor(BasePyTorchModel):
|
||||||
dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
|
dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
|
||||||
dk.label_pipeline = self.define_label_pipeline(threads=dk.thread_count)
|
dk.label_pipeline = self.define_label_pipeline(threads=dk.thread_count)
|
||||||
|
|
||||||
dd["train_labels"], _, _ = dk.label_pipeline.fit_transform(dd["train_labels"])
|
|
||||||
dd["test_labels"], _, _ = dk.label_pipeline.transform(dd["test_labels"])
|
|
||||||
|
|
||||||
(dd["train_features"], dd["train_labels"], dd["train_weights"]) = (
|
(dd["train_features"], dd["train_labels"], dd["train_weights"]) = (
|
||||||
dk.feature_pipeline.fit_transform(
|
dk.feature_pipeline.fit_transform(
|
||||||
dd["train_features"], dd["train_labels"], dd["train_weights"]
|
dd["train_features"], dd["train_labels"], dd["train_weights"]
|
||||||
|
|
|
@ -141,7 +141,7 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
|
||||||
pred_df = pd.DataFrame(yb.detach().numpy(), columns=dk.label_list)
|
pred_df = pd.DataFrame(yb.detach().numpy(), columns=dk.label_list)
|
||||||
pred_df, _, _ = dk.label_pipeline.inverse_transform(pred_df)
|
pred_df, _, _ = dk.label_pipeline.inverse_transform(pred_df)
|
||||||
|
|
||||||
if self.freqai_info.get("DI_threshold", 0) > 0:
|
if self.ft_params.get("DI_threshold", 0) > 0:
|
||||||
dk.DI_values = dk.feature_pipeline["di"].di_values
|
dk.DI_values = dk.feature_pipeline["di"].di_values
|
||||||
else:
|
else:
|
||||||
dk.DI_values = np.zeros(outliers.shape[0])
|
dk.DI_values = np.zeros(outliers.shape[0])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user