freqtrade_origin/freqtrade/freqai/RL/BaseEnvironment.py

278 lines
9.9 KiB
Python

import logging
from abc import abstractmethod
from enum import Enum
from typing import Optional
import gym
import numpy as np
import pandas as pd
from gym import spaces
from gym.utils import seeding
from pandas import DataFrame
logger = logging.getLogger(__name__)
class Positions(Enum):
Short = 0
Long = 1
Neutral = 0.5
def opposite(self):
return Positions.Short if self == Positions.Long else Positions.Long
class BaseEnvironment(gym.Env):
"""
Base class for environments. This class is agnostic to action count.
Inherited classes customize this to include varying action counts/types,
See RL/Base5ActionRLEnv.py and RL/Base4ActionRLEnv.py
"""
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
reward_kwargs: dict = {}, window_size=10, starting_point=True,
id: str = 'baseenv-1', seed: int = 1, config: dict = {}):
self.rl_config = config['freqai']['rl_config']
self.add_state_info = self.rl_config.get('add_state_info', False)
self.id = id
self.seed(seed)
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
self.compound_trades = config['stake_amount'] == 'unlimited'
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
reward_kwargs: dict, starting_point=True):
"""
Resets the environment when the agent fails (in our case, if the drawdown
exceeds the user set max_training_drawdown_pct)
"""
self.df = df
self.signal_features = self.df
self.prices = prices
self.window_size = window_size
self.starting_point = starting_point
self.rr = reward_kwargs["rr"]
self.profit_aim = reward_kwargs["profit_aim"]
self.fee = 0.0015
# # spaces
if self.add_state_info:
self.total_features = self.signal_features.shape[1] + 3
else:
self.total_features = self.signal_features.shape[1]
self.shape = (window_size, self.total_features)
self.set_action_space()
self.observation_space = spaces.Box(
low=-1, high=1, shape=self.shape, dtype=np.float32)
# episode
self._start_tick: int = self.window_size
self._end_tick: int = len(self.prices) - 1
self._done: bool = False
self._current_tick: int = self._start_tick
self._last_trade_tick: Optional[int] = None
self._position = Positions.Neutral
self._position_history: list = [None]
self.total_reward: float = 0
self._total_profit: float = 1
self._total_unrealized_profit: float = 1
self.history: dict = {}
self.trade_history: list = []
@abstractmethod
def set_action_space(self):
"""
Unique to the environment action count. Must be inherited.
"""
def seed(self, seed: int = 1):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def reset(self):
self._done = False
if self.starting_point is True:
self._position_history = (self._start_tick * [None]) + [self._position]
else:
self._position_history = (self.window_size * [None]) + [self._position]
self._current_tick = self._start_tick
self._last_trade_tick = None
self._position = Positions.Neutral
self.total_reward = 0.
self._total_profit = 1. # unit
self.history = {}
self.trade_history = []
self.portfolio_log_returns = np.zeros(len(self.prices))
self._profits = [(self._start_tick, 1)]
self.close_trade_profit = []
self._total_unrealized_profit = 1
return self._get_observation()
@abstractmethod
def step(self, action: int):
"""
Step depeneds on action types, this must be inherited.
"""
return
def _get_observation(self):
"""
This may or may not be independent of action types, user can inherit
this in their custom "MyRLEnv"
"""
features_window = self.signal_features[(
self._current_tick - self.window_size):self._current_tick]
if self.add_state_info:
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
columns=['current_profit_pct',
'position',
'trade_duration'],
index=features_window.index)
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
features_and_state['position'] = self._position.value
features_and_state['trade_duration'] = self.get_trade_duration()
features_and_state = pd.concat([features_window, features_and_state], axis=1)
return features_and_state
else:
return features_window
def get_trade_duration(self):
"""
Get the trade duration if the agent is in a trade
"""
if self._last_trade_tick is None:
return 0
else:
return self._current_tick - self._last_trade_tick
def get_unrealized_profit(self):
"""
Get the unrealized profit if the agent is in a trade
"""
if self._last_trade_tick is None:
return 0.
if self._position == Positions.Neutral:
return 0.
elif self._position == Positions.Short:
current_price = self.add_entry_fee(self.prices.iloc[self._current_tick].open)
last_trade_price = self.add_exit_fee(self.prices.iloc[self._last_trade_tick].open)
return (last_trade_price - current_price) / last_trade_price
elif self._position == Positions.Long:
current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
return (current_price - last_trade_price) / last_trade_price
else:
return 0.
@abstractmethod
def is_tradesignal(self, action: int):
"""
Determine if the signal is a trade signal. This is
unique to the actions in the environment, and therefore must be
inherited.
"""
return
def _is_valid(self, action: int):
"""
Determine if the signal is valid.This is
unique to the actions in the environment, and therefore must be
inherited.
"""
return
def add_entry_fee(self, price):
return price * (1 + self.fee)
def add_exit_fee(self, price):
return price / (1 + self.fee)
def _update_history(self, info):
if not self.history:
self.history = {key: [] for key in info.keys()}
for key, value in info.items():
self.history[key].append(value)
@abstractmethod
def calculate_reward(self, action):
"""
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
:params:
action: int = The action made by the agent for the current candle.
:returns:
float = the reward to give to the agent for current step (used for optimization
of weights in NN)
"""
def _update_unrealized_total_profit(self):
"""
Update the unrealized total profit incase of episode end.
"""
if self._position in (Positions.Long, Positions.Short):
pnl = self.get_unrealized_profit()
if self.compound_trades:
# assumes unit stake and compounding
unrl_profit = self._total_profit * (1 + pnl)
else:
# assumes unit stake and no compounding
unrl_profit = self._total_profit + pnl
self._total_unrealized_profit = unrl_profit
def _update_total_profit(self):
pnl = self.get_unrealized_profit()
if self.compound_trades:
# assumes unite stake and compounding
self._total_profit = self._total_profit * (1 + pnl)
else:
# assumes unit stake and no compounding
self._total_profit += pnl
def most_recent_return(self, action: int):
"""
Calculate the tick to tick return if in a trade.
Return is generated from rising prices in Long
and falling prices in Short positions.
The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
"""
# Long positions
if self._position == Positions.Long:
current_price = self.prices.iloc[self._current_tick].open
previous_price = self.prices.iloc[self._current_tick - 1].open
if (self._position_history[self._current_tick - 1] == Positions.Short
or self._position_history[self._current_tick - 1] == Positions.Neutral):
previous_price = self.add_entry_fee(previous_price)
return np.log(current_price) - np.log(previous_price)
# Short positions
if self._position == Positions.Short:
current_price = self.prices.iloc[self._current_tick].open
previous_price = self.prices.iloc[self._current_tick - 1].open
if (self._position_history[self._current_tick - 1] == Positions.Long
or self._position_history[self._current_tick - 1] == Positions.Neutral):
previous_price = self.add_exit_fee(previous_price)
return np.log(previous_price) - np.log(current_price)
return 0
def update_portfolio_log_returns(self, action):
self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action)
def current_price(self) -> float:
return self.prices.iloc[self._current_tick].open