from enum import Enum import gym import matplotlib.pyplot as plt import numpy as np from gym import spaces from gym.utils import seeding class Actions(Enum): Hold = 0 Buy = 1 Sell = 2 class Positions(Enum): Short = 0 Long = 1 def opposite(self): return Positions.Short if self == Positions.Long else Positions.Long class GymAnytrading(gym.Env): """ Based on https://github.com/AminHP/gym-anytrading """ metadata = {'render.modes': ['human']} def __init__(self, signal_features, prices, window_size, fee=0.0): assert signal_features.ndim == 2 self.seed() self.signal_features = signal_features self.prices = prices self.window_size = window_size self.fee = fee self.shape = (window_size, self.signal_features.shape[1]) # spaces self.action_space = spaces.Discrete(len(Actions)) self.observation_space = spaces.Box( low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32) # episode self._start_tick = self.window_size self._end_tick = len(self.prices) - 1 self._done = None self._current_tick = None self._last_trade_tick = None self._position = None self._position_history = None self._total_reward = None self._total_profit = None self._first_rendering = None self.history = None def seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] def reset(self): self._done = False self._current_tick = self._start_tick self._last_trade_tick = self._current_tick - 1 self._position = Positions.Short self._position_history = (self.window_size * [None]) + [self._position] self._total_reward = 0. self._total_profit = 1. # unit self._first_rendering = True self.history = {} return self._get_observation() def step(self, action): self._done = False self._current_tick += 1 if self._current_tick == self._end_tick: self._done = True step_reward = self._calculate_reward(action) self._total_reward += step_reward self._update_profit(action) trade = False if ((action == Actions.Buy.value and self._position == Positions.Short) or (action == Actions.Sell.value and self._position == Positions.Long)): trade = True if trade: self._position = self._position.opposite() self._last_trade_tick = self._current_tick self._position_history.append(self._position) observation = self._get_observation() info = dict( total_reward=self._total_reward, total_profit=self._total_profit, position=self._position.value ) self._update_history(info) return observation, step_reward, self._done, info def _get_observation(self): return self.signal_features[(self._current_tick - self.window_size):self._current_tick] def _update_history(self, info): if not self.history: self.history = {key: [] for key in info.keys()} for key, value in info.items(): self.history[key].append(value) def render(self, mode='human'): def _plot_position(position, tick): color = None if position == Positions.Short: color = 'red' elif position == Positions.Long: color = 'green' if color: plt.scatter(tick, self.prices[tick], color=color) if self._first_rendering: self._first_rendering = False plt.cla() plt.plot(self.prices) start_position = self._position_history[self._start_tick] _plot_position(start_position, self._start_tick) _plot_position(self._position, self._current_tick) plt.suptitle( "Total Reward: %.6f" % self._total_reward + ' ~ ' + "Total Profit: %.6f" % self._total_profit ) plt.pause(0.01) def render_all(self, mode='human'): window_ticks = np.arange(len(self._position_history)) plt.plot(self.prices) short_ticks = [] long_ticks = [] for i, tick in enumerate(window_ticks): if self._position_history[i] == Positions.Short: short_ticks.append(tick) elif self._position_history[i] == Positions.Long: long_ticks.append(tick) plt.plot(short_ticks, self.prices[short_ticks], 'ro') plt.plot(long_ticks, self.prices[long_ticks], 'go') plt.suptitle( "Total Reward: %.6f" % self._total_reward + ' ~ ' + "Total Profit: %.6f" % self._total_profit ) def close(self): plt.close() def save_rendering(self, filepath): plt.savefig(filepath) def pause_rendering(self): plt.show() def _calculate_reward(self, action): step_reward = 0 trade = False if ((action == Actions.Buy.value and self._position == Positions.Short) or (action == Actions.Sell.value and self._position == Positions.Long)): trade = True if trade: current_price = self.prices[self._current_tick] last_trade_price = self.prices[self._last_trade_tick] price_diff = current_price - last_trade_price if self._position == Positions.Long: step_reward += price_diff return step_reward def _update_profit(self, action): trade = False if ((action == Actions.Buy.value and self._position == Positions.Short) or (action == Actions.Sell.value and self._position == Positions.Long)): trade = True if trade or self._done: current_price = self.prices[self._current_tick] last_trade_price = self.prices[self._last_trade_tick] if self._position == Positions.Long: shares = (self._total_profit * (1 - self.fee)) / last_trade_price self._total_profit = (shares * (1 - self.fee)) * current_price def max_possible_profit(self): current_tick = self._start_tick last_trade_tick = current_tick - 1 profit = 1. while current_tick <= self._end_tick: position = None if self.prices[current_tick] < self.prices[current_tick - 1]: while (current_tick <= self._end_tick and self.prices[current_tick] < self.prices[current_tick - 1]): current_tick += 1 position = Positions.Short else: while (current_tick <= self._end_tick and self.prices[current_tick] >= self.prices[current_tick - 1]): current_tick += 1 position = Positions.Long if position == Positions.Long: current_price = self.prices[current_tick - 1] last_trade_price = self.prices[last_trade_tick] shares = profit / last_trade_price profit = shares * current_price last_trade_tick = current_tick - 1 print(profit) return profit