diff --git a/freqtrade/freqai/data_kitchen.py b/freqtrade/freqai/data_kitchen.py index d43f569d8..04f7322dc 100644 --- a/freqtrade/freqai/data_kitchen.py +++ b/freqtrade/freqai/data_kitchen.py @@ -214,7 +214,7 @@ class FreqaiDataKitchen: self, unfiltered_df: DataFrame, training_feature_list: List, - label_list: List = list(), + label_list: Optional[List] = None, training_filter: bool = True, ) -> Tuple[DataFrame, DataFrame]: """ @@ -244,7 +244,7 @@ class FreqaiDataKitchen: # we don't care about total row number (total no. datapoints) in training, we only care # about removing any row with NaNs # if labels has multiple columns (user wants to train multiple modelEs), we detect here - labels = unfiltered_df.filter(label_list, axis=1) + labels = unfiltered_df.filter(label_list or [], axis=1) drop_index_labels = pd.isnull(labels).any(axis=1) drop_index_labels = ( drop_index_labels.replace(True, 1).replace(False, 0).infer_objects(copy=False) @@ -654,8 +654,8 @@ class FreqaiDataKitchen: pair: str, tf: str, strategy: IStrategy, - corr_dataframes: dict = {}, - base_dataframes: dict = {}, + corr_dataframes: dict, + base_dataframes: dict, is_corr_pairs: bool = False, ) -> DataFrame: """ diff --git a/tests/freqai/test_freqai_datakitchen.py b/tests/freqai/test_freqai_datakitchen.py index 27efc3a66..5b7ec3ef1 100644 --- a/tests/freqai/test_freqai_datakitchen.py +++ b/tests/freqai/test_freqai_datakitchen.py @@ -151,7 +151,9 @@ def test_get_pair_data_for_features_with_prealoaded_data(mocker, freqai_conf): freqai.dd.load_all_pair_histories(timerange, freqai.dk) _, base_df = freqai.dd.get_base_and_corr_dataframes(timerange, "LTC/BTC", freqai.dk) - df = freqai.dk.get_pair_data_for_features("LTC/BTC", "5m", strategy, base_dataframes=base_df) + df = freqai.dk.get_pair_data_for_features( + "LTC/BTC", "5m", strategy, {}, base_dataframes=base_df + ) assert df is base_df["5m"] assert not df.empty @@ -171,7 +173,9 @@ def test_get_pair_data_for_features_without_preloaded_data(mocker, freqai_conf): freqai.dd.load_all_pair_histories(timerange, freqai.dk) base_df = {"5m": pd.DataFrame()} - df = freqai.dk.get_pair_data_for_features("LTC/BTC", "5m", strategy, base_dataframes=base_df) + df = freqai.dk.get_pair_data_for_features( + "LTC/BTC", "5m", strategy, {}, base_dataframes=base_df + ) assert df is not base_df["5m"] assert not df.empty