freqtrade_origin/freqtrade/freqai/prediction_models/ReinforcementLearner.py

172 lines
6.9 KiB
Python
Raw Normal View History

import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Type
import torch as th
from stable_baselines3.common.callbacks import ProgressBarCallback
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
2023-04-26 17:43:42 +00:00
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
logger = logging.getLogger(__name__)
class ReinforcementLearner(BaseReinforcementLearningModel):
"""
Reinforcement Learning Model prediction model.
Users can inherit from this class to make their own RL model with custom
environment/training controls. Define the file as follows:
```
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
class MyCoolRLModel(ReinforcementLearner):
```
Save the file to `user_data/freqaimodels`, then run it with:
freqtrade trade --freqaimodel MyCoolRLModel --config config.json --strategy SomeCoolStrat
Here the users can override any of the functions
available in the `IFreqaiModel` inheritance tree. Most importantly for RL, this
is where the user overrides `MyRLEnv` (see below), to define custom
`calculate_reward()` function, or to override any other parts of the environment.
This class also allows users to override any other part of the IFreqaiModel tree.
For example, the user can override `def fit()` or `def train()` or `def predict()`
to take fine-tuned control over these processes.
Another common override may be `def data_cleaning_predict()` where the user can
take fine-tuned control over the data handling pipeline.
"""
2022-09-14 22:46:35 +00:00
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
"""
User customizable fit method
2022-11-13 16:43:52 +00:00
:param data_dictionary: dict = common data dictionary containing all train/test
features/labels/weights.
2022-11-13 16:43:52 +00:00
:param dk: FreqaiDatakitchen = data kitchen for current pair.
:return:
model Any = trained model to be used for inference in dry/live/backtesting
"""
train_df = data_dictionary["train_features"]
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
2024-05-12 15:12:20 +00:00
policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=self.net_arch)
2023-05-14 14:39:23 +00:00
if self.activate_tensorboard:
2024-05-12 15:12:20 +00:00
tb_path = Path(dk.full_path / "tensorboard" / dk.pair.split("/")[0])
2023-05-14 14:39:23 +00:00
else:
tb_path = None
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
2024-05-12 15:12:20 +00:00
model = self.MODELCLASS(
self.policy_type,
self.train_env,
policy_kwargs=policy_kwargs,
tensorboard_log=tb_path,
**self.freqai_info.get("model_training_parameters", {}),
)
else:
2024-05-12 15:12:20 +00:00
logger.info(
"Continual training activated - starting training from previously " "trained agent."
)
model = self.dd.model_dictionary[dk.pair]
model.set_env(self.train_env)
callbacks: List[Any] = [self.eval_callback, self.tensorboard_callback]
progressbar_callback: Optional[ProgressBarCallback] = None
2024-05-12 15:12:20 +00:00
if self.rl_config.get("progress_bar", False):
progressbar_callback = ProgressBarCallback()
callbacks.insert(0, progressbar_callback)
try:
model.learn(
total_timesteps=int(total_timesteps),
callback=callbacks,
)
finally:
if progressbar_callback:
progressbar_callback.on_training_end()
if Path(dk.data_path / "best_model.zip").is_file():
2024-05-12 15:12:20 +00:00
logger.info("Callback found a best model.")
best_model = self.MODELCLASS.load(dk.data_path / "best_model")
return best_model
2023-10-15 08:40:45 +00:00
logger.info("Couldn't find best model, using final model instead.")
return model
2023-04-26 17:43:42 +00:00
MyRLEnv: Type[BaseEnvironment]
class MyRLEnv(Base5ActionRLEnv): # type: ignore[no-redef]
"""
User can override any function in BaseRLEnv and gym.Env. Here the user
sets a custom reward based on profit and trade duration.
"""
2022-11-26 11:11:59 +00:00
def calculate_reward(self, action: int) -> float:
"""
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
Warning!
This is function is a showcase of functionality designed to show as many possible
environment control features as possible. It is also designed to run quickly
on small computers. This is a benchmark, it is *not* for live production.
2022-11-13 16:43:52 +00:00
:param action: int = The action made by the agent for the current candle.
:return:
float = the reward to give to the agent for current step (used for optimization
of weights in NN)
"""
# first, penalize if the action is not valid
if not self._is_valid(action):
2023-03-11 22:32:55 +00:00
self.tensorboard_log("invalid", category="actions")
return -2
pnl = self.get_unrealized_profit()
2024-05-12 15:12:20 +00:00
factor = 100.0
# reward agent for entering trades
2024-05-12 15:12:20 +00:00
if action == Actions.Long_enter.value and self._position == Positions.Neutral:
2022-12-03 10:16:04 +00:00
return 25
2024-05-12 15:12:20 +00:00
if action == Actions.Short_enter.value and self._position == Positions.Neutral:
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
return -1
2024-05-12 15:12:20 +00:00
max_trade_duration = self.rl_config.get("max_trade_duration_candles", 300)
2022-11-26 11:11:59 +00:00
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
if trade_duration <= max_trade_duration:
factor *= 1.5
elif trade_duration > max_trade_duration:
factor *= 0.5
# discourage sitting in position
2024-05-12 15:12:20 +00:00
if (
self._position in (Positions.Short, Positions.Long)
and action == Actions.Neutral.value
):
return -1 * trade_duration / max_trade_duration
# close long
if action == Actions.Long_exit.value and self._position == Positions.Long:
if pnl > self.profit_aim * self.rr:
2024-05-12 15:12:20 +00:00
factor *= self.rl_config["model_reward_parameters"].get("win_reward_factor", 2)
return float(pnl * factor)
# close short
if action == Actions.Short_exit.value and self._position == Positions.Short:
if pnl > self.profit_aim * self.rr:
2024-05-12 15:12:20 +00:00
factor *= self.rl_config["model_reward_parameters"].get("win_reward_factor", 2)
return float(pnl * factor)
2022-12-03 11:30:04 +00:00
2024-05-12 15:12:20 +00:00
return 0.0