mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-13 03:33:55 +00:00
Merge pull request #7908 from freqtrade/add-3action-rl-env
Add 3 Action RL Env
This commit is contained in:
commit
cc30210b3f
|
@ -275,12 +275,12 @@ FreqAI also provides a built in episodic summary logger called `self.tensorboard
|
||||||
|
|
||||||
### Choosing a base environment
|
### Choosing a base environment
|
||||||
|
|
||||||
FreqAI provides two base environments, `Base4ActionEnvironment` and `Base5ActionEnvironment`. As the names imply, the environments are customized for agents that can select from 4 or 5 actions. In the `Base4ActionEnvironment`, the agent can enter long, enter short, hold neutral, or exit position. Meanwhile, in the `Base5ActionEnvironment`, the agent has the same actions as Base4, but instead of a single exit action, it separates exit long and exit short. The main changes stemming from the environment selection include:
|
FreqAI provides three base environments, `Base3ActionRLEnvironment`, `Base4ActionEnvironment` and `Base5ActionEnvironment`. As the names imply, the environments are customized for agents that can select from 3, 4 or 5 actions. The `Base3ActionEnvironment` is the simplest, the agent can select from hold, long, or short. This environment can also be used for long-only bots (it automatically follows the `can_short` flag from the strategy), where long is the enter condition and short is the exit condition. Meanwhile, in the `Base4ActionEnvironment`, the agent can enter long, enter short, hold neutral, or exit position. Finally, in the `Base5ActionEnvironment`, the agent has the same actions as Base4, but instead of a single exit action, it separates exit long and exit short. The main changes stemming from the environment selection include:
|
||||||
|
|
||||||
* the actions available in the `calculate_reward`
|
* the actions available in the `calculate_reward`
|
||||||
* the actions consumed by the user strategy
|
* the actions consumed by the user strategy
|
||||||
|
|
||||||
Both of the FreqAI provided environments inherit from an action/position agnostic environment object called the `BaseEnvironment`, which contains all shared logic. The architecture is designed to be easily customized. The simplest customization is the `calculate_reward()` (see details [here](#creating-a-custom-reward-function)). However, the customizations can be further extended into any of the functions inside the environment. You can do this by simply overriding those functions inside your `MyRLEnv` in the prediction model file. Or for more advanced customizations, it is encouraged to create an entirely new environment inherited from `BaseEnvironment`.
|
All of the FreqAI provided environments inherit from an action/position agnostic environment object called the `BaseEnvironment`, which contains all shared logic. The architecture is designed to be easily customized. The simplest customization is the `calculate_reward()` (see details [here](#creating-a-custom-reward-function)). However, the customizations can be further extended into any of the functions inside the environment. You can do this by simply overriding those functions inside your `MyRLEnv` in the prediction model file. Or for more advanced customizations, it is encouraged to create an entirely new environment inherited from `BaseEnvironment`.
|
||||||
|
|
||||||
!!! Note
|
!!! Note
|
||||||
FreqAI does not provide by default, a long-only training environment. However, creating one should be as simple as copy-pasting one of the built in environments and removing the `short` actions (and all associated references to those).
|
Only the `Base3ActionRLEnv` can do long-only training/trading (set the user strategy attribute `can_short = False`).
|
||||||
|
|
125
freqtrade/freqai/RL/Base3ActionRLEnv.py
Normal file
125
freqtrade/freqai/RL/Base3ActionRLEnv.py
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from gym import spaces
|
||||||
|
|
||||||
|
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(Enum):
|
||||||
|
Neutral = 0
|
||||||
|
Buy = 1
|
||||||
|
Sell = 2
|
||||||
|
|
||||||
|
|
||||||
|
class Base3ActionRLEnv(BaseEnvironment):
|
||||||
|
"""
|
||||||
|
Base class for a 3 action environment
|
||||||
|
"""
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.actions = Actions
|
||||||
|
|
||||||
|
def set_action_space(self):
|
||||||
|
self.action_space = spaces.Discrete(len(Actions))
|
||||||
|
|
||||||
|
def step(self, action: int):
|
||||||
|
"""
|
||||||
|
Logic for a single step (incrementing one candle in time)
|
||||||
|
by the agent
|
||||||
|
:param: action: int = the action type that the agent plans
|
||||||
|
to take for the current step.
|
||||||
|
:returns:
|
||||||
|
observation = current state of environment
|
||||||
|
step_reward = the reward from `calculate_reward()`
|
||||||
|
_done = if the agent "died" or if the candles finished
|
||||||
|
info = dict passed back to openai gym lib
|
||||||
|
"""
|
||||||
|
self._done = False
|
||||||
|
self._current_tick += 1
|
||||||
|
|
||||||
|
if self._current_tick == self._end_tick:
|
||||||
|
self._done = True
|
||||||
|
|
||||||
|
self._update_unrealized_total_profit()
|
||||||
|
step_reward = self.calculate_reward(action)
|
||||||
|
self.total_reward += step_reward
|
||||||
|
self.tensorboard_log(self.actions._member_names_[action])
|
||||||
|
|
||||||
|
trade_type = None
|
||||||
|
if self.is_tradesignal(action):
|
||||||
|
if action == Actions.Buy.value:
|
||||||
|
if self._position == Positions.Short:
|
||||||
|
self._update_total_profit()
|
||||||
|
self._position = Positions.Long
|
||||||
|
trade_type = "long"
|
||||||
|
self._last_trade_tick = self._current_tick
|
||||||
|
elif action == Actions.Sell.value and self.can_short:
|
||||||
|
if self._position == Positions.Long:
|
||||||
|
self._update_total_profit()
|
||||||
|
self._position = Positions.Short
|
||||||
|
trade_type = "short"
|
||||||
|
self._last_trade_tick = self._current_tick
|
||||||
|
elif action == Actions.Sell.value and not self.can_short:
|
||||||
|
self._update_total_profit()
|
||||||
|
self._position = Positions.Neutral
|
||||||
|
trade_type = "neutral"
|
||||||
|
self._last_trade_tick = None
|
||||||
|
else:
|
||||||
|
print("case not defined")
|
||||||
|
|
||||||
|
if trade_type is not None:
|
||||||
|
self.trade_history.append(
|
||||||
|
{'price': self.current_price(), 'index': self._current_tick,
|
||||||
|
'type': trade_type})
|
||||||
|
|
||||||
|
if (self._total_profit < self.max_drawdown or
|
||||||
|
self._total_unrealized_profit < self.max_drawdown):
|
||||||
|
self._done = True
|
||||||
|
|
||||||
|
self._position_history.append(self._position)
|
||||||
|
|
||||||
|
info = dict(
|
||||||
|
tick=self._current_tick,
|
||||||
|
action=action,
|
||||||
|
total_reward=self.total_reward,
|
||||||
|
total_profit=self._total_profit,
|
||||||
|
position=self._position.value,
|
||||||
|
trade_duration=self.get_trade_duration(),
|
||||||
|
current_profit_pct=self.get_unrealized_profit()
|
||||||
|
)
|
||||||
|
|
||||||
|
observation = self._get_observation()
|
||||||
|
|
||||||
|
self._update_history(info)
|
||||||
|
|
||||||
|
return observation, step_reward, self._done, info
|
||||||
|
|
||||||
|
def is_tradesignal(self, action: int) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if the signal is a trade signal
|
||||||
|
e.g.: agent wants a Actions.Buy while it is in a Positions.short
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
(action == Actions.Buy.value and self._position == Positions.Neutral)
|
||||||
|
or (action == Actions.Sell.value and self._position == Positions.Long)
|
||||||
|
or (action == Actions.Sell.value and self._position == Positions.Neutral
|
||||||
|
and self.can_short)
|
||||||
|
or (action == Actions.Buy.value and self._position == Positions.Short
|
||||||
|
and self.can_short)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_valid(self, action: int) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if the signal is valid.
|
||||||
|
e.g.: agent wants a Actions.Sell while it is in a Positions.Long
|
||||||
|
"""
|
||||||
|
if self.can_short:
|
||||||
|
return action in [Actions.Buy.value, Actions.Sell.value, Actions.Neutral.value]
|
||||||
|
else:
|
||||||
|
if action == Actions.Sell.value and self._position != Positions.Long:
|
||||||
|
return False
|
||||||
|
return True
|
|
@ -45,7 +45,7 @@ class BaseEnvironment(gym.Env):
|
||||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
|
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
|
||||||
fee: float = 0.0015):
|
fee: float = 0.0015, can_short: bool = False):
|
||||||
"""
|
"""
|
||||||
Initializes the training/eval environment.
|
Initializes the training/eval environment.
|
||||||
:param df: dataframe of features
|
:param df: dataframe of features
|
||||||
|
@ -58,6 +58,7 @@ class BaseEnvironment(gym.Env):
|
||||||
:param config: Typical user configuration file
|
:param config: Typical user configuration file
|
||||||
:param live: Whether or not this environment is active in dry/live/backtesting
|
:param live: Whether or not this environment is active in dry/live/backtesting
|
||||||
:param fee: The fee to use for environmental interactions.
|
:param fee: The fee to use for environmental interactions.
|
||||||
|
:param can_short: Whether or not the environment can short
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rl_config = config['freqai']['rl_config']
|
self.rl_config = config['freqai']['rl_config']
|
||||||
|
@ -73,6 +74,7 @@ class BaseEnvironment(gym.Env):
|
||||||
# set here to default 5Ac, but all children envs can override this
|
# set here to default 5Ac, but all children envs can override this
|
||||||
self.actions: Type[Enum] = BaseActions
|
self.actions: Type[Enum] = BaseActions
|
||||||
self.tensorboard_metrics: dict = {}
|
self.tensorboard_metrics: dict = {}
|
||||||
|
self.can_short = can_short
|
||||||
self.live = live
|
self.live = live
|
||||||
if not self.live and self.add_state_info:
|
if not self.live and self.add_state_info:
|
||||||
self.add_state_info = False
|
self.add_state_info = False
|
||||||
|
|
|
@ -165,7 +165,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||||
env_info = {"window_size": self.CONV_WIDTH,
|
env_info = {"window_size": self.CONV_WIDTH,
|
||||||
"reward_kwargs": self.reward_params,
|
"reward_kwargs": self.reward_params,
|
||||||
"config": self.config,
|
"config": self.config,
|
||||||
"live": self.live}
|
"live": self.live,
|
||||||
|
"can_short": self.can_short}
|
||||||
if self.data_provider:
|
if self.data_provider:
|
||||||
env_info["fee"] = self.data_provider._exchange \
|
env_info["fee"] = self.data_provider._exchange \
|
||||||
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
||||||
|
|
|
@ -104,6 +104,7 @@ class IFreqaiModel(ABC):
|
||||||
self.metadata: Dict[str, Any] = self.dd.load_global_metadata_from_disk()
|
self.metadata: Dict[str, Any] = self.dd.load_global_metadata_from_disk()
|
||||||
self.data_provider: Optional[DataProvider] = None
|
self.data_provider: Optional[DataProvider] = None
|
||||||
self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
|
self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
|
||||||
|
self.can_short = True # overridden in start() with strategy.can_short
|
||||||
|
|
||||||
record_params(config, self.full_path)
|
record_params(config, self.full_path)
|
||||||
|
|
||||||
|
@ -133,6 +134,7 @@ class IFreqaiModel(ABC):
|
||||||
self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
|
self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
|
||||||
self.dd.set_pair_dict_info(metadata)
|
self.dd.set_pair_dict_info(metadata)
|
||||||
self.data_provider = strategy.dp
|
self.data_provider = strategy.dp
|
||||||
|
self.can_short = strategy.can_short
|
||||||
|
|
||||||
if self.live:
|
if self.live:
|
||||||
self.inference_timer('start')
|
self.inference_timer('start')
|
||||||
|
|
|
@ -27,20 +27,23 @@ def is_mac() -> bool:
|
||||||
return "Darwin" in machine
|
return "Darwin" in machine
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('model, pca, dbscan, float32', [
|
@pytest.mark.parametrize('model, pca, dbscan, float32, can_short', [
|
||||||
('LightGBMRegressor', True, False, True),
|
('LightGBMRegressor', True, False, True, True),
|
||||||
('XGBoostRegressor', False, True, False),
|
('XGBoostRegressor', False, True, False, True),
|
||||||
('XGBoostRFRegressor', False, False, False),
|
('XGBoostRFRegressor', False, False, False, True),
|
||||||
('CatboostRegressor', False, False, False),
|
('CatboostRegressor', False, False, False, True),
|
||||||
('ReinforcementLearner', False, True, False),
|
('ReinforcementLearner', False, True, False, True),
|
||||||
('ReinforcementLearner_multiproc', False, False, False),
|
('ReinforcementLearner_multiproc', False, False, False, True),
|
||||||
('ReinforcementLearner_test_4ac', False, False, False)
|
('ReinforcementLearner_test_3ac', False, False, False, False),
|
||||||
|
('ReinforcementLearner_test_3ac', False, False, False, True),
|
||||||
|
('ReinforcementLearner_test_4ac', False, False, False, True)
|
||||||
])
|
])
|
||||||
def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, dbscan, float32):
|
def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
||||||
|
dbscan, float32, can_short):
|
||||||
if is_arm() and model == 'CatboostRegressor':
|
if is_arm() and model == 'CatboostRegressor':
|
||||||
pytest.skip("CatBoost is not supported on ARM")
|
pytest.skip("CatBoost is not supported on ARM")
|
||||||
|
|
||||||
if is_mac() and 'Reinforcement' in model:
|
if is_mac() and not is_arm() and 'Reinforcement' in model:
|
||||||
pytest.skip("Reinforcement learning module not available on intel based Mac OS")
|
pytest.skip("Reinforcement learning module not available on intel based Mac OS")
|
||||||
|
|
||||||
model_save_ext = 'joblib'
|
model_save_ext = 'joblib'
|
||||||
|
@ -58,9 +61,6 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
||||||
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
|
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
|
||||||
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
|
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
|
||||||
|
|
||||||
if 'test_4ac' in model:
|
|
||||||
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
|
||||||
|
|
||||||
if 'ReinforcementLearner' in model:
|
if 'ReinforcementLearner' in model:
|
||||||
model_save_ext = 'zip'
|
model_save_ext = 'zip'
|
||||||
freqai_conf = make_rl_config(freqai_conf)
|
freqai_conf = make_rl_config(freqai_conf)
|
||||||
|
@ -68,7 +68,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
||||||
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
|
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
|
||||||
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
|
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
|
||||||
|
|
||||||
if 'test_4ac' in model:
|
if 'test_3ac' in model or 'test_4ac' in model:
|
||||||
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
||||||
|
|
||||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||||
|
@ -77,6 +77,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
||||||
strategy.freqai_info = freqai_conf.get("freqai", {})
|
strategy.freqai_info = freqai_conf.get("freqai", {})
|
||||||
freqai = strategy.freqai
|
freqai = strategy.freqai
|
||||||
freqai.live = True
|
freqai.live = True
|
||||||
|
freqai.can_short = can_short
|
||||||
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
||||||
freqai.dk.set_paths('ADA/BTC', 10000)
|
freqai.dk.set_paths('ADA/BTC', 10000)
|
||||||
timerange = TimeRange.parse_timerange("20180110-20180130")
|
timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||||
|
|
65
tests/freqai/test_models/ReinforcementLearner_test_3ac.py
Normal file
65
tests/freqai/test_models/ReinforcementLearner_test_3ac.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||||
|
from freqtrade.freqai.RL.Base3ActionRLEnv import Actions, Base3ActionRLEnv, Positions
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReinforcementLearner_test_3ac(ReinforcementLearner):
|
||||||
|
"""
|
||||||
|
User created Reinforcement Learning Model prediction model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MyRLEnv(Base3ActionRLEnv):
|
||||||
|
"""
|
||||||
|
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||||
|
sets a custom reward based on profit and trade duration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def calculate_reward(self, action: int) -> float:
|
||||||
|
|
||||||
|
# first, penalize if the action is not valid
|
||||||
|
if not self._is_valid(action):
|
||||||
|
return -2
|
||||||
|
|
||||||
|
pnl = self.get_unrealized_profit()
|
||||||
|
rew = np.sign(pnl) * (pnl + 1)
|
||||||
|
factor = 100.
|
||||||
|
|
||||||
|
# reward agent for entering trades
|
||||||
|
if (action in (Actions.Buy.value, Actions.Sell.value)
|
||||||
|
and self._position == Positions.Neutral):
|
||||||
|
return 25
|
||||||
|
# discourage agent from not entering trades
|
||||||
|
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||||
|
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
|
||||||
|
|
||||||
|
if trade_duration <= max_trade_duration:
|
||||||
|
factor *= 1.5
|
||||||
|
elif trade_duration > max_trade_duration:
|
||||||
|
factor *= 0.5
|
||||||
|
|
||||||
|
# discourage sitting in position
|
||||||
|
if self._position in (Positions.Short, Positions.Long) and (
|
||||||
|
action == Actions.Neutral.value
|
||||||
|
or (action == Actions.Sell.value and self._position == Positions.Short)
|
||||||
|
or (action == Actions.Buy.value and self._position == Positions.Long)
|
||||||
|
):
|
||||||
|
return -1 * trade_duration / max_trade_duration
|
||||||
|
|
||||||
|
# close position
|
||||||
|
if (action == Actions.Buy.value and self._position == Positions.Short) or (
|
||||||
|
action == Actions.Sell.value and self._position == Positions.Long
|
||||||
|
):
|
||||||
|
if pnl > self.profit_aim * self.rr:
|
||||||
|
factor *= self.rl_config["model_reward_parameters"].get("win_reward_factor", 2)
|
||||||
|
return float(rew * factor)
|
||||||
|
|
||||||
|
return 0.
|
Loading…
Reference in New Issue
Block a user