diff --git a/config_examples/config_freqai.example.json b/config_examples/config_freqai.example.json index 12eb30128..9494ba0e1 100644 --- a/config_examples/config_freqai.example.json +++ b/config_examples/config_freqai.example.json @@ -77,7 +77,8 @@ "indicator_periods_candles": [ 10, 20 - ] + ], + "plot_feature_importance": true }, "data_split_parameters": { "test_size": 0.33, diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index 78931bed4..0cc51fdab 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -20,6 +20,7 @@ from freqtrade.exceptions import OperationalException from freqtrade.exchange import timeframe_to_seconds from freqtrade.freqai.data_drawer import FreqaiDataDrawer from freqtrade.freqai.data_kitchen import FreqaiDataKitchen +from freqtrade.freqai.utils import plot_feature_importance from freqtrade.strategy.interface import IStrategy @@ -555,6 +556,14 @@ class IFreqaiModel(ABC): model = self.train(unfiltered_dataframe, pair, dk) + if self.freqai_info["feature_parameters"].get("plot_feature_importance", False): + plot_feature_importance( + model=model, + feature_names=dk.training_features_list, + pair=pair, + train_dir=dk.data_path + ) + self.dd.pair_dict[pair]["trained_timestamp"] = new_trained_timerange.stopts dk.set_new_model_names(pair, new_trained_timerange) self.dd.pair_dict[pair]["first"] = False diff --git a/freqtrade/freqai/utils.py b/freqtrade/freqai/utils.py index 6a70f050f..86d89d4b0 100644 --- a/freqtrade/freqai/utils.py +++ b/freqtrade/freqai/utils.py @@ -1,5 +1,13 @@ import logging from datetime import datetime, timezone +# for plot_feature_importance +from pathlib import Path + +import numpy as np +import pandas as pd +import plotly.graph_objects as go +import plotly.io as pio +from plotly.subplots import make_subplots from freqtrade.configuration import TimeRange from freqtrade.data.dataprovider import DataProvider @@ -132,3 +140,66 @@ def get_required_data_timerange( # trading_mode=config.get("trading_mode", "spot"), # prepend=config.get("prepend_data", False), # ) + + +def plot_feature_importance(model, feature_names, pair, train_dir, count_max=25) -> None: + """ + Plot Best and Worst Features by importance for CatBoost model. + Called once per sub-train. + + Required: pip install kaleido + + Usage: plot_feature_importance( + model=model, + feature_names=dk.training_features_list, + pair=pair, + train_dir=dk.data_path) + """ + + # Gather feature importance from model + if "catboost.core" in str(model.__class__): + fi = model.get_feature_importance() + + elif "lightgbm.sklearn" in str(model.__class__): + fi = model.feature_importances_ + + else: + raise NotImplementedError(f"Cannot extract feature importance for {model.__class__}") + + # Data preparation + fi_df = pd.DataFrame({ + "feature_names": np.array(feature_names), + "feature_importance": np.array(fi) + }) + fi_df_top = fi_df.nlargest(count_max, "feature_importance")[::-1] + fi_df_worst = fi_df.nsmallest(count_max, "feature_importance")[::-1] + + # Plotting + fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.5) + fig.add_trace( + go.Bar( + x=fi_df_top["feature_importance"], + y=fi_df_top["feature_names"], + orientation='h', showlegend=False + ), row=1, col=1 + ) + fig.add_trace( + go.Bar( + x=fi_df_worst["feature_importance"], + y=fi_df_worst["feature_names"], + orientation='h', showlegend=False + ), row=1, col=2 + ) + fig.update_layout( + title_text=f"Best and Worst Features {pair}", + width=1000, height=600 + ) + + # Create directory and save image + model_dir, train_name = str(train_dir).rsplit("/", 1) + fi_dir = Path(f"{model_dir}/feature_importance/{pair.split('/')[0]}") + fi_dir.mkdir(parents=True, exist_ok=True) + + pio.write_image(fig, f"{fi_dir}/{train_name}.png", format="png") + + logger.info(f"Freqai saving feature importance plot {fi_dir}/{train_name}.png") diff --git a/requirements-plot.txt b/requirements-plot.txt index 80cd3f4f2..ef3cf9f24 100644 --- a/requirements-plot.txt +++ b/requirements-plot.txt @@ -2,3 +2,4 @@ -r requirements.txt plotly==5.10.0 +kaleido==0.2.1