Merge pull request #9897 from freqtrade/fix/xgboosttensorboard

fix: try plotting as much info in xgboost tensorboard as possible
This commit is contained in:
Matthias 2024-03-04 06:38:22 +01:00 committed by GitHub
commit 426bc4c97b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 8 deletions

View File

@ -36,8 +36,15 @@ class XGBoostRegressor(BaseRegressionModel):
eval_set = None
eval_weights = None
else:
eval_set = [(data_dictionary["test_features"], data_dictionary["test_labels"])]
eval_weights = [data_dictionary['test_weights']]
eval_set = [
(data_dictionary["test_features"],
data_dictionary["test_labels"]),
(X, y)
]
eval_weights = [
data_dictionary['test_weights'],
data_dictionary['train_weights']
]
sample_weight = data_dictionary["train_weights"]

View File

@ -43,13 +43,11 @@ class TensorBoardCallback(BaseTensorBoardCallback):
if not evals_log:
return False
for data, metric in evals_log.items():
for metric_name, log in metric.items():
evals = ["validation", "train"]
for metric, eval in zip(evals_log.items(), evals):
for metric_name, log in metric[1].items():
score = log[-1][0] if isinstance(log[-1], tuple) else log[-1]
if data == "train":
self.writer.add_scalar("train_loss", score, epoch)
else:
self.writer.add_scalar("valid_loss", score, epoch)
self.writer.add_scalar(f"{eval}-{metric_name}", score, epoch)
return False