mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 02:12:01 +00:00
Compare commits
17 Commits
0f62a2c115
...
4857d8c1ef
Author | SHA1 | Date | |
---|---|---|---|
|
4857d8c1ef | ||
|
ae41ab101a | ||
|
f4881e7c6f | ||
|
94ef4380d4 | ||
|
7ebe1b8c14 | ||
|
79020bba28 | ||
|
95c250ebcc | ||
|
bfb14614cc | ||
|
12299d4810 | ||
|
c67a9d4e84 | ||
|
bb62b0fc5a | ||
|
3436e8aa1d | ||
|
1d5abe5b75 | ||
|
ffd828b6ad | ||
|
c83dd2d806 | ||
|
dc5766fb10 | ||
|
07fba3abb0 |
|
@ -222,7 +222,6 @@ Mandatory parameters are marked as **Required**, which means that they are requi
|
||||||
| `exchange.ccxt_async_config` | Additional CCXT parameters passed to the async ccxt instance. Parameters may differ from exchange to exchange and are documented in the [ccxt documentation](https://docs.ccxt.com/#/README?id=overriding-exchange-properties-upon-instantiation) <br> **Datatype:** Dict
|
| `exchange.ccxt_async_config` | Additional CCXT parameters passed to the async ccxt instance. Parameters may differ from exchange to exchange and are documented in the [ccxt documentation](https://docs.ccxt.com/#/README?id=overriding-exchange-properties-upon-instantiation) <br> **Datatype:** Dict
|
||||||
| `exchange.enable_ws` | Enable the usage of Websockets for the exchange. <br>[More information](#consuming-exchange-websockets).<br>*Defaults to `true`.* <br> **Datatype:** Boolean
|
| `exchange.enable_ws` | Enable the usage of Websockets for the exchange. <br>[More information](#consuming-exchange-websockets).<br>*Defaults to `true`.* <br> **Datatype:** Boolean
|
||||||
| `exchange.markets_refresh_interval` | The interval in minutes in which markets are reloaded. <br>*Defaults to `60` minutes.* <br> **Datatype:** Positive Integer
|
| `exchange.markets_refresh_interval` | The interval in minutes in which markets are reloaded. <br>*Defaults to `60` minutes.* <br> **Datatype:** Positive Integer
|
||||||
| `exchange.skip_pair_validation` | Skip pairlist validation on startup.<br>*Defaults to `false`*<br> **Datatype:** Boolean
|
|
||||||
| `exchange.skip_open_order_update` | Skips open order updates on startup should the exchange cause problems. Only relevant in live conditions.<br>*Defaults to `false`*<br> **Datatype:** Boolean
|
| `exchange.skip_open_order_update` | Skips open order updates on startup should the exchange cause problems. Only relevant in live conditions.<br>*Defaults to `false`*<br> **Datatype:** Boolean
|
||||||
| `exchange.unknown_fee_rate` | Fallback value to use when calculating trading fees. This can be useful for exchanges which have fees in non-tradable currencies. The value provided here will be multiplied with the "fee cost".<br>*Defaults to `None`<br> **Datatype:** float
|
| `exchange.unknown_fee_rate` | Fallback value to use when calculating trading fees. This can be useful for exchanges which have fees in non-tradable currencies. The value provided here will be multiplied with the "fee cost".<br>*Defaults to `None`<br> **Datatype:** float
|
||||||
| `exchange.log_responses` | Log relevant exchange responses. For debug mode only - use with care.<br>*Defaults to `false`*<br> **Datatype:** Boolean
|
| `exchange.log_responses` | Log relevant exchange responses. For debug mode only - use with care.<br>*Defaults to `false`*<br> **Datatype:** Boolean
|
||||||
|
|
|
@ -205,7 +205,7 @@ This is called with each iteration of the bot (only if the Pairlist Handler is a
|
||||||
|
|
||||||
It must return the resulting pairlist (which may then be passed into the chain of Pairlist Handlers).
|
It must return the resulting pairlist (which may then be passed into the chain of Pairlist Handlers).
|
||||||
|
|
||||||
Validations are optional, the parent class exposes a `_verify_blacklist(pairlist)` and `_whitelist_for_active_markets(pairlist)` to do default filtering. Use this if you limit your result to a certain number of pairs - so the end-result is not shorter than expected.
|
Validations are optional, the parent class exposes a `verify_blacklist(pairlist)` and `_whitelist_for_active_markets(pairlist)` to do default filtering. Use this if you limit your result to a certain number of pairs - so the end-result is not shorter than expected.
|
||||||
|
|
||||||
#### filter_pairlist
|
#### filter_pairlist
|
||||||
|
|
||||||
|
@ -219,7 +219,7 @@ The default implementation in the base class simply calls the `_validate_pair()`
|
||||||
|
|
||||||
If overridden, it must return the resulting pairlist (which may then be passed into the next Pairlist Handler in the chain).
|
If overridden, it must return the resulting pairlist (which may then be passed into the next Pairlist Handler in the chain).
|
||||||
|
|
||||||
Validations are optional, the parent class exposes a `_verify_blacklist(pairlist)` and `_whitelist_for_active_markets(pairlist)` to do default filters. Use this if you limit your result to a certain number of pairs - so the end result is not shorter than expected.
|
Validations are optional, the parent class exposes a `verify_blacklist(pairlist)` and `_whitelist_for_active_markets(pairlist)` to do default filters. Use this if you limit your result to a certain number of pairs - so the end result is not shorter than expected.
|
||||||
|
|
||||||
In `VolumePairList`, this implements different methods of sorting, does early validation so only the expected number of pairs is returned.
|
In `VolumePairList`, this implements different methods of sorting, does early validation so only the expected number of pairs is returned.
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,6 @@ It uses configuration from `exchange.pair_whitelist` and `exchange.pair_blacklis
|
||||||
By default, only currently enabled pairs are allowed.
|
By default, only currently enabled pairs are allowed.
|
||||||
To skip pair validation against active markets, set `"allow_inactive": true` within the `StaticPairList` configuration.
|
To skip pair validation against active markets, set `"allow_inactive": true` within the `StaticPairList` configuration.
|
||||||
This can be useful for backtesting expired pairs (like quarterly spot-markets).
|
This can be useful for backtesting expired pairs (like quarterly spot-markets).
|
||||||
This option must be configured along with `exchange.skip_pair_validation` in the exchange configuration.
|
|
||||||
|
|
||||||
When used in a "follow-up" position (e.g. after VolumePairlist), all pairs in `'pair_whitelist'` will be added to the end of the pairlist.
|
When used in a "follow-up" position (e.g. after VolumePairlist), all pairs in `'pair_whitelist'` will be added to the end of the pairlist.
|
||||||
|
|
||||||
|
|
|
@ -610,9 +610,6 @@ def download_data_main(config: Config) -> None:
|
||||||
if "timeframes" not in config:
|
if "timeframes" not in config:
|
||||||
config["timeframes"] = DL_DATA_TIMEFRAMES
|
config["timeframes"] = DL_DATA_TIMEFRAMES
|
||||||
|
|
||||||
# Manual validations of relevant settings
|
|
||||||
if not config["exchange"].get("skip_pair_validation", False):
|
|
||||||
exchange.validate_pairs(expanded_pairs)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"About to download pairs: {expanded_pairs}, "
|
f"About to download pairs: {expanded_pairs}, "
|
||||||
f"intervals: {config['timeframes']} to {config['datadir']}"
|
f"intervals: {config['timeframes']} to {config['datadir']}"
|
||||||
|
|
|
@ -104,7 +104,6 @@ from freqtrade.misc import (
|
||||||
file_load_json,
|
file_load_json,
|
||||||
safe_value_fallback2,
|
safe_value_fallback2,
|
||||||
)
|
)
|
||||||
from freqtrade.plugins.pairlist.pairlist_helpers import expand_pairlist
|
|
||||||
from freqtrade.util import dt_from_ts, dt_now
|
from freqtrade.util import dt_from_ts, dt_now
|
||||||
from freqtrade.util.datetime_helpers import dt_humanize_delta, dt_ts, format_ms_time
|
from freqtrade.util.datetime_helpers import dt_humanize_delta, dt_ts, format_ms_time
|
||||||
from freqtrade.util.periodic_cache import PeriodicCache
|
from freqtrade.util.periodic_cache import PeriodicCache
|
||||||
|
@ -331,8 +330,6 @@ class Exchange:
|
||||||
|
|
||||||
# Check if all pairs are available
|
# Check if all pairs are available
|
||||||
self.validate_stakecurrency(config["stake_currency"])
|
self.validate_stakecurrency(config["stake_currency"])
|
||||||
if not config["exchange"].get("skip_pair_validation"):
|
|
||||||
self.validate_pairs(config["exchange"]["pair_whitelist"])
|
|
||||||
self.validate_ordertypes(config.get("order_types", {}))
|
self.validate_ordertypes(config.get("order_types", {}))
|
||||||
self.validate_order_time_in_force(config.get("order_time_in_force", {}))
|
self.validate_order_time_in_force(config.get("order_time_in_force", {}))
|
||||||
self.validate_trading_mode_and_margin_mode(self.trading_mode, self.margin_mode)
|
self.validate_trading_mode_and_margin_mode(self.trading_mode, self.margin_mode)
|
||||||
|
@ -702,54 +699,6 @@ class Exchange:
|
||||||
f"Available currencies are: {', '.join(quote_currencies)}"
|
f"Available currencies are: {', '.join(quote_currencies)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_pairs(self, pairs: List[str]) -> None:
|
|
||||||
"""
|
|
||||||
Checks if all given pairs are tradable on the current exchange.
|
|
||||||
:param pairs: list of pairs
|
|
||||||
:raise: OperationalException if one pair is not available
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not self.markets:
|
|
||||||
logger.warning("Unable to validate pairs (assuming they are correct).")
|
|
||||||
return
|
|
||||||
extended_pairs = expand_pairlist(pairs, list(self.markets), keep_invalid=True)
|
|
||||||
invalid_pairs = []
|
|
||||||
for pair in extended_pairs:
|
|
||||||
# Note: ccxt has BaseCurrency/QuoteCurrency format for pairs
|
|
||||||
if self.markets and pair not in self.markets:
|
|
||||||
raise OperationalException(
|
|
||||||
f"Pair {pair} is not available on {self.name} {self.trading_mode}. "
|
|
||||||
f"Please remove {pair} from your whitelist."
|
|
||||||
)
|
|
||||||
|
|
||||||
# From ccxt Documentation:
|
|
||||||
# markets.info: An associative array of non-common market properties,
|
|
||||||
# including fees, rates, limits and other general market information.
|
|
||||||
# The internal info array is different for each particular market,
|
|
||||||
# its contents depend on the exchange.
|
|
||||||
# It can also be a string or similar ... so we need to verify that first.
|
|
||||||
elif isinstance(self.markets[pair].get("info"), dict) and self.markets[pair].get(
|
|
||||||
"info", {}
|
|
||||||
).get("prohibitedIn", False):
|
|
||||||
# Warn users about restricted pairs in whitelist.
|
|
||||||
# We cannot determine reliably if Users are affected.
|
|
||||||
logger.warning(
|
|
||||||
f"Pair {pair} is restricted for some users on this exchange."
|
|
||||||
f"Please check if you are impacted by this restriction "
|
|
||||||
f"on the exchange and eventually remove {pair} from your whitelist."
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
self._config["stake_currency"]
|
|
||||||
and self.get_pair_quote_currency(pair) != self._config["stake_currency"]
|
|
||||||
):
|
|
||||||
invalid_pairs.append(pair)
|
|
||||||
if invalid_pairs:
|
|
||||||
raise OperationalException(
|
|
||||||
f"Stake-currency '{self._config['stake_currency']}' not compatible with "
|
|
||||||
f"pair-whitelist. Please remove the following pairs: {invalid_pairs}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_valid_pair_combination(self, curr_1: str, curr_2: str) -> str:
|
def get_valid_pair_combination(self, curr_1: str, curr_2: str) -> str:
|
||||||
"""
|
"""
|
||||||
Get valid pair combination of curr_1 and curr_2 by trying both combinations.
|
Get valid pair combination of curr_1 and curr_2 by trying both combinations.
|
||||||
|
|
|
@ -22,12 +22,18 @@ class Base5ActionRLEnv(BaseEnvironment):
|
||||||
Base class for a 5 action environment
|
Base class for a 5 action environment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, *args, action_space_type: str = "Discrete", **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.action_space_type = action_space_type
|
||||||
self.actions = Actions
|
self.actions = Actions
|
||||||
|
|
||||||
def set_action_space(self):
|
def set_action_space(self):
|
||||||
self.action_space = spaces.Discrete(len(Actions))
|
if self.action_space_type == "Discrete":
|
||||||
|
self.action_space = spaces.Discrete(len(Actions))
|
||||||
|
elif self.action_space_type == "Box":
|
||||||
|
self.action_space = spaces.Box(low=-1, high=1, shape=(1,))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown action space type: {self.action_space_type}")
|
||||||
|
|
||||||
def step(self, action: int):
|
def step(self, action: int):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -60,6 +60,7 @@ class BaseEnvironment(gym.Env):
|
||||||
can_short: bool = False,
|
can_short: bool = False,
|
||||||
pair: str = "",
|
pair: str = "",
|
||||||
df_raw: DataFrame = DataFrame(),
|
df_raw: DataFrame = DataFrame(),
|
||||||
|
action_space_type: str = "Discrete"
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the training/eval environment.
|
Initializes the training/eval environment.
|
||||||
|
@ -93,6 +94,7 @@ class BaseEnvironment(gym.Env):
|
||||||
self.tensorboard_metrics: dict = {}
|
self.tensorboard_metrics: dict = {}
|
||||||
self.can_short: bool = can_short
|
self.can_short: bool = can_short
|
||||||
self.live: bool = live
|
self.live: bool = live
|
||||||
|
self.action_space_type: str = action_space_type
|
||||||
if not self.live and self.add_state_info:
|
if not self.live and self.add_state_info:
|
||||||
raise OperationalException(
|
raise OperationalException(
|
||||||
"`add_state_info` is not available in backtesting. Change "
|
"`add_state_info` is not available in backtesting. Change "
|
||||||
|
|
|
@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
|
||||||
SB3_MODELS = ["PPO", "A2C", "DQN"]
|
SB3_MODELS = ["PPO", "A2C", "DQN", "DDPG", "TD3"]
|
||||||
SB3_CONTRIB_MODELS = ["TRPO", "ARS", "RecurrentPPO", "MaskablePPO", "QRDQN"]
|
SB3_CONTRIB_MODELS = ["TRPO", "ARS", "RecurrentPPO", "MaskablePPO", "QRDQN"]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,556 @@
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import gc
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Type, Callable, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
import pandas as pd
|
||||||
|
from pandas import DataFrame
|
||||||
|
from gymnasium import spaces
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.transforms as mtransforms
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
|
||||||
|
from stable_baselines3.common.logger import HParam, Figure
|
||||||
|
|
||||||
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
|
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
|
||||||
|
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, BaseActions
|
||||||
|
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
|
||||||
|
from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReinforcementLearner_DDPG_TD3(BaseReinforcementLearningModel):
|
||||||
|
"""
|
||||||
|
Reinforcement Learning Model prediction model for DDPG and TD3.
|
||||||
|
|
||||||
|
Users can inherit from this class to make their own RL model with custom
|
||||||
|
environment/training controls. Define the file as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||||
|
|
||||||
|
class MyCoolRLModel(ReinforcementLearner):
|
||||||
|
```
|
||||||
|
|
||||||
|
Save the file to `user_data/freqaimodels`, then run it with:
|
||||||
|
|
||||||
|
freqtrade trade --freqaimodel MyCoolRLModel --config config.json --strategy SomeCoolStrat
|
||||||
|
|
||||||
|
Here the users can override any of the functions
|
||||||
|
available in the `IFreqaiModel` inheritance tree. Most importantly for RL, this
|
||||||
|
is where the user overrides `MyRLEnv` (see below), to define custom
|
||||||
|
`calculate_reward()` function, or to override any other parts of the environment.
|
||||||
|
|
||||||
|
This class also allows users to override any other part of the IFreqaiModel tree.
|
||||||
|
For example, the user can override `def fit()` or `def train()` or `def predict()`
|
||||||
|
to take fine-tuned control over these processes.
|
||||||
|
|
||||||
|
Another common override may be `def data_cleaning_predict()` where the user can
|
||||||
|
take fine-tuned control over the data handling pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Model specific config
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# Enable learning rate linear schedule
|
||||||
|
self.lr_schedule: bool = self.rl_config.get("lr_schedule", False)
|
||||||
|
|
||||||
|
# Enable tensorboard logging
|
||||||
|
self.activate_tensorboard: bool = self.rl_config.get("activate_tensorboard", True)
|
||||||
|
# TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS,
|
||||||
|
# IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!
|
||||||
|
|
||||||
|
# Enable tensorboard rollout plot
|
||||||
|
self.tensorboard_plot: bool = self.rl_config.get("tensorboard_plot", False)
|
||||||
|
|
||||||
|
def get_model_params(self):
|
||||||
|
"""
|
||||||
|
Get the model specific parameters
|
||||||
|
"""
|
||||||
|
model_params = copy.deepcopy(self.freqai_info["model_training_parameters"])
|
||||||
|
|
||||||
|
if self.lr_schedule:
|
||||||
|
_lr = model_params.get('learning_rate', 0.0003)
|
||||||
|
model_params["learning_rate"] = linear_schedule(_lr)
|
||||||
|
logger.info(f"Learning rate linear schedule enabled, initial value: {_lr}")
|
||||||
|
|
||||||
|
model_params["policy_kwargs"] = dict(
|
||||||
|
net_arch=dict(vf=self.net_arch, pi=self.net_arch),
|
||||||
|
activation_fn=th.nn.ReLU,
|
||||||
|
optimizer_class=th.optim.Adam
|
||||||
|
|
||||||
|
return model_params
|
||||||
|
|
||||||
|
def get_callbacks(self, eval_freq, data_path) -> list:
|
||||||
|
"""
|
||||||
|
Get the model specific callbacks
|
||||||
|
"""
|
||||||
|
callbacks = []
|
||||||
|
callbacks.append(self.eval_callback)
|
||||||
|
if self.activate_tensorboard:
|
||||||
|
callbacks.append(CustomTensorboardCallback())
|
||||||
|
if self.tensorboard_plot:
|
||||||
|
callbacks.append(FigureRecorderCallback())
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||||
|
"""
|
||||||
|
User customizable fit method
|
||||||
|
:param data_dictionary: dict = common data dictionary containing all train/test
|
||||||
|
features/labels/weights.
|
||||||
|
:param dk: FreqaiDatakitchen = data kitchen for current pair.
|
||||||
|
:return:
|
||||||
|
model Any = trained model to be used for inference in dry/live/backtesting
|
||||||
|
"""
|
||||||
|
train_df = data_dictionary["train_features"]
|
||||||
|
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
||||||
|
|
||||||
|
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||||
|
net_arch=self.net_arch)
|
||||||
|
|
||||||
|
if self.activate_tensorboard:
|
||||||
|
tb_path = Path(dk.full_path / "tensorboard" / dk.pair.split('/')[0])
|
||||||
|
else:
|
||||||
|
tb_path = None
|
||||||
|
|
||||||
|
model_params = self.get_model_params()
|
||||||
|
logger.info(f"Params: {model_params}")
|
||||||
|
|
||||||
|
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||||
|
model = self.MODELCLASS(self.policy_type, self.train_env,
|
||||||
|
tensorboard_log=tb_path,
|
||||||
|
**model_params)
|
||||||
|
else:
|
||||||
|
logger.info("Continual training activated - starting training from previously "
|
||||||
|
"trained agent.")
|
||||||
|
model = self.dd.model_dictionary[dk.pair]
|
||||||
|
model.set_env(self.train_env)
|
||||||
|
|
||||||
|
model.learn(
|
||||||
|
total_timesteps=int(total_timesteps),
|
||||||
|
#callback=[self.eval_callback, self.tensorboard_callback],
|
||||||
|
callback=self.get_callbacks(len(train_df), str(dk.data_path)),
|
||||||
|
progress_bar=self.rl_config.get("progress_bar", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
if Path(dk.data_path / "best_model.zip").is_file():
|
||||||
|
logger.info("Callback found a best model.")
|
||||||
|
best_model = self.MODELCLASS.load(dk.data_path / "best_model")
|
||||||
|
return best_model
|
||||||
|
|
||||||
|
logger.info("Couldnt find best model, using final model instead.")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
MyRLEnv: Type[BaseEnvironment]
|
||||||
|
|
||||||
|
class MyRLEnv(Base5ActionRLEnv): # type: ignore[no-redef]
|
||||||
|
"""
|
||||||
|
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||||
|
sets a custom reward based on profit and trade duration.
|
||||||
|
"""
|
||||||
|
def __init__(self, df, prices, reward_kwargs, window_size=10, starting_point=True, id="boxenv-1", seed=1, config={}, live=False, fee=0.0015, can_short=False, pair="", df_raw=None, action_space_type="Box"):
|
||||||
|
super().__init__(df, prices, reward_kwargs, window_size, starting_point, id, seed, config, live, fee, can_short, pair, df_raw)
|
||||||
|
|
||||||
|
# Define the action space as a continuous space between -1 and 1 for a single action dimension
|
||||||
|
self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
|
||||||
|
|
||||||
|
# Define the observation space as before
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf,
|
||||||
|
high=np.inf,
|
||||||
|
shape=(window_size, self.total_features),
|
||||||
|
dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
def calculate_reward(self, action: int) -> float:
|
||||||
|
"""
|
||||||
|
An example reward function. This is the one function that users will likely
|
||||||
|
wish to inject their own creativity into.
|
||||||
|
|
||||||
|
Warning!
|
||||||
|
This is function is a showcase of functionality designed to show as many possible
|
||||||
|
environment control features as possible. It is also designed to run quickly
|
||||||
|
on small computers. This is a benchmark, it is *not* for live production.
|
||||||
|
|
||||||
|
:param action: int = The action made by the agent for the current candle.
|
||||||
|
:return:
|
||||||
|
float = the reward to give to the agent for current step (used for optimization
|
||||||
|
of weights in NN)
|
||||||
|
"""
|
||||||
|
# first, penalize if the action is not valid
|
||||||
|
if not self._is_valid(action):
|
||||||
|
self.tensorboard_log("invalid", category="actions")
|
||||||
|
return -2
|
||||||
|
|
||||||
|
pnl = self.get_unrealized_profit()
|
||||||
|
factor = 100.
|
||||||
|
|
||||||
|
# reward agent for entering trades
|
||||||
|
if (action == Actions.Long_enter.value
|
||||||
|
and self._position == Positions.Neutral):
|
||||||
|
return 25
|
||||||
|
if (action == Actions.Short_enter.value
|
||||||
|
and self._position == Positions.Neutral):
|
||||||
|
return 25
|
||||||
|
# discourage agent from not entering trades
|
||||||
|
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||||
|
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
|
||||||
|
|
||||||
|
if trade_duration <= max_trade_duration:
|
||||||
|
factor *= 1.5
|
||||||
|
elif trade_duration > max_trade_duration:
|
||||||
|
factor *= 0.5
|
||||||
|
|
||||||
|
# discourage sitting in position
|
||||||
|
if (self._position in (Positions.Short, Positions.Long) and
|
||||||
|
action == Actions.Neutral.value):
|
||||||
|
return -1 * trade_duration / max_trade_duration
|
||||||
|
|
||||||
|
# close long
|
||||||
|
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
||||||
|
if pnl > self.profit_aim * self.rr:
|
||||||
|
factor *= self.rl_config["model_reward_parameters"].get("win_reward_factor", 2)
|
||||||
|
return float(pnl * factor)
|
||||||
|
|
||||||
|
# close short
|
||||||
|
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
||||||
|
if pnl > self.profit_aim * self.rr:
|
||||||
|
factor *= self.rl_config["model_reward_parameters"].get("win_reward_factor", 2)
|
||||||
|
return float(pnl * factor)
|
||||||
|
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
"""
|
||||||
|
Logic for a single step (incrementing one candle in time)
|
||||||
|
by the agent
|
||||||
|
:param: action: int = the action type that the agent plans
|
||||||
|
to take for the current step.
|
||||||
|
:returns:
|
||||||
|
observation = current state of environment
|
||||||
|
step_reward = the reward from `calculate_reward()`
|
||||||
|
_done = if the agent "died" or if the candles finished
|
||||||
|
info = dict passed back to openai gym lib
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Ensure action is within the range [-1, 1]
|
||||||
|
action = np.clip(action, -1, 1)
|
||||||
|
|
||||||
|
# Apply noise for exploration
|
||||||
|
self.noise_std = 0.3 # Standard deviation for exploration noise
|
||||||
|
noise = np.random.normal(0, self.noise_std, size=action.shape)
|
||||||
|
action = np.tanh(action + noise) # Ensure action is within -1 to 1
|
||||||
|
|
||||||
|
# Map the continuous action to one of the five discrete actions
|
||||||
|
discrete_action = self._map_continuous_to_discrete(action)
|
||||||
|
|
||||||
|
#print(f"{self._current_tick} Action!!!: {action}")
|
||||||
|
#print(f"{self._current_tick} Discrete Action!!!: {discrete_action}")
|
||||||
|
|
||||||
|
self._done = False
|
||||||
|
self._current_tick += 1
|
||||||
|
|
||||||
|
if self._current_tick == self._end_tick:
|
||||||
|
self._done = True
|
||||||
|
|
||||||
|
self._update_unrealized_total_profit()
|
||||||
|
step_reward = self.calculate_reward(discrete_action)
|
||||||
|
self.total_reward += step_reward
|
||||||
|
|
||||||
|
self.tensorboard_log(self.actions._member_names_[discrete_action], category="actions")
|
||||||
|
|
||||||
|
trade_type = None
|
||||||
|
if self.is_tradesignal(discrete_action):
|
||||||
|
|
||||||
|
if discrete_action == Actions.Neutral.value:
|
||||||
|
self._position = Positions.Neutral
|
||||||
|
trade_type = "neutral"
|
||||||
|
self._last_trade_tick = None
|
||||||
|
elif discrete_action == Actions.Long_enter.value:
|
||||||
|
self._position = Positions.Long
|
||||||
|
trade_type = "enter_long"
|
||||||
|
self._last_trade_tick = self._current_tick
|
||||||
|
elif discrete_action == Actions.Short_enter.value:
|
||||||
|
self._position = Positions.Short
|
||||||
|
trade_type = "enter_short"
|
||||||
|
self._last_trade_tick = self._current_tick
|
||||||
|
elif discrete_action == Actions.Long_exit.value:
|
||||||
|
self._update_total_profit()
|
||||||
|
self._position = Positions.Neutral
|
||||||
|
trade_type = "exit_long"
|
||||||
|
self._last_trade_tick = None
|
||||||
|
elif discrete_action == Actions.Short_exit.value:
|
||||||
|
self._update_total_profit()
|
||||||
|
self._position = Positions.Neutral
|
||||||
|
trade_type = "exit_short"
|
||||||
|
self._last_trade_tick = None
|
||||||
|
else:
|
||||||
|
print("case not defined")
|
||||||
|
|
||||||
|
if trade_type is not None:
|
||||||
|
self.trade_history.append(
|
||||||
|
{"price": self.current_price(), "index": self._current_tick,
|
||||||
|
"type": trade_type, "profit": self.get_unrealized_profit()})
|
||||||
|
|
||||||
|
if (self._total_profit < self.max_drawdown or
|
||||||
|
self._total_unrealized_profit < self.max_drawdown):
|
||||||
|
self._done = True
|
||||||
|
|
||||||
|
self._position_history.append(self._position)
|
||||||
|
|
||||||
|
info = dict(
|
||||||
|
tick=self._current_tick,
|
||||||
|
action=discrete_action,
|
||||||
|
total_reward=self.total_reward,
|
||||||
|
total_profit=self._total_profit,
|
||||||
|
position=self._position.value,
|
||||||
|
trade_duration=self.get_trade_duration(),
|
||||||
|
current_profit_pct=self.get_unrealized_profit()
|
||||||
|
)
|
||||||
|
|
||||||
|
observation = self._get_observation()
|
||||||
|
# user can play with time if they want
|
||||||
|
truncated = False
|
||||||
|
|
||||||
|
self._update_history(info)
|
||||||
|
|
||||||
|
return observation, step_reward, self._done, truncated, info
|
||||||
|
|
||||||
|
def _map_continuous_to_discrete(self, action):
|
||||||
|
"""
|
||||||
|
Map the continuous action (a value between -1 and 1) to one of the discrete actions.
|
||||||
|
"""
|
||||||
|
action_value = action[0] # Extract the single continuous action value
|
||||||
|
|
||||||
|
# Define the number of discrete actions
|
||||||
|
num_discrete_actions = 5
|
||||||
|
|
||||||
|
# Calculate the step size for each interval
|
||||||
|
step_size = 2 / num_discrete_actions # (2 because range is from -1 to 1)
|
||||||
|
|
||||||
|
# Generate the boundaries dynamically
|
||||||
|
boundaries = th.linspace(-1 + step_size, 1 - step_size, steps=num_discrete_actions - 1)
|
||||||
|
|
||||||
|
# Find the bucket index for the action value
|
||||||
|
bucket_index = th.bucketize(th.tensor(action_value), boundaries, right=True)
|
||||||
|
|
||||||
|
# Map the bucket index to discrete actions
|
||||||
|
discrete_actions = [
|
||||||
|
BaseActions.Neutral,
|
||||||
|
BaseActions.Long_enter,
|
||||||
|
BaseActions.Long_exit,
|
||||||
|
BaseActions.Short_enter,
|
||||||
|
BaseActions.Short_exit
|
||||||
|
]
|
||||||
|
|
||||||
|
return discrete_actions[bucket_index].value
|
||||||
|
|
||||||
|
def get_rollout_history(self) -> DataFrame:
|
||||||
|
"""
|
||||||
|
Get environment data from the first to the last trade
|
||||||
|
"""
|
||||||
|
_history_df = pd.DataFrame.from_dict(self.history)
|
||||||
|
_trade_history_df = pd.DataFrame.from_dict(self.trade_history)
|
||||||
|
_rollout_history = _history_df.merge(_trade_history_df, left_on="tick", right_on="index", how="left")
|
||||||
|
|
||||||
|
_price_history = self.prices.iloc[_rollout_history.tick].copy().reset_index()
|
||||||
|
|
||||||
|
history = pd.merge(
|
||||||
|
_rollout_history,
|
||||||
|
_price_history,
|
||||||
|
left_index=True, right_index=True
|
||||||
|
)
|
||||||
|
return history
|
||||||
|
|
||||||
|
def get_rollout_plot(self):
|
||||||
|
"""
|
||||||
|
Plot trades and environment data
|
||||||
|
"""
|
||||||
|
def transform_y_offset(ax, offset):
|
||||||
|
return mtransforms.offset_copy(ax.transData, fig=fig, x=0, y=offset, units="inches")
|
||||||
|
|
||||||
|
def plot_markers(ax, ticks, marker, color, size, offset):
|
||||||
|
ax.plot(ticks, marker=marker, color=color, markersize=size, fillstyle="full",
|
||||||
|
transform=transform_y_offset(ax, offset), linestyle="none")
|
||||||
|
|
||||||
|
plt.style.use("dark_background")
|
||||||
|
fig, axs = plt.subplots(
|
||||||
|
nrows=5, ncols=1,
|
||||||
|
figsize=(16, 9),
|
||||||
|
height_ratios=[6, 1, 1, 1, 1],
|
||||||
|
sharex=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return empty fig if no trades
|
||||||
|
if len(self.trade_history) == 0:
|
||||||
|
return fig
|
||||||
|
|
||||||
|
history = self.get_rollout_history()
|
||||||
|
enter_long_prices = history.loc[history["type"] == "enter_long"]["price"]
|
||||||
|
enter_short_prices = history.loc[history["type"] == "enter_short"]["price"]
|
||||||
|
exit_long_prices = history.loc[history["type"] == "exit_long"]["price"]
|
||||||
|
exit_short_prices = history.loc[history["type"] == "exit_short"]["price"]
|
||||||
|
|
||||||
|
axs[0].plot(history["open"], linewidth=1, color="#c28ce3")
|
||||||
|
plot_markers(axs[0], enter_long_prices, "^", "#4ae747", 5, -0.05)
|
||||||
|
plot_markers(axs[0], enter_short_prices, "v", "#f53580", 5, 0.05)
|
||||||
|
plot_markers(axs[0], exit_long_prices, "o", "#4ae747", 3, 0)
|
||||||
|
plot_markers(axs[0], exit_short_prices, "o", "#f53580", 3, 0)
|
||||||
|
|
||||||
|
axs[1].set_ylabel("pnl")
|
||||||
|
axs[1].plot(history["current_profit_pct"], linewidth=1, color="#a29db9")
|
||||||
|
axs[1].axhline(y=0, label='0', alpha=0.33)
|
||||||
|
axs[2].set_ylabel("duration")
|
||||||
|
axs[2].plot(history["trade_duration"], linewidth=1, color="#a29db9")
|
||||||
|
axs[3].set_ylabel("total_reward")
|
||||||
|
axs[3].plot(history["total_reward"], linewidth=1, color="#a29db9")
|
||||||
|
axs[3].axhline(y=0, label='0', alpha=0.33)
|
||||||
|
axs[4].set_ylabel("total_profit")
|
||||||
|
axs[4].set_xlabel("tick")
|
||||||
|
axs[4].plot(history["total_profit"], linewidth=1, color="#a29db9")
|
||||||
|
axs[4].axhline(y=1, label='1', alpha=0.33)
|
||||||
|
|
||||||
|
for _ax in axs:
|
||||||
|
for _border in ["top", "right", "bottom", "left"]:
|
||||||
|
_ax.spines[_border].set_color("#5b5e4b")
|
||||||
|
|
||||||
|
fig.suptitle(
|
||||||
|
"Total Reward: %.6f" % self.total_reward + " ~ " +
|
||||||
|
"Total Profit: %.6f" % self._total_profit
|
||||||
|
)
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
gc.collect()
|
||||||
|
th.cuda.empty_cache()
|
||||||
|
|
||||||
|
def linear_schedule(initial_value: float) -> Callable[[float], float]:
|
||||||
|
def func(progress_remaining: float) -> float:
|
||||||
|
return progress_remaining * initial_value
|
||||||
|
return func
|
||||||
|
|
||||||
|
class CustomTensorboardCallback(TensorboardCallback):
|
||||||
|
"""
|
||||||
|
Tensorboard callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _on_training_start(self) -> None:
|
||||||
|
_lr = self.model.learning_rate
|
||||||
|
|
||||||
|
if self.model.__class__.__name__ == "DDPG":
|
||||||
|
hparam_dict = {
|
||||||
|
"algorithm": self.model.__class__.__name__,
|
||||||
|
"buffer_size": self.model.buffer_size,
|
||||||
|
"learning_rate": _lr if isinstance(_lr, float) else "lr_schedule",
|
||||||
|
"learning_starts": self.model.learning_starts,
|
||||||
|
"batch_size": self.model.batch_size,
|
||||||
|
"tau": self.model.tau,
|
||||||
|
"gamma": self.model.gamma,
|
||||||
|
"train_freq": self.model.train_freq,
|
||||||
|
"gradient_steps": self.model.gradient_steps,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif self.model.__class__.__name__ == "TD3":
|
||||||
|
hparam_dict = {
|
||||||
|
"algorithm": self.model.__class__.__name__,
|
||||||
|
"learning_rate": _lr if isinstance(_lr, float) else "lr_schedule",
|
||||||
|
"buffer_size": self.model.buffer_size,
|
||||||
|
"learning_starts": self.model.learning_starts,
|
||||||
|
"batch_size": self.model.batch_size,
|
||||||
|
"tau": self.model.tau,
|
||||||
|
"gamma": self.model.gamma,
|
||||||
|
"train_freq": self.model.train_freq,
|
||||||
|
"gradient_steps": self.model.gradient_steps,
|
||||||
|
"policy_delay": self.model.policy_delay,
|
||||||
|
"target_policy_noise": self.model.target_policy_noise,
|
||||||
|
"target_noise_clip": self.model.target_noise_clip,
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
hparam_dict = {
|
||||||
|
"algorithm": self.model.__class__.__name__,
|
||||||
|
"learning_rate": _lr if isinstance(_lr, float) else "lr_schedule",
|
||||||
|
"gamma": self.model.gamma,
|
||||||
|
"gae_lambda": self.model.gae_lambda,
|
||||||
|
"n_steps": self.model.n_steps,
|
||||||
|
"batch_size": self.model.batch_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert hparam_dict values to str if they are not of type int, float, str, bool, or torch.Tensor
|
||||||
|
hparam_dict = {k: (str(v) if not isinstance(v, (int, float, str, bool, th.Tensor)) else v) for k, v in hparam_dict.items()}
|
||||||
|
|
||||||
|
metric_dict = {
|
||||||
|
"eval/mean_reward": 0,
|
||||||
|
"rollout/ep_rew_mean": 0,
|
||||||
|
"rollout/ep_len_mean": 0,
|
||||||
|
"info/total_profit": 1,
|
||||||
|
"info/trades_count": 0,
|
||||||
|
"info/trade_duration": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.logger.record(
|
||||||
|
"hparams",
|
||||||
|
HParam(hparam_dict, metric_dict),
|
||||||
|
exclude=("stdout", "log", "json", "csv"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
|
||||||
|
local_info = self.locals["infos"][0]
|
||||||
|
if self.training_env is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
tensorboard_metrics = self.training_env.env_method("get_wrapper_attr", "tensorboard_metrics")[0]
|
||||||
|
|
||||||
|
for metric in local_info:
|
||||||
|
if metric not in ["episode", "terminal_observation", "TimeLimit.truncated"]:
|
||||||
|
self.logger.record(f"info/{metric}", local_info[metric])
|
||||||
|
|
||||||
|
for category in tensorboard_metrics:
|
||||||
|
for metric in tensorboard_metrics[category]:
|
||||||
|
self.logger.record(f"{category}/{metric}", tensorboard_metrics[category][metric])
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
class FigureRecorderCallback(BaseCallback):
|
||||||
|
"""
|
||||||
|
Tensorboard figures callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, verbose=0):
|
||||||
|
super().__init__(verbose)
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _on_rollout_end(self):
|
||||||
|
try:
|
||||||
|
# Access the rollout plot directly from the base environment
|
||||||
|
figures = [env.unwrapped.get_rollout_plot() for env in self.training_env.envs]
|
||||||
|
except AttributeError:
|
||||||
|
# If the above fails, try getting it from the wrappers
|
||||||
|
figures = self.training_env.env_method("get_wrapper_attr", "get_rollout_plot")
|
||||||
|
|
||||||
|
for i, fig in enumerate(figures):
|
||||||
|
self.logger.record(
|
||||||
|
f"rollout/env_{i}",
|
||||||
|
Figure(fig, close=True),
|
||||||
|
exclude=("stdout", "log", "json", "csv")
|
||||||
|
)
|
||||||
|
plt.close(fig)
|
||||||
|
return True
|
|
@ -61,14 +61,15 @@ class StaticPairList(IPairList):
|
||||||
:param tickers: Tickers (from exchange.get_tickers). May be cached.
|
:param tickers: Tickers (from exchange.get_tickers). May be cached.
|
||||||
:return: List of pairs
|
:return: List of pairs
|
||||||
"""
|
"""
|
||||||
|
wl = self.verify_whitelist(
|
||||||
|
self._config["exchange"]["pair_whitelist"], logger.info, keep_invalid=True
|
||||||
|
)
|
||||||
if self._allow_inactive:
|
if self._allow_inactive:
|
||||||
return self.verify_whitelist(
|
return wl
|
||||||
self._config["exchange"]["pair_whitelist"], logger.info, keep_invalid=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return self._whitelist_for_active_markets(
|
# Avoid implicit filtering of "verify_whitelist" to keep
|
||||||
self.verify_whitelist(self._config["exchange"]["pair_whitelist"], logger.info)
|
# proper warnings in the log
|
||||||
)
|
return self._whitelist_for_active_markets(wl)
|
||||||
|
|
||||||
def filter_pairlist(self, pairlist: List[str], tickers: Tickers) -> List[str]:
|
def filter_pairlist(self, pairlist: List[str], tickers: Tickers) -> List[str]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -28,6 +28,7 @@ def expand_pairlist(
|
||||||
except re.error as err:
|
except re.error as err:
|
||||||
raise ValueError(f"Wildcard error in {pair_wc}, {err}")
|
raise ValueError(f"Wildcard error in {pair_wc}, {err}")
|
||||||
|
|
||||||
|
# Remove wildcard pairs that didn't have a match.
|
||||||
result = [element for element in result if re.fullmatch(r"^[A-Za-z0-9:/-]+$", element)]
|
result = [element for element in result if re.fullmatch(r"^[A-Za-z0-9:/-]+$", element)]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -255,7 +255,6 @@ def test_init_exception(default_conf, mocker):
|
||||||
def test_exchange_resolver(default_conf, mocker, caplog):
|
def test_exchange_resolver(default_conf, mocker, caplog):
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=MagicMock()))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=MagicMock()))
|
||||||
mocker.patch(f"{EXMS}._load_async_markets")
|
mocker.patch(f"{EXMS}._load_async_markets")
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
mocker.patch(f"{EXMS}.validate_timeframes")
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
mocker.patch(f"{EXMS}.validate_pricing")
|
||||||
|
@ -555,7 +554,6 @@ def test_get_min_pair_stake_amount_real_data(mocker, default_conf) -> None:
|
||||||
|
|
||||||
def test__load_async_markets(default_conf, mocker, caplog):
|
def test__load_async_markets(default_conf, mocker, caplog):
|
||||||
mocker.patch(f"{EXMS}._init_ccxt")
|
mocker.patch(f"{EXMS}._init_ccxt")
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
mocker.patch(f"{EXMS}.validate_timeframes")
|
||||||
mocker.patch(f"{EXMS}.reload_markets")
|
mocker.patch(f"{EXMS}.reload_markets")
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
|
@ -584,7 +582,6 @@ def test__load_markets(default_conf, mocker, caplog):
|
||||||
api_mock = MagicMock()
|
api_mock = MagicMock()
|
||||||
api_mock.load_markets = get_mock_coro(side_effect=ccxt.BaseError("SomeError"))
|
api_mock.load_markets = get_mock_coro(side_effect=ccxt.BaseError("SomeError"))
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
mocker.patch(f"{EXMS}.validate_timeframes")
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
mocker.patch(f"{EXMS}.validate_pricing")
|
||||||
|
@ -684,7 +681,6 @@ def test_validate_stakecurrency(default_conf, stake_currency, mocker, caplog):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
mocker.patch(f"{EXMS}.validate_timeframes")
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
mocker.patch(f"{EXMS}.validate_pricing")
|
||||||
Exchange(default_conf)
|
Exchange(default_conf)
|
||||||
|
@ -702,7 +698,6 @@ def test_validate_stakecurrency_error(default_conf, mocker, caplog):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
mocker.patch(f"{EXMS}.validate_timeframes")
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ConfigurationError,
|
ConfigurationError,
|
||||||
|
@ -755,147 +750,6 @@ def test_get_pair_base_currency(default_conf, mocker, pair, expected):
|
||||||
assert ex.get_pair_base_currency(pair) == expected
|
assert ex.get_pair_base_currency(pair) == expected
|
||||||
|
|
||||||
|
|
||||||
def test_validate_pairs(default_conf, mocker):
|
|
||||||
api_mock = MagicMock()
|
|
||||||
id_mock = PropertyMock(return_value="test_exchange")
|
|
||||||
type(api_mock).id = id_mock
|
|
||||||
|
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
|
||||||
mocker.patch(
|
|
||||||
f"{EXMS}._load_async_markets",
|
|
||||||
return_value={
|
|
||||||
"ETH/BTC": {"quote": "BTC"},
|
|
||||||
"LTC/BTC": {"quote": "BTC"},
|
|
||||||
"XRP/BTC": {"quote": "BTC"},
|
|
||||||
"NEO/BTC": {"quote": "BTC"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
|
||||||
# test exchange.validate_pairs directly
|
|
||||||
# No assert - but this should not fail (!)
|
|
||||||
Exchange(default_conf)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_pairs_not_available(default_conf, mocker):
|
|
||||||
api_mock = MagicMock()
|
|
||||||
type(api_mock).markets = PropertyMock(
|
|
||||||
return_value={"XRP/BTC": {"inactive": True, "base": "XRP", "quote": "BTC"}}
|
|
||||||
)
|
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
|
||||||
mocker.patch(f"{EXMS}._load_async_markets")
|
|
||||||
|
|
||||||
with pytest.raises(OperationalException, match=r"not available"):
|
|
||||||
Exchange(default_conf)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_pairs_exception(default_conf, mocker, caplog):
|
|
||||||
caplog.set_level(logging.INFO)
|
|
||||||
api_mock = MagicMock()
|
|
||||||
mocker.patch(f"{EXMS}.name", PropertyMock(return_value="Binance"))
|
|
||||||
|
|
||||||
type(api_mock).markets = PropertyMock(return_value={})
|
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", api_mock)
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
|
||||||
mocker.patch(f"{EXMS}._load_async_markets")
|
|
||||||
|
|
||||||
with pytest.raises(OperationalException, match=r"Pair ETH/BTC is not available on Binance"):
|
|
||||||
Exchange(default_conf)
|
|
||||||
|
|
||||||
mocker.patch(f"{EXMS}.markets", PropertyMock(return_value={}))
|
|
||||||
Exchange(default_conf)
|
|
||||||
assert log_has("Unable to validate pairs (assuming they are correct).", caplog)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_pairs_restricted(default_conf, mocker, caplog):
|
|
||||||
api_mock = MagicMock()
|
|
||||||
type(api_mock).load_markets = get_mock_coro(
|
|
||||||
return_value={
|
|
||||||
"ETH/BTC": {"quote": "BTC"},
|
|
||||||
"LTC/BTC": {"quote": "BTC"},
|
|
||||||
"XRP/BTC": {"quote": "BTC", "info": {"prohibitedIn": ["US"]}},
|
|
||||||
"NEO/BTC": {"quote": "BTC", "info": "TestString"}, # info can also be a string ...
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
|
||||||
|
|
||||||
Exchange(default_conf)
|
|
||||||
assert log_has(
|
|
||||||
"Pair XRP/BTC is restricted for some users on this exchange."
|
|
||||||
"Please check if you are impacted by this restriction "
|
|
||||||
"on the exchange and eventually remove XRP/BTC from your whitelist.",
|
|
||||||
caplog,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_pairs_stakecompatibility(default_conf, mocker):
|
|
||||||
api_mock = MagicMock()
|
|
||||||
type(api_mock).load_markets = get_mock_coro(
|
|
||||||
return_value={
|
|
||||||
"ETH/BTC": {"quote": "BTC"},
|
|
||||||
"LTC/BTC": {"quote": "BTC"},
|
|
||||||
"XRP/BTC": {"quote": "BTC"},
|
|
||||||
"NEO/BTC": {"quote": "BTC"},
|
|
||||||
"HELLO-WORLD": {"quote": "BTC"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
|
||||||
|
|
||||||
Exchange(default_conf)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_pairs_stakecompatibility_downloaddata(default_conf, mocker):
|
|
||||||
api_mock = MagicMock()
|
|
||||||
default_conf["stake_currency"] = ""
|
|
||||||
type(api_mock).load_markets = get_mock_coro(
|
|
||||||
return_value={
|
|
||||||
"ETH/BTC": {"quote": "BTC"},
|
|
||||||
"LTC/BTC": {"quote": "BTC"},
|
|
||||||
"XRP/BTC": {"quote": "BTC"},
|
|
||||||
"NEO/BTC": {"quote": "BTC"},
|
|
||||||
"HELLO-WORLD": {"quote": "BTC"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
|
||||||
|
|
||||||
Exchange(default_conf)
|
|
||||||
assert type(api_mock).load_markets.call_count == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_pairs_stakecompatibility_fail(default_conf, mocker):
|
|
||||||
default_conf["exchange"]["pair_whitelist"].append("HELLO-WORLD")
|
|
||||||
api_mock = MagicMock()
|
|
||||||
type(api_mock).load_markets = get_mock_coro(
|
|
||||||
return_value={
|
|
||||||
"ETH/BTC": {"quote": "BTC"},
|
|
||||||
"LTC/BTC": {"quote": "BTC"},
|
|
||||||
"XRP/BTC": {"quote": "BTC"},
|
|
||||||
"NEO/BTC": {"quote": "BTC"},
|
|
||||||
"HELLO-WORLD": {"quote": "USDT"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
|
||||||
|
|
||||||
with pytest.raises(OperationalException, match=r"Stake-currency 'BTC' not compatible with.*"):
|
|
||||||
Exchange(default_conf)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("timeframe", [("5m"), ("1m"), ("15m"), ("1h")])
|
@pytest.mark.parametrize("timeframe", [("5m"), ("1m"), ("15m"), ("1h")])
|
||||||
def test_validate_timeframes(default_conf, mocker, timeframe):
|
def test_validate_timeframes(default_conf, mocker, timeframe):
|
||||||
default_conf["timeframe"] = timeframe
|
default_conf["timeframe"] = timeframe
|
||||||
|
@ -907,7 +761,6 @@ def test_validate_timeframes(default_conf, mocker, timeframe):
|
||||||
|
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
||||||
mocker.patch(f"{EXMS}.reload_markets")
|
mocker.patch(f"{EXMS}.reload_markets")
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
mocker.patch(f"{EXMS}.validate_pricing")
|
||||||
Exchange(default_conf)
|
Exchange(default_conf)
|
||||||
|
@ -925,7 +778,6 @@ def test_validate_timeframes_failed(default_conf, mocker):
|
||||||
|
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
||||||
mocker.patch(f"{EXMS}.reload_markets")
|
mocker.patch(f"{EXMS}.reload_markets")
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
mocker.patch(f"{EXMS}.validate_pricing")
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
|
@ -955,7 +807,6 @@ def test_validate_timeframes_emulated_ohlcv_1(default_conf, mocker):
|
||||||
|
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
||||||
mocker.patch(f"{EXMS}.reload_markets")
|
mocker.patch(f"{EXMS}.reload_markets")
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
OperationalException,
|
OperationalException,
|
||||||
|
@ -977,7 +828,6 @@ def test_validate_timeframes_emulated_ohlcvi_2(default_conf, mocker):
|
||||||
|
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
||||||
mocker.patch(f"{EXMS}.reload_markets")
|
mocker.patch(f"{EXMS}.reload_markets")
|
||||||
mocker.patch(f"{EXMS}.validate_pairs", MagicMock())
|
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
OperationalException,
|
OperationalException,
|
||||||
|
@ -999,7 +849,6 @@ def test_validate_timeframes_not_in_config(default_conf, mocker):
|
||||||
|
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
||||||
mocker.patch(f"{EXMS}.reload_markets")
|
mocker.patch(f"{EXMS}.reload_markets")
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
mocker.patch(f"{EXMS}.validate_pricing")
|
||||||
mocker.patch(f"{EXMS}.validate_required_startup_candles")
|
mocker.patch(f"{EXMS}.validate_required_startup_candles")
|
||||||
|
@ -1016,7 +865,6 @@ def test_validate_pricing(default_conf, mocker):
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
||||||
mocker.patch(f"{EXMS}.reload_markets")
|
mocker.patch(f"{EXMS}.reload_markets")
|
||||||
mocker.patch(f"{EXMS}.validate_trading_mode_and_margin_mode")
|
mocker.patch(f"{EXMS}.validate_trading_mode_and_margin_mode")
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
mocker.patch(f"{EXMS}.validate_timeframes")
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
mocker.patch(f"{EXMS}.name", "Binance")
|
mocker.patch(f"{EXMS}.name", "Binance")
|
||||||
|
@ -1051,7 +899,6 @@ def test_validate_ordertypes(default_conf, mocker):
|
||||||
type(api_mock).has = PropertyMock(return_value={"createMarketOrder": True})
|
type(api_mock).has = PropertyMock(return_value={"createMarketOrder": True})
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
||||||
mocker.patch(f"{EXMS}.reload_markets")
|
mocker.patch(f"{EXMS}.reload_markets")
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
mocker.patch(f"{EXMS}.validate_timeframes")
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
mocker.patch(f"{EXMS}.validate_pricing")
|
||||||
|
@ -1110,7 +957,6 @@ def test_validate_ordertypes_stop_advanced(default_conf, mocker, exchange_name,
|
||||||
type(api_mock).has = PropertyMock(return_value={"createMarketOrder": True})
|
type(api_mock).has = PropertyMock(return_value={"createMarketOrder": True})
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
||||||
mocker.patch(f"{EXMS}.reload_markets")
|
mocker.patch(f"{EXMS}.reload_markets")
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
mocker.patch(f"{EXMS}.validate_timeframes")
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
mocker.patch(f"{EXMS}.validate_pricing")
|
||||||
|
@ -1135,7 +981,6 @@ def test_validate_order_types_not_in_config(default_conf, mocker):
|
||||||
api_mock = MagicMock()
|
api_mock = MagicMock()
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
|
||||||
mocker.patch(f"{EXMS}.reload_markets")
|
mocker.patch(f"{EXMS}.reload_markets")
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
mocker.patch(f"{EXMS}.validate_timeframes")
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
mocker.patch(f"{EXMS}.validate_pricing")
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
|
@ -1151,7 +996,6 @@ def test_validate_required_startup_candles(default_conf, mocker, caplog):
|
||||||
mocker.patch(f"{EXMS}._init_ccxt", api_mock)
|
mocker.patch(f"{EXMS}._init_ccxt", api_mock)
|
||||||
mocker.patch(f"{EXMS}.validate_timeframes")
|
mocker.patch(f"{EXMS}.validate_timeframes")
|
||||||
mocker.patch(f"{EXMS}._load_async_markets")
|
mocker.patch(f"{EXMS}._load_async_markets")
|
||||||
mocker.patch(f"{EXMS}.validate_pairs")
|
|
||||||
mocker.patch(f"{EXMS}.validate_pricing")
|
mocker.patch(f"{EXMS}.validate_pricing")
|
||||||
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
mocker.patch(f"{EXMS}.validate_stakecurrency")
|
||||||
|
|
||||||
|
@ -4185,7 +4029,6 @@ def test_merge_ft_has_dict(default_conf, mocker):
|
||||||
EXMS,
|
EXMS,
|
||||||
_init_ccxt=MagicMock(return_value=MagicMock()),
|
_init_ccxt=MagicMock(return_value=MagicMock()),
|
||||||
_load_async_markets=MagicMock(),
|
_load_async_markets=MagicMock(),
|
||||||
validate_pairs=MagicMock(),
|
|
||||||
validate_timeframes=MagicMock(),
|
validate_timeframes=MagicMock(),
|
||||||
validate_stakecurrency=MagicMock(),
|
validate_stakecurrency=MagicMock(),
|
||||||
validate_pricing=MagicMock(),
|
validate_pricing=MagicMock(),
|
||||||
|
@ -4220,7 +4063,6 @@ def test_get_valid_pair_combination(default_conf, mocker, markets):
|
||||||
EXMS,
|
EXMS,
|
||||||
_init_ccxt=MagicMock(return_value=MagicMock()),
|
_init_ccxt=MagicMock(return_value=MagicMock()),
|
||||||
_load_async_markets=MagicMock(),
|
_load_async_markets=MagicMock(),
|
||||||
validate_pairs=MagicMock(),
|
|
||||||
validate_timeframes=MagicMock(),
|
validate_timeframes=MagicMock(),
|
||||||
validate_pricing=MagicMock(),
|
validate_pricing=MagicMock(),
|
||||||
markets=PropertyMock(return_value=markets),
|
markets=PropertyMock(return_value=markets),
|
||||||
|
@ -4500,7 +4342,6 @@ def test_get_markets(
|
||||||
EXMS,
|
EXMS,
|
||||||
_init_ccxt=MagicMock(return_value=MagicMock()),
|
_init_ccxt=MagicMock(return_value=MagicMock()),
|
||||||
_load_async_markets=MagicMock(),
|
_load_async_markets=MagicMock(),
|
||||||
validate_pairs=MagicMock(),
|
|
||||||
validate_timeframes=MagicMock(),
|
validate_timeframes=MagicMock(),
|
||||||
validate_pricing=MagicMock(),
|
validate_pricing=MagicMock(),
|
||||||
markets=PropertyMock(return_value=markets_static),
|
markets=PropertyMock(return_value=markets_static),
|
||||||
|
|
|
@ -2204,7 +2204,6 @@ def test_manage_open_orders_buy_exception(
|
||||||
patch_exchange(mocker)
|
patch_exchange(mocker)
|
||||||
mocker.patch.multiple(
|
mocker.patch.multiple(
|
||||||
EXMS,
|
EXMS,
|
||||||
validate_pairs=MagicMock(),
|
|
||||||
fetch_ticker=ticker_usdt,
|
fetch_ticker=ticker_usdt,
|
||||||
fetch_order=MagicMock(side_effect=ExchangeError),
|
fetch_order=MagicMock(side_effect=ExchangeError),
|
||||||
cancel_order=cancel_order_mock,
|
cancel_order=cancel_order_mock,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user