mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
Merge pull request #9897 from freqtrade/fix/xgboosttensorboard
fix: try plotting as much info in xgboost tensorboard as possible
This commit is contained in:
commit
426bc4c97b
|
@ -36,8 +36,15 @@ class XGBoostRegressor(BaseRegressionModel):
|
|||
eval_set = None
|
||||
eval_weights = None
|
||||
else:
|
||||
eval_set = [(data_dictionary["test_features"], data_dictionary["test_labels"])]
|
||||
eval_weights = [data_dictionary['test_weights']]
|
||||
eval_set = [
|
||||
(data_dictionary["test_features"],
|
||||
data_dictionary["test_labels"]),
|
||||
(X, y)
|
||||
]
|
||||
eval_weights = [
|
||||
data_dictionary['test_weights'],
|
||||
data_dictionary['train_weights']
|
||||
]
|
||||
|
||||
sample_weight = data_dictionary["train_weights"]
|
||||
|
||||
|
|
|
@ -43,13 +43,11 @@ class TensorBoardCallback(BaseTensorBoardCallback):
|
|||
if not evals_log:
|
||||
return False
|
||||
|
||||
for data, metric in evals_log.items():
|
||||
for metric_name, log in metric.items():
|
||||
evals = ["validation", "train"]
|
||||
for metric, eval in zip(evals_log.items(), evals):
|
||||
for metric_name, log in metric[1].items():
|
||||
score = log[-1][0] if isinstance(log[-1], tuple) else log[-1]
|
||||
if data == "train":
|
||||
self.writer.add_scalar("train_loss", score, epoch)
|
||||
else:
|
||||
self.writer.add_scalar("valid_loss", score, epoch)
|
||||
self.writer.add_scalar(f"{eval}-{metric_name}", score, epoch)
|
||||
|
||||
return False
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user