Compare commits

...

17 Commits

Author SHA1 Message Date
Robert Caulk
4857d8c1ef
Merge bb62b0fc5a into ae41ab101a 2024-09-15 19:43:45 +08:00
Matthias
ae41ab101a docs: remove skip_pair_validation - it's no longer used.
Some checks are pending
Build Documentation / Deploy Docs through mike (push) Waiting to run
2024-09-15 11:28:57 +02:00
Matthias
f4881e7c6f tests: Adjust tests for removed validate_pairlist functionality 2024-09-15 11:28:57 +02:00
Matthias
94ef4380d4 chore: remove validate_pairs from exchange class
Invalid pairs were filtered out before this was called in most cases.
in cases where it's not - regular pairlist-filtering provides proper warnings.
2024-09-15 11:28:57 +02:00
Matthias
7ebe1b8c14 chore: remove pointless validation
pairs are validated through expand_pairlist.
If they're not in markets, they'll no longer be in the
pairlist once this function function is hit.
2024-09-15 11:02:49 +02:00
Matthias
79020bba28 chore: Remove "prohibitedIn" check
it's only been used for bitrex, which does no longer exist.
apparently this was forgotten when decomissioning bittrex.
2024-09-15 10:49:26 +02:00
Matthias
95c250ebcc chore: add explaining comment 2024-09-15 10:37:28 +02:00
Matthias
bfb14614cc chore: enhance change with comment 2024-09-15 09:48:44 +02:00
Matthias
12299d4810 feat: staticPairlist to warn for invalid pairs
Warnings about invalid pairs were "covered" by the implicit
filtering of `expand_pairlist()`
2024-09-15 09:46:47 +02:00
Matthias
c67a9d4e84 docs: update pairlist creation docs 2024-09-15 09:29:45 +02:00
Shane
bb62b0fc5a
Update ReinforcementLearner_DDPG_TD3.py
Clean up set policy code.
2024-05-26 20:21:16 +10:00
Shane
3436e8aa1d
Update Base5ActionRLEnv.py
Fix init
2024-05-24 22:15:34 +10:00
Shane
1d5abe5b75
Update Base5ActionRLEnv.py
Fix init
2024-05-24 22:05:28 +10:00
Shane
ffd828b6ad
Create ReinforcementLearner_DDPG_TD3.py
Reinforcement Learning Model to support DDPG and TD3.
2024-05-24 21:45:09 +10:00
Shane
c83dd2d806
Update BaseReinforcementLearningModel.py
Add support for DDPG and TD3.
2024-05-24 21:29:38 +10:00
Shane
dc5766fb10
Update Base5ActionRLEnv.py
Addition of action_space_type to support Discrete and Box action spaces.
2024-05-24 21:12:56 +10:00
Shane
07fba3abb0
Update BaseEnvironment.py
Addition of action_space_type.
2024-05-24 21:10:06 +10:00
13 changed files with 578 additions and 228 deletions

View File

@ -222,7 +222,6 @@ Mandatory parameters are marked as **Required**, which means that they are requi
| `exchange.ccxt_async_config` | Additional CCXT parameters passed to the async ccxt instance. Parameters may differ from exchange to exchange and are documented in the [ccxt documentation](https://docs.ccxt.com/#/README?id=overriding-exchange-properties-upon-instantiation) <br> **Datatype:** Dict
| `exchange.enable_ws` | Enable the usage of Websockets for the exchange. <br>[More information](#consuming-exchange-websockets).<br>*Defaults to `true`.* <br> **Datatype:** Boolean
| `exchange.markets_refresh_interval` | The interval in minutes in which markets are reloaded. <br>*Defaults to `60` minutes.* <br> **Datatype:** Positive Integer
| `exchange.skip_pair_validation` | Skip pairlist validation on startup.<br>*Defaults to `false`*<br> **Datatype:** Boolean
| `exchange.skip_open_order_update` | Skips open order updates on startup should the exchange cause problems. Only relevant in live conditions.<br>*Defaults to `false`*<br> **Datatype:** Boolean
| `exchange.unknown_fee_rate` | Fallback value to use when calculating trading fees. This can be useful for exchanges which have fees in non-tradable currencies. The value provided here will be multiplied with the "fee cost".<br>*Defaults to `None`<br> **Datatype:** float
| `exchange.log_responses` | Log relevant exchange responses. For debug mode only - use with care.<br>*Defaults to `false`*<br> **Datatype:** Boolean

View File

@ -205,7 +205,7 @@ This is called with each iteration of the bot (only if the Pairlist Handler is a
It must return the resulting pairlist (which may then be passed into the chain of Pairlist Handlers).
Validations are optional, the parent class exposes a `_verify_blacklist(pairlist)` and `_whitelist_for_active_markets(pairlist)` to do default filtering. Use this if you limit your result to a certain number of pairs - so the end-result is not shorter than expected.
Validations are optional, the parent class exposes a `verify_blacklist(pairlist)` and `_whitelist_for_active_markets(pairlist)` to do default filtering. Use this if you limit your result to a certain number of pairs - so the end-result is not shorter than expected.
#### filter_pairlist
@ -219,7 +219,7 @@ The default implementation in the base class simply calls the `_validate_pair()`
If overridden, it must return the resulting pairlist (which may then be passed into the next Pairlist Handler in the chain).
Validations are optional, the parent class exposes a `_verify_blacklist(pairlist)` and `_whitelist_for_active_markets(pairlist)` to do default filters. Use this if you limit your result to a certain number of pairs - so the end result is not shorter than expected.
Validations are optional, the parent class exposes a `verify_blacklist(pairlist)` and `_whitelist_for_active_markets(pairlist)` to do default filters. Use this if you limit your result to a certain number of pairs - so the end result is not shorter than expected.
In `VolumePairList`, this implements different methods of sorting, does early validation so only the expected number of pairs is returned.

View File

@ -55,7 +55,6 @@ It uses configuration from `exchange.pair_whitelist` and `exchange.pair_blacklis
By default, only currently enabled pairs are allowed.
To skip pair validation against active markets, set `"allow_inactive": true` within the `StaticPairList` configuration.
This can be useful for backtesting expired pairs (like quarterly spot-markets).
This option must be configured along with `exchange.skip_pair_validation` in the exchange configuration.
When used in a "follow-up" position (e.g. after VolumePairlist), all pairs in `'pair_whitelist'` will be added to the end of the pairlist.

View File

@ -610,9 +610,6 @@ def download_data_main(config: Config) -> None:
if "timeframes" not in config:
config["timeframes"] = DL_DATA_TIMEFRAMES
# Manual validations of relevant settings
if not config["exchange"].get("skip_pair_validation", False):
exchange.validate_pairs(expanded_pairs)
logger.info(
f"About to download pairs: {expanded_pairs}, "
f"intervals: {config['timeframes']} to {config['datadir']}"

View File

@ -104,7 +104,6 @@ from freqtrade.misc import (
file_load_json,
safe_value_fallback2,
)
from freqtrade.plugins.pairlist.pairlist_helpers import expand_pairlist
from freqtrade.util import dt_from_ts, dt_now
from freqtrade.util.datetime_helpers import dt_humanize_delta, dt_ts, format_ms_time
from freqtrade.util.periodic_cache import PeriodicCache
@ -331,8 +330,6 @@ class Exchange:
# Check if all pairs are available
self.validate_stakecurrency(config["stake_currency"])
if not config["exchange"].get("skip_pair_validation"):
self.validate_pairs(config["exchange"]["pair_whitelist"])
self.validate_ordertypes(config.get("order_types", {}))
self.validate_order_time_in_force(config.get("order_time_in_force", {}))
self.validate_trading_mode_and_margin_mode(self.trading_mode, self.margin_mode)
@ -702,54 +699,6 @@ class Exchange:
f"Available currencies are: {', '.join(quote_currencies)}"
)
def validate_pairs(self, pairs: List[str]) -> None:
"""
Checks if all given pairs are tradable on the current exchange.
:param pairs: list of pairs
:raise: OperationalException if one pair is not available
:return: None
"""
if not self.markets:
logger.warning("Unable to validate pairs (assuming they are correct).")
return
extended_pairs = expand_pairlist(pairs, list(self.markets), keep_invalid=True)
invalid_pairs = []
for pair in extended_pairs:
# Note: ccxt has BaseCurrency/QuoteCurrency format for pairs
if self.markets and pair not in self.markets:
raise OperationalException(
f"Pair {pair} is not available on {self.name} {self.trading_mode}. "
f"Please remove {pair} from your whitelist."
)
# From ccxt Documentation:
# markets.info: An associative array of non-common market properties,
# including fees, rates, limits and other general market information.
# The internal info array is different for each particular market,
# its contents depend on the exchange.
# It can also be a string or similar ... so we need to verify that first.
elif isinstance(self.markets[pair].get("info"), dict) and self.markets[pair].get(
"info", {}
).get("prohibitedIn", False):
# Warn users about restricted pairs in whitelist.
# We cannot determine reliably if Users are affected.
logger.warning(
f"Pair {pair} is restricted for some users on this exchange."
f"Please check if you are impacted by this restriction "
f"on the exchange and eventually remove {pair} from your whitelist."
)
if (
self._config["stake_currency"]
and self.get_pair_quote_currency(pair) != self._config["stake_currency"]
):
invalid_pairs.append(pair)
if invalid_pairs:
raise OperationalException(
f"Stake-currency '{self._config['stake_currency']}' not compatible with "
f"pair-whitelist. Please remove the following pairs: {invalid_pairs}"
)
def get_valid_pair_combination(self, curr_1: str, curr_2: str) -> str:
"""
Get valid pair combination of curr_1 and curr_2 by trying both combinations.

View File

@ -22,12 +22,18 @@ class Base5ActionRLEnv(BaseEnvironment):
Base class for a 5 action environment
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init__(self, *args, action_space_type: str = "Discrete", **kwargs):
super().__init__(*args, **kwargs)
self.action_space_type = action_space_type
self.actions = Actions
def set_action_space(self):
self.action_space = spaces.Discrete(len(Actions))
if self.action_space_type == "Discrete":
self.action_space = spaces.Discrete(len(Actions))
elif self.action_space_type == "Box":
self.action_space = spaces.Box(low=-1, high=1, shape=(1,))
else:
raise ValueError(f"Unknown action space type: {self.action_space_type}")
def step(self, action: int):
"""

View File

@ -60,6 +60,7 @@ class BaseEnvironment(gym.Env):
can_short: bool = False,
pair: str = "",
df_raw: DataFrame = DataFrame(),
action_space_type: str = "Discrete"
):
"""
Initializes the training/eval environment.
@ -93,6 +94,7 @@ class BaseEnvironment(gym.Env):
self.tensorboard_metrics: dict = {}
self.can_short: bool = can_short
self.live: bool = live
self.action_space_type: str = action_space_type
if not self.live and self.add_state_info:
raise OperationalException(
"`add_state_info` is not available in backtesting. Change "

View File

@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
torch.multiprocessing.set_sharing_strategy("file_system")
SB3_MODELS = ["PPO", "A2C", "DQN"]
SB3_MODELS = ["PPO", "A2C", "DQN", "DDPG", "TD3"]
SB3_CONTRIB_MODELS = ["TRPO", "ARS", "RecurrentPPO", "MaskablePPO", "QRDQN"]

View File

@ -0,0 +1,556 @@
import copy
import logging
import gc
from pathlib import Path
from typing import Any, Dict, Type, Callable, List, Optional, Union
import numpy as np
import torch as th
import pandas as pd
from pandas import DataFrame
from gymnasium import spaces
import matplotlib
import matplotlib.transforms as mtransforms
import matplotlib.pyplot as plt
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
from stable_baselines3.common.logger import HParam, Figure
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, BaseActions
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback
logger = logging.getLogger(__name__)
class ReinforcementLearner_DDPG_TD3(BaseReinforcementLearningModel):
"""
Reinforcement Learning Model prediction model for DDPG and TD3.
Users can inherit from this class to make their own RL model with custom
environment/training controls. Define the file as follows:
```
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
class MyCoolRLModel(ReinforcementLearner):
```
Save the file to `user_data/freqaimodels`, then run it with:
freqtrade trade --freqaimodel MyCoolRLModel --config config.json --strategy SomeCoolStrat
Here the users can override any of the functions
available in the `IFreqaiModel` inheritance tree. Most importantly for RL, this
is where the user overrides `MyRLEnv` (see below), to define custom
`calculate_reward()` function, or to override any other parts of the environment.
This class also allows users to override any other part of the IFreqaiModel tree.
For example, the user can override `def fit()` or `def train()` or `def predict()`
to take fine-tuned control over these processes.
Another common override may be `def data_cleaning_predict()` where the user can
take fine-tuned control over the data handling pipeline.
"""
def __init__(self, **kwargs) -> None:
"""
Model specific config
"""
super().__init__(**kwargs)
# Enable learning rate linear schedule
self.lr_schedule: bool = self.rl_config.get("lr_schedule", False)
# Enable tensorboard logging
self.activate_tensorboard: bool = self.rl_config.get("activate_tensorboard", True)
# TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS,
# IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!
# Enable tensorboard rollout plot
self.tensorboard_plot: bool = self.rl_config.get("tensorboard_plot", False)
def get_model_params(self):
"""
Get the model specific parameters
"""
model_params = copy.deepcopy(self.freqai_info["model_training_parameters"])
if self.lr_schedule:
_lr = model_params.get('learning_rate', 0.0003)
model_params["learning_rate"] = linear_schedule(_lr)
logger.info(f"Learning rate linear schedule enabled, initial value: {_lr}")
model_params["policy_kwargs"] = dict(
net_arch=dict(vf=self.net_arch, pi=self.net_arch),
activation_fn=th.nn.ReLU,
optimizer_class=th.optim.Adam
return model_params
def get_callbacks(self, eval_freq, data_path) -> list:
"""
Get the model specific callbacks
"""
callbacks = []
callbacks.append(self.eval_callback)
if self.activate_tensorboard:
callbacks.append(CustomTensorboardCallback())
if self.tensorboard_plot:
callbacks.append(FigureRecorderCallback())
return callbacks
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
"""
User customizable fit method
:param data_dictionary: dict = common data dictionary containing all train/test
features/labels/weights.
:param dk: FreqaiDatakitchen = data kitchen for current pair.
:return:
model Any = trained model to be used for inference in dry/live/backtesting
"""
train_df = data_dictionary["train_features"]
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=self.net_arch)
if self.activate_tensorboard:
tb_path = Path(dk.full_path / "tensorboard" / dk.pair.split('/')[0])
else:
tb_path = None
model_params = self.get_model_params()
logger.info(f"Params: {model_params}")
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
model = self.MODELCLASS(self.policy_type, self.train_env,
tensorboard_log=tb_path,
**model_params)
else:
logger.info("Continual training activated - starting training from previously "
"trained agent.")
model = self.dd.model_dictionary[dk.pair]
model.set_env(self.train_env)
model.learn(
total_timesteps=int(total_timesteps),
#callback=[self.eval_callback, self.tensorboard_callback],
callback=self.get_callbacks(len(train_df), str(dk.data_path)),
progress_bar=self.rl_config.get("progress_bar", False)
)
if Path(dk.data_path / "best_model.zip").is_file():
logger.info("Callback found a best model.")
best_model = self.MODELCLASS.load(dk.data_path / "best_model")
return best_model
logger.info("Couldnt find best model, using final model instead.")
return model
MyRLEnv: Type[BaseEnvironment]
class MyRLEnv(Base5ActionRLEnv): # type: ignore[no-redef]
"""
User can override any function in BaseRLEnv and gym.Env. Here the user
sets a custom reward based on profit and trade duration.
"""
def __init__(self, df, prices, reward_kwargs, window_size=10, starting_point=True, id="boxenv-1", seed=1, config={}, live=False, fee=0.0015, can_short=False, pair="", df_raw=None, action_space_type="Box"):
super().__init__(df, prices, reward_kwargs, window_size, starting_point, id, seed, config, live, fee, can_short, pair, df_raw)
# Define the action space as a continuous space between -1 and 1 for a single action dimension
self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
# Define the observation space as before
self.observation_space = spaces.Box(
low=-np.inf,
high=np.inf,
shape=(window_size, self.total_features),
dtype=np.float32
)
def calculate_reward(self, action: int) -> float:
"""
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
Warning!
This is function is a showcase of functionality designed to show as many possible
environment control features as possible. It is also designed to run quickly
on small computers. This is a benchmark, it is *not* for live production.
:param action: int = The action made by the agent for the current candle.
:return:
float = the reward to give to the agent for current step (used for optimization
of weights in NN)
"""
# first, penalize if the action is not valid
if not self._is_valid(action):
self.tensorboard_log("invalid", category="actions")
return -2
pnl = self.get_unrealized_profit()
factor = 100.
# reward agent for entering trades
if (action == Actions.Long_enter.value
and self._position == Positions.Neutral):
return 25
if (action == Actions.Short_enter.value
and self._position == Positions.Neutral):
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
if trade_duration <= max_trade_duration:
factor *= 1.5
elif trade_duration > max_trade_duration:
factor *= 0.5
# discourage sitting in position
if (self._position in (Positions.Short, Positions.Long) and
action == Actions.Neutral.value):
return -1 * trade_duration / max_trade_duration
# close long
if action == Actions.Long_exit.value and self._position == Positions.Long:
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config["model_reward_parameters"].get("win_reward_factor", 2)
return float(pnl * factor)
# close short
if action == Actions.Short_exit.value and self._position == Positions.Short:
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config["model_reward_parameters"].get("win_reward_factor", 2)
return float(pnl * factor)
return 0.
def step(self, action):
"""
Logic for a single step (incrementing one candle in time)
by the agent
:param: action: int = the action type that the agent plans
to take for the current step.
:returns:
observation = current state of environment
step_reward = the reward from `calculate_reward()`
_done = if the agent "died" or if the candles finished
info = dict passed back to openai gym lib
"""
# Ensure action is within the range [-1, 1]
action = np.clip(action, -1, 1)
# Apply noise for exploration
self.noise_std = 0.3 # Standard deviation for exploration noise
noise = np.random.normal(0, self.noise_std, size=action.shape)
action = np.tanh(action + noise) # Ensure action is within -1 to 1
# Map the continuous action to one of the five discrete actions
discrete_action = self._map_continuous_to_discrete(action)
#print(f"{self._current_tick} Action!!!: {action}")
#print(f"{self._current_tick} Discrete Action!!!: {discrete_action}")
self._done = False
self._current_tick += 1
if self._current_tick == self._end_tick:
self._done = True
self._update_unrealized_total_profit()
step_reward = self.calculate_reward(discrete_action)
self.total_reward += step_reward
self.tensorboard_log(self.actions._member_names_[discrete_action], category="actions")
trade_type = None
if self.is_tradesignal(discrete_action):
if discrete_action == Actions.Neutral.value:
self._position = Positions.Neutral
trade_type = "neutral"
self._last_trade_tick = None
elif discrete_action == Actions.Long_enter.value:
self._position = Positions.Long
trade_type = "enter_long"
self._last_trade_tick = self._current_tick
elif discrete_action == Actions.Short_enter.value:
self._position = Positions.Short
trade_type = "enter_short"
self._last_trade_tick = self._current_tick
elif discrete_action == Actions.Long_exit.value:
self._update_total_profit()
self._position = Positions.Neutral
trade_type = "exit_long"
self._last_trade_tick = None
elif discrete_action == Actions.Short_exit.value:
self._update_total_profit()
self._position = Positions.Neutral
trade_type = "exit_short"
self._last_trade_tick = None
else:
print("case not defined")
if trade_type is not None:
self.trade_history.append(
{"price": self.current_price(), "index": self._current_tick,
"type": trade_type, "profit": self.get_unrealized_profit()})
if (self._total_profit < self.max_drawdown or
self._total_unrealized_profit < self.max_drawdown):
self._done = True
self._position_history.append(self._position)
info = dict(
tick=self._current_tick,
action=discrete_action,
total_reward=self.total_reward,
total_profit=self._total_profit,
position=self._position.value,
trade_duration=self.get_trade_duration(),
current_profit_pct=self.get_unrealized_profit()
)
observation = self._get_observation()
# user can play with time if they want
truncated = False
self._update_history(info)
return observation, step_reward, self._done, truncated, info
def _map_continuous_to_discrete(self, action):
"""
Map the continuous action (a value between -1 and 1) to one of the discrete actions.
"""
action_value = action[0] # Extract the single continuous action value
# Define the number of discrete actions
num_discrete_actions = 5
# Calculate the step size for each interval
step_size = 2 / num_discrete_actions # (2 because range is from -1 to 1)
# Generate the boundaries dynamically
boundaries = th.linspace(-1 + step_size, 1 - step_size, steps=num_discrete_actions - 1)
# Find the bucket index for the action value
bucket_index = th.bucketize(th.tensor(action_value), boundaries, right=True)
# Map the bucket index to discrete actions
discrete_actions = [
BaseActions.Neutral,
BaseActions.Long_enter,
BaseActions.Long_exit,
BaseActions.Short_enter,
BaseActions.Short_exit
]
return discrete_actions[bucket_index].value
def get_rollout_history(self) -> DataFrame:
"""
Get environment data from the first to the last trade
"""
_history_df = pd.DataFrame.from_dict(self.history)
_trade_history_df = pd.DataFrame.from_dict(self.trade_history)
_rollout_history = _history_df.merge(_trade_history_df, left_on="tick", right_on="index", how="left")
_price_history = self.prices.iloc[_rollout_history.tick].copy().reset_index()
history = pd.merge(
_rollout_history,
_price_history,
left_index=True, right_index=True
)
return history
def get_rollout_plot(self):
"""
Plot trades and environment data
"""
def transform_y_offset(ax, offset):
return mtransforms.offset_copy(ax.transData, fig=fig, x=0, y=offset, units="inches")
def plot_markers(ax, ticks, marker, color, size, offset):
ax.plot(ticks, marker=marker, color=color, markersize=size, fillstyle="full",
transform=transform_y_offset(ax, offset), linestyle="none")
plt.style.use("dark_background")
fig, axs = plt.subplots(
nrows=5, ncols=1,
figsize=(16, 9),
height_ratios=[6, 1, 1, 1, 1],
sharex=True
)
# Return empty fig if no trades
if len(self.trade_history) == 0:
return fig
history = self.get_rollout_history()
enter_long_prices = history.loc[history["type"] == "enter_long"]["price"]
enter_short_prices = history.loc[history["type"] == "enter_short"]["price"]
exit_long_prices = history.loc[history["type"] == "exit_long"]["price"]
exit_short_prices = history.loc[history["type"] == "exit_short"]["price"]
axs[0].plot(history["open"], linewidth=1, color="#c28ce3")
plot_markers(axs[0], enter_long_prices, "^", "#4ae747", 5, -0.05)
plot_markers(axs[0], enter_short_prices, "v", "#f53580", 5, 0.05)
plot_markers(axs[0], exit_long_prices, "o", "#4ae747", 3, 0)
plot_markers(axs[0], exit_short_prices, "o", "#f53580", 3, 0)
axs[1].set_ylabel("pnl")
axs[1].plot(history["current_profit_pct"], linewidth=1, color="#a29db9")
axs[1].axhline(y=0, label='0', alpha=0.33)
axs[2].set_ylabel("duration")
axs[2].plot(history["trade_duration"], linewidth=1, color="#a29db9")
axs[3].set_ylabel("total_reward")
axs[3].plot(history["total_reward"], linewidth=1, color="#a29db9")
axs[3].axhline(y=0, label='0', alpha=0.33)
axs[4].set_ylabel("total_profit")
axs[4].set_xlabel("tick")
axs[4].plot(history["total_profit"], linewidth=1, color="#a29db9")
axs[4].axhline(y=1, label='1', alpha=0.33)
for _ax in axs:
for _border in ["top", "right", "bottom", "left"]:
_ax.spines[_border].set_color("#5b5e4b")
fig.suptitle(
"Total Reward: %.6f" % self.total_reward + " ~ " +
"Total Profit: %.6f" % self._total_profit
)
fig.tight_layout()
return fig
def close(self) -> None:
gc.collect()
th.cuda.empty_cache()
def linear_schedule(initial_value: float) -> Callable[[float], float]:
def func(progress_remaining: float) -> float:
return progress_remaining * initial_value
return func
class CustomTensorboardCallback(TensorboardCallback):
"""
Tensorboard callback
"""
def _on_training_start(self) -> None:
_lr = self.model.learning_rate
if self.model.__class__.__name__ == "DDPG":
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"buffer_size": self.model.buffer_size,
"learning_rate": _lr if isinstance(_lr, float) else "lr_schedule",
"learning_starts": self.model.learning_starts,
"batch_size": self.model.batch_size,
"tau": self.model.tau,
"gamma": self.model.gamma,
"train_freq": self.model.train_freq,
"gradient_steps": self.model.gradient_steps,
}
elif self.model.__class__.__name__ == "TD3":
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning_rate": _lr if isinstance(_lr, float) else "lr_schedule",
"buffer_size": self.model.buffer_size,
"learning_starts": self.model.learning_starts,
"batch_size": self.model.batch_size,
"tau": self.model.tau,
"gamma": self.model.gamma,
"train_freq": self.model.train_freq,
"gradient_steps": self.model.gradient_steps,
"policy_delay": self.model.policy_delay,
"target_policy_noise": self.model.target_policy_noise,
"target_noise_clip": self.model.target_noise_clip,
}
else:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning_rate": _lr if isinstance(_lr, float) else "lr_schedule",
"gamma": self.model.gamma,
"gae_lambda": self.model.gae_lambda,
"n_steps": self.model.n_steps,
"batch_size": self.model.batch_size,
}
# Convert hparam_dict values to str if they are not of type int, float, str, bool, or torch.Tensor
hparam_dict = {k: (str(v) if not isinstance(v, (int, float, str, bool, th.Tensor)) else v) for k, v in hparam_dict.items()}
metric_dict = {
"eval/mean_reward": 0,
"rollout/ep_rew_mean": 0,
"rollout/ep_len_mean": 0,
"info/total_profit": 1,
"info/trades_count": 0,
"info/trade_duration": 0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
local_info = self.locals["infos"][0]
if self.training_env is None:
return True
tensorboard_metrics = self.training_env.env_method("get_wrapper_attr", "tensorboard_metrics")[0]
for metric in local_info:
if metric not in ["episode", "terminal_observation", "TimeLimit.truncated"]:
self.logger.record(f"info/{metric}", local_info[metric])
for category in tensorboard_metrics:
for metric in tensorboard_metrics[category]:
self.logger.record(f"{category}/{metric}", tensorboard_metrics[category][metric])
return True
class FigureRecorderCallback(BaseCallback):
"""
Tensorboard figures callback
"""
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_step(self) -> bool:
return True
def _on_rollout_end(self):
try:
# Access the rollout plot directly from the base environment
figures = [env.unwrapped.get_rollout_plot() for env in self.training_env.envs]
except AttributeError:
# If the above fails, try getting it from the wrappers
figures = self.training_env.env_method("get_wrapper_attr", "get_rollout_plot")
for i, fig in enumerate(figures):
self.logger.record(
f"rollout/env_{i}",
Figure(fig, close=True),
exclude=("stdout", "log", "json", "csv")
)
plt.close(fig)
return True

View File

@ -61,14 +61,15 @@ class StaticPairList(IPairList):
:param tickers: Tickers (from exchange.get_tickers). May be cached.
:return: List of pairs
"""
wl = self.verify_whitelist(
self._config["exchange"]["pair_whitelist"], logger.info, keep_invalid=True
)
if self._allow_inactive:
return self.verify_whitelist(
self._config["exchange"]["pair_whitelist"], logger.info, keep_invalid=True
)
return wl
else:
return self._whitelist_for_active_markets(
self.verify_whitelist(self._config["exchange"]["pair_whitelist"], logger.info)
)
# Avoid implicit filtering of "verify_whitelist" to keep
# proper warnings in the log
return self._whitelist_for_active_markets(wl)
def filter_pairlist(self, pairlist: List[str], tickers: Tickers) -> List[str]:
"""

View File

@ -28,6 +28,7 @@ def expand_pairlist(
except re.error as err:
raise ValueError(f"Wildcard error in {pair_wc}, {err}")
# Remove wildcard pairs that didn't have a match.
result = [element for element in result if re.fullmatch(r"^[A-Za-z0-9:/-]+$", element)]
else:

View File

@ -255,7 +255,6 @@ def test_init_exception(default_conf, mocker):
def test_exchange_resolver(default_conf, mocker, caplog):
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=MagicMock()))
mocker.patch(f"{EXMS}._load_async_markets")
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}.validate_pricing")
@ -555,7 +554,6 @@ def test_get_min_pair_stake_amount_real_data(mocker, default_conf) -> None:
def test__load_async_markets(default_conf, mocker, caplog):
mocker.patch(f"{EXMS}._init_ccxt")
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.reload_markets")
mocker.patch(f"{EXMS}.validate_stakecurrency")
@ -584,7 +582,6 @@ def test__load_markets(default_conf, mocker, caplog):
api_mock = MagicMock()
api_mock.load_markets = get_mock_coro(side_effect=ccxt.BaseError("SomeError"))
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}.validate_pricing")
@ -684,7 +681,6 @@ def test_validate_stakecurrency(default_conf, stake_currency, mocker, caplog):
}
)
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_pricing")
Exchange(default_conf)
@ -702,7 +698,6 @@ def test_validate_stakecurrency_error(default_conf, mocker, caplog):
}
)
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_timeframes")
with pytest.raises(
ConfigurationError,
@ -755,147 +750,6 @@ def test_get_pair_base_currency(default_conf, mocker, pair, expected):
assert ex.get_pair_base_currency(pair) == expected
def test_validate_pairs(default_conf, mocker):
api_mock = MagicMock()
id_mock = PropertyMock(return_value="test_exchange")
type(api_mock).id = id_mock
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(
f"{EXMS}._load_async_markets",
return_value={
"ETH/BTC": {"quote": "BTC"},
"LTC/BTC": {"quote": "BTC"},
"XRP/BTC": {"quote": "BTC"},
"NEO/BTC": {"quote": "BTC"},
},
)
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}.validate_pricing")
# test exchange.validate_pairs directly
# No assert - but this should not fail (!)
Exchange(default_conf)
def test_validate_pairs_not_available(default_conf, mocker):
api_mock = MagicMock()
type(api_mock).markets = PropertyMock(
return_value={"XRP/BTC": {"inactive": True, "base": "XRP", "quote": "BTC"}}
)
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}._load_async_markets")
with pytest.raises(OperationalException, match=r"not available"):
Exchange(default_conf)
def test_validate_pairs_exception(default_conf, mocker, caplog):
caplog.set_level(logging.INFO)
api_mock = MagicMock()
mocker.patch(f"{EXMS}.name", PropertyMock(return_value="Binance"))
type(api_mock).markets = PropertyMock(return_value={})
mocker.patch(f"{EXMS}._init_ccxt", api_mock)
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}.validate_pricing")
mocker.patch(f"{EXMS}._load_async_markets")
with pytest.raises(OperationalException, match=r"Pair ETH/BTC is not available on Binance"):
Exchange(default_conf)
mocker.patch(f"{EXMS}.markets", PropertyMock(return_value={}))
Exchange(default_conf)
assert log_has("Unable to validate pairs (assuming they are correct).", caplog)
def test_validate_pairs_restricted(default_conf, mocker, caplog):
api_mock = MagicMock()
type(api_mock).load_markets = get_mock_coro(
return_value={
"ETH/BTC": {"quote": "BTC"},
"LTC/BTC": {"quote": "BTC"},
"XRP/BTC": {"quote": "BTC", "info": {"prohibitedIn": ["US"]}},
"NEO/BTC": {"quote": "BTC", "info": "TestString"}, # info can also be a string ...
}
)
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_pricing")
mocker.patch(f"{EXMS}.validate_stakecurrency")
Exchange(default_conf)
assert log_has(
"Pair XRP/BTC is restricted for some users on this exchange."
"Please check if you are impacted by this restriction "
"on the exchange and eventually remove XRP/BTC from your whitelist.",
caplog,
)
def test_validate_pairs_stakecompatibility(default_conf, mocker):
api_mock = MagicMock()
type(api_mock).load_markets = get_mock_coro(
return_value={
"ETH/BTC": {"quote": "BTC"},
"LTC/BTC": {"quote": "BTC"},
"XRP/BTC": {"quote": "BTC"},
"NEO/BTC": {"quote": "BTC"},
"HELLO-WORLD": {"quote": "BTC"},
}
)
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}.validate_pricing")
Exchange(default_conf)
def test_validate_pairs_stakecompatibility_downloaddata(default_conf, mocker):
api_mock = MagicMock()
default_conf["stake_currency"] = ""
type(api_mock).load_markets = get_mock_coro(
return_value={
"ETH/BTC": {"quote": "BTC"},
"LTC/BTC": {"quote": "BTC"},
"XRP/BTC": {"quote": "BTC"},
"NEO/BTC": {"quote": "BTC"},
"HELLO-WORLD": {"quote": "BTC"},
}
)
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}.validate_pricing")
Exchange(default_conf)
assert type(api_mock).load_markets.call_count == 1
def test_validate_pairs_stakecompatibility_fail(default_conf, mocker):
default_conf["exchange"]["pair_whitelist"].append("HELLO-WORLD")
api_mock = MagicMock()
type(api_mock).load_markets = get_mock_coro(
return_value={
"ETH/BTC": {"quote": "BTC"},
"LTC/BTC": {"quote": "BTC"},
"XRP/BTC": {"quote": "BTC"},
"NEO/BTC": {"quote": "BTC"},
"HELLO-WORLD": {"quote": "USDT"},
}
)
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_stakecurrency")
with pytest.raises(OperationalException, match=r"Stake-currency 'BTC' not compatible with.*"):
Exchange(default_conf)
@pytest.mark.parametrize("timeframe", [("5m"), ("1m"), ("15m"), ("1h")])
def test_validate_timeframes(default_conf, mocker, timeframe):
default_conf["timeframe"] = timeframe
@ -907,7 +761,6 @@ def test_validate_timeframes(default_conf, mocker, timeframe):
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.reload_markets")
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}.validate_pricing")
Exchange(default_conf)
@ -925,7 +778,6 @@ def test_validate_timeframes_failed(default_conf, mocker):
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.reload_markets")
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}.validate_pricing")
with pytest.raises(
@ -955,7 +807,6 @@ def test_validate_timeframes_emulated_ohlcv_1(default_conf, mocker):
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.reload_markets")
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_stakecurrency")
with pytest.raises(
OperationalException,
@ -977,7 +828,6 @@ def test_validate_timeframes_emulated_ohlcvi_2(default_conf, mocker):
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.reload_markets")
mocker.patch(f"{EXMS}.validate_pairs", MagicMock())
mocker.patch(f"{EXMS}.validate_stakecurrency")
with pytest.raises(
OperationalException,
@ -999,7 +849,6 @@ def test_validate_timeframes_not_in_config(default_conf, mocker):
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.reload_markets")
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}.validate_pricing")
mocker.patch(f"{EXMS}.validate_required_startup_candles")
@ -1016,7 +865,6 @@ def test_validate_pricing(default_conf, mocker):
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.reload_markets")
mocker.patch(f"{EXMS}.validate_trading_mode_and_margin_mode")
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}.name", "Binance")
@ -1051,7 +899,6 @@ def test_validate_ordertypes(default_conf, mocker):
type(api_mock).has = PropertyMock(return_value={"createMarketOrder": True})
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.reload_markets")
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}.validate_pricing")
@ -1110,7 +957,6 @@ def test_validate_ordertypes_stop_advanced(default_conf, mocker, exchange_name,
type(api_mock).has = PropertyMock(return_value={"createMarketOrder": True})
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.reload_markets")
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_stakecurrency")
mocker.patch(f"{EXMS}.validate_pricing")
@ -1135,7 +981,6 @@ def test_validate_order_types_not_in_config(default_conf, mocker):
api_mock = MagicMock()
mocker.patch(f"{EXMS}._init_ccxt", MagicMock(return_value=api_mock))
mocker.patch(f"{EXMS}.reload_markets")
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}.validate_pricing")
mocker.patch(f"{EXMS}.validate_stakecurrency")
@ -1151,7 +996,6 @@ def test_validate_required_startup_candles(default_conf, mocker, caplog):
mocker.patch(f"{EXMS}._init_ccxt", api_mock)
mocker.patch(f"{EXMS}.validate_timeframes")
mocker.patch(f"{EXMS}._load_async_markets")
mocker.patch(f"{EXMS}.validate_pairs")
mocker.patch(f"{EXMS}.validate_pricing")
mocker.patch(f"{EXMS}.validate_stakecurrency")
@ -4185,7 +4029,6 @@ def test_merge_ft_has_dict(default_conf, mocker):
EXMS,
_init_ccxt=MagicMock(return_value=MagicMock()),
_load_async_markets=MagicMock(),
validate_pairs=MagicMock(),
validate_timeframes=MagicMock(),
validate_stakecurrency=MagicMock(),
validate_pricing=MagicMock(),
@ -4220,7 +4063,6 @@ def test_get_valid_pair_combination(default_conf, mocker, markets):
EXMS,
_init_ccxt=MagicMock(return_value=MagicMock()),
_load_async_markets=MagicMock(),
validate_pairs=MagicMock(),
validate_timeframes=MagicMock(),
validate_pricing=MagicMock(),
markets=PropertyMock(return_value=markets),
@ -4500,7 +4342,6 @@ def test_get_markets(
EXMS,
_init_ccxt=MagicMock(return_value=MagicMock()),
_load_async_markets=MagicMock(),
validate_pairs=MagicMock(),
validate_timeframes=MagicMock(),
validate_pricing=MagicMock(),
markets=PropertyMock(return_value=markets_static),

View File

@ -2204,7 +2204,6 @@ def test_manage_open_orders_buy_exception(
patch_exchange(mocker)
mocker.patch.multiple(
EXMS,
validate_pairs=MagicMock(),
fetch_ticker=ticker_usdt,
fetch_order=MagicMock(side_effect=ExchangeError),
cancel_order=cancel_order_mock,