This commit is contained in:
Matthias 2024-09-13 23:18:35 +02:00 committed by GitHub
commit c2f08b1d1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 22 additions and 17 deletions

View File

@ -47,19 +47,20 @@ class BaseEnvironment(gym.Env):
def __init__(
self,
df: DataFrame = DataFrame(),
prices: DataFrame = DataFrame(),
reward_kwargs: dict = {},
*,
df: DataFrame,
prices: DataFrame,
reward_kwargs: dict,
window_size=10,
starting_point=True,
id: str = "baseenv-1", # noqa: A002
seed: int = 1,
config: dict = {},
config: dict,
live: bool = False,
fee: float = 0.0015,
can_short: bool = False,
pair: str = "",
df_raw: DataFrame = DataFrame(),
df_raw: DataFrame,
):
"""
Initializes the training/eval environment.

View File

@ -488,7 +488,7 @@ def make_env(
seed: int,
train_df: DataFrame,
price: DataFrame,
env_info: Dict[str, Any] = {},
env_info: Dict[str, Any],
) -> Callable:
"""
Utility function for multiprocessed env.

View File

@ -214,7 +214,7 @@ class FreqaiDataKitchen:
self,
unfiltered_df: DataFrame,
training_feature_list: List,
label_list: List = list(),
label_list: Optional[List] = None,
training_filter: bool = True,
) -> Tuple[DataFrame, DataFrame]:
"""
@ -244,7 +244,7 @@ class FreqaiDataKitchen:
# we don't care about total row number (total no. datapoints) in training, we only care
# about removing any row with NaNs
# if labels has multiple columns (user wants to train multiple modelEs), we detect here
labels = unfiltered_df.filter(label_list, axis=1)
labels = unfiltered_df.filter(label_list or [], axis=1)
drop_index_labels = pd.isnull(labels).any(axis=1)
drop_index_labels = (
drop_index_labels.replace(True, 1).replace(False, 0).infer_objects(copy=False)
@ -654,8 +654,8 @@ class FreqaiDataKitchen:
pair: str,
tf: str,
strategy: IStrategy,
corr_dataframes: dict = {},
base_dataframes: dict = {},
corr_dataframes: dict,
base_dataframes: dict,
is_corr_pairs: bool = False,
) -> DataFrame:
"""
@ -776,7 +776,7 @@ class FreqaiDataKitchen:
corr_dataframes: dict = {},
base_dataframes: dict = {},
pair: str = "",
prediction_dataframe: DataFrame = pd.DataFrame(),
prediction_dataframe: Optional[DataFrame] = None,
do_corr_pairs: bool = True,
) -> DataFrame:
"""
@ -822,7 +822,7 @@ class FreqaiDataKitchen:
if tf not in corr_dataframes[p]:
corr_dataframes[p][tf] = pd.DataFrame()
if not prediction_dataframe.empty:
if prediction_dataframe is not None and not prediction_dataframe.empty:
dataframe = prediction_dataframe.copy()
base_dataframes[self.config["timeframe"]] = dataframe.copy()
else:

View File

@ -25,7 +25,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
criterion: nn.Module,
device: str,
data_convertor: PyTorchDataConvertor,
model_meta_data: Dict[str, Any] = {},
model_meta_data: Optional[Dict[str, Any]] = None,
window_size: int = 1,
tb_logger: Any = None,
**kwargs,
@ -45,6 +45,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
:param n_epochs: The maximum number batches to use for evaluation.
:param batch_size: The size of the batches to use during training.
"""
if model_meta_data is None:
model_meta_data = {}
self.model = model
self.optimizer = optimizer
self.criterion = criterion

View File

@ -168,8 +168,6 @@ max-complexity = 12
[tool.ruff.lint.per-file-ignores]
"freqtrade/freqai/**/*.py" = [
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"B006", # Bugbear - mutable default argument
"B008", # bugbear - Do not perform function calls in argument defaults
]
"tests/**/*.py" = [
"S101", # allow assert in tests

View File

@ -151,7 +151,9 @@ def test_get_pair_data_for_features_with_prealoaded_data(mocker, freqai_conf):
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
_, base_df = freqai.dd.get_base_and_corr_dataframes(timerange, "LTC/BTC", freqai.dk)
df = freqai.dk.get_pair_data_for_features("LTC/BTC", "5m", strategy, base_dataframes=base_df)
df = freqai.dk.get_pair_data_for_features(
"LTC/BTC", "5m", strategy, {}, base_dataframes=base_df
)
assert df is base_df["5m"]
assert not df.empty
@ -171,7 +173,9 @@ def test_get_pair_data_for_features_without_preloaded_data(mocker, freqai_conf):
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
base_df = {"5m": pd.DataFrame()}
df = freqai.dk.get_pair_data_for_features("LTC/BTC", "5m", strategy, base_dataframes=base_df)
df = freqai.dk.get_pair_data_for_features(
"LTC/BTC", "5m", strategy, {}, base_dataframes=base_df
)
assert df is not base_df["5m"]
assert not df.empty