mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-13 03:33:55 +00:00
163 lines
5.3 KiB
Python
163 lines
5.3 KiB
Python
# common library
|
|
|
|
import numpy as np
|
|
from stable_baselines3 import A2C
|
|
from stable_baselines3 import DDPG
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3 import SAC
|
|
from stable_baselines3 import TD3
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
from stable_baselines3.common.noise import NormalActionNoise
|
|
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise
|
|
# from stable_baselines3.common.vec_env import DummyVecEnv
|
|
|
|
from freqtrade.freqai.prediction_models.RL import config
|
|
# from meta.env_stock_trading.env_stock_trading import StockTradingEnv
|
|
|
|
# RL models from stable-baselines
|
|
|
|
|
|
MODELS = {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO}
|
|
|
|
|
|
MODEL_KWARGS = {x: config.__dict__[f"{x.upper()}_PARAMS"] for x in MODELS.keys()}
|
|
|
|
|
|
NOISE = {
|
|
"normal": NormalActionNoise,
|
|
"ornstein_uhlenbeck": OrnsteinUhlenbeckActionNoise,
|
|
}
|
|
|
|
|
|
class TensorboardCallback(BaseCallback):
|
|
"""
|
|
Custom callback for plotting additional values in tensorboard.
|
|
"""
|
|
|
|
def __init__(self, verbose=0):
|
|
super(TensorboardCallback, self).__init__(verbose)
|
|
|
|
def _on_step(self) -> bool:
|
|
try:
|
|
self.logger.record(key="train/reward", value=self.locals["rewards"][0])
|
|
except BaseException:
|
|
self.logger.record(key="train/reward", value=self.locals["reward"][0])
|
|
return True
|
|
|
|
|
|
class RLPrediction_agent:
|
|
"""Provides implementations for DRL algorithms
|
|
Based on:
|
|
https://github.com/AI4Finance-Foundation/FinRL-Meta/blob/master/agents/stablebaselines3_models.py
|
|
Attributes
|
|
----------
|
|
env: gym environment class
|
|
user-defined class
|
|
|
|
Methods
|
|
-------
|
|
get_model()
|
|
setup DRL algorithms
|
|
train_model()
|
|
train DRL algorithms in a train dataset
|
|
and output the trained model
|
|
DRL_prediction()
|
|
make a prediction in a test dataset and get results
|
|
"""
|
|
|
|
def __init__(self, env):
|
|
self.env = env
|
|
|
|
def get_model(
|
|
self,
|
|
model_name,
|
|
policy="MlpPolicy",
|
|
policy_kwargs=None,
|
|
model_kwargs=None,
|
|
verbose=1,
|
|
seed=None,
|
|
):
|
|
if model_name not in MODELS:
|
|
raise NotImplementedError("NotImplementedError")
|
|
|
|
if model_kwargs is None:
|
|
model_kwargs = MODEL_KWARGS[model_name]
|
|
|
|
if "action_noise" in model_kwargs:
|
|
n_actions = self.env.action_space.shape[-1]
|
|
model_kwargs["action_noise"] = NOISE[model_kwargs["action_noise"]](
|
|
mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)
|
|
)
|
|
print(model_kwargs)
|
|
model = MODELS[model_name](
|
|
policy=policy,
|
|
env=self.env,
|
|
tensorboard_log=f"{config.TENSORBOARD_LOG_DIR}/{model_name}",
|
|
verbose=verbose,
|
|
policy_kwargs=policy_kwargs,
|
|
seed=seed,
|
|
**model_kwargs,
|
|
)
|
|
return model
|
|
|
|
def train_model(self, model, tb_log_name, total_timesteps=5000):
|
|
model = model.learn(
|
|
total_timesteps=total_timesteps,
|
|
tb_log_name=tb_log_name,
|
|
callback=TensorboardCallback(),
|
|
)
|
|
return model
|
|
|
|
@staticmethod
|
|
def DRL_prediction(model, environment):
|
|
test_env, test_obs = environment.get_sb_env()
|
|
"""make a prediction"""
|
|
account_memory = []
|
|
actions_memory = []
|
|
test_env.reset()
|
|
for i in range(len(environment.df.index.unique())):
|
|
action, _states = model.predict(test_obs)
|
|
# account_memory = test_env.env_method(method_name="save_asset_memory")
|
|
# actions_memory = test_env.env_method(method_name="save_action_memory")
|
|
test_obs, rewards, dones, info = test_env.step(action)
|
|
if i == (len(environment.df.index.unique()) - 2):
|
|
account_memory = test_env.env_method(method_name="save_asset_memory")
|
|
actions_memory = test_env.env_method(method_name="save_action_memory")
|
|
if dones[0]:
|
|
print("hit end!")
|
|
break
|
|
return account_memory[0], actions_memory[0]
|
|
|
|
@staticmethod
|
|
def DRL_prediction_load_from_file(model_name, environment, cwd):
|
|
if model_name not in MODELS:
|
|
raise NotImplementedError("NotImplementedError")
|
|
try:
|
|
# load agent
|
|
model = MODELS[model_name].load(cwd)
|
|
print("Successfully load model", cwd)
|
|
except BaseException:
|
|
raise ValueError("Fail to load agent!")
|
|
|
|
# test on the testing env
|
|
state = environment.reset()
|
|
episode_returns = list() # the cumulative_return / initial_account
|
|
episode_total_assets = list()
|
|
episode_total_assets.append(environment.initial_total_asset)
|
|
done = False
|
|
while not done:
|
|
action = model.predict(state)[0]
|
|
state, reward, done, _ = environment.step(action)
|
|
|
|
total_asset = (
|
|
environment.cash
|
|
+ (environment.price_array[environment.time] * environment.stocks).sum()
|
|
)
|
|
episode_total_assets.append(total_asset)
|
|
episode_return = total_asset / environment.initial_total_asset
|
|
episode_returns.append(episode_return)
|
|
|
|
print("episode_return", episode_return)
|
|
print("Test Finished!")
|
|
return episode_total_assets
|