fix: Update BasePyTorchRegressor.py

This commit is contained in:
Robert Caulk 2024-09-26 16:31:43 +02:00 committed by GitHub
parent f0eaccc6ac
commit 123909cdac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -86,9 +86,6 @@ class BasePyTorchRegressor(BasePyTorchModel):
dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
dk.label_pipeline = self.define_label_pipeline(threads=dk.thread_count)
dd["train_labels"], _, _ = dk.label_pipeline.fit_transform(dd["train_labels"])
dd["test_labels"], _, _ = dk.label_pipeline.transform(dd["test_labels"])
(dd["train_features"], dd["train_labels"], dd["train_weights"]) = (
dk.feature_pipeline.fit_transform(
dd["train_features"], dd["train_labels"], dd["train_weights"]