mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
Add 3 Action RL env
This commit is contained in:
parent
dde363343c
commit
7727f31507
125
freqtrade/freqai/RL/Base3ActionRLEnv.py
Normal file
125
freqtrade/freqai/RL/Base3ActionRLEnv.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from gym import spaces
|
||||
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Actions(Enum):
|
||||
Neutral = 0
|
||||
Buy = 1
|
||||
Sell = 2
|
||||
|
||||
|
||||
class Base3ActionRLEnv(BaseEnvironment):
|
||||
"""
|
||||
Base class for a 3 action environment
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.actions = Actions
|
||||
|
||||
def set_action_space(self):
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
|
||||
def step(self, action: int):
|
||||
"""
|
||||
Logic for a single step (incrementing one candle in time)
|
||||
by the agent
|
||||
:param: action: int = the action type that the agent plans
|
||||
to take for the current step.
|
||||
:returns:
|
||||
observation = current state of environment
|
||||
step_reward = the reward from `calculate_reward()`
|
||||
_done = if the agent "died" or if the candles finished
|
||||
info = dict passed back to openai gym lib
|
||||
"""
|
||||
self._done = False
|
||||
self._current_tick += 1
|
||||
|
||||
if self._current_tick == self._end_tick:
|
||||
self._done = True
|
||||
|
||||
self._update_unrealized_total_profit()
|
||||
step_reward = self.calculate_reward(action)
|
||||
self.total_reward += step_reward
|
||||
self.tensorboard_log(self.actions._member_names_[action])
|
||||
|
||||
trade_type = None
|
||||
if self.is_tradesignal(action):
|
||||
if action == Actions.Buy.value:
|
||||
if self._position == Positions.Short:
|
||||
self._update_total_profit()
|
||||
self._position = Positions.Long
|
||||
trade_type = "long"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Sell.value and self.can_short:
|
||||
if self._position == Positions.Long:
|
||||
self._update_total_profit()
|
||||
self._position = Positions.Short
|
||||
trade_type = "short"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Sell.value and not self.can_short:
|
||||
self._update_total_profit()
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
else:
|
||||
print("case not defined")
|
||||
|
||||
if trade_type is not None:
|
||||
self.trade_history.append(
|
||||
{'price': self.current_price(), 'index': self._current_tick,
|
||||
'type': trade_type})
|
||||
|
||||
if (self._total_profit < self.max_drawdown or
|
||||
self._total_unrealized_profit < self.max_drawdown):
|
||||
self._done = True
|
||||
|
||||
self._position_history.append(self._position)
|
||||
|
||||
info = dict(
|
||||
tick=self._current_tick,
|
||||
action=action,
|
||||
total_reward=self.total_reward,
|
||||
total_profit=self._total_profit,
|
||||
position=self._position.value,
|
||||
trade_duration=self.get_trade_duration(),
|
||||
current_profit_pct=self.get_unrealized_profit()
|
||||
)
|
||||
|
||||
observation = self._get_observation()
|
||||
|
||||
self._update_history(info)
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
def is_tradesignal(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is a trade signal
|
||||
e.g.: agent wants a Actions.Buy while it is in a Positions.short
|
||||
"""
|
||||
return (
|
||||
(action == Actions.Buy.value and self._position == Positions.Neutral)
|
||||
or (action == Actions.Sell.value and self._position == Positions.Long)
|
||||
or (action == Actions.Sell.value and self._position == Positions.Neutral
|
||||
and self.can_short)
|
||||
or (action == Actions.Buy.value and self._position == Positions.Short
|
||||
and self.can_short)
|
||||
)
|
||||
|
||||
def _is_valid(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is valid.
|
||||
e.g.: agent wants a Actions.Sell while it is in a Positions.Long
|
||||
"""
|
||||
if self.can_short:
|
||||
return action in [Actions.Buy.value, Actions.Sell.value, Actions.Neutral.value]
|
||||
else:
|
||||
if action == Actions.Sell.value and self._position != Positions.Long:
|
||||
return False
|
||||
return True
|
Loading…
Reference in New Issue
Block a user