diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index c35d1afb4..a5fe9b90b 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -210,8 +210,16 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model): freqai.extract_data_and_train_model(new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange) + if freqai.dd.model_type == 'joblib': + model_file_extension = ".joblib" + elif freqai.dd.model_type == "pytorch": + model_file_extension = ".zip" + else: + raise Exception(f"Unsupported model type: {freqai.dd.model_type}," + f" can't assign model_file_extension") - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").exists() + assert Path(freqai.dk.data_path / + f"{freqai.dk.model_filename}_model{model_file_extension}").exists() assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").exists() assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").exists() assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").exists()