diff --git a/docs/freqai-parameter-table.md b/docs/freqai-parameter-table.md index 02426ec13..f2a52a9b8 100644 --- a/docs/freqai-parameter-table.md +++ b/docs/freqai-parameter-table.md @@ -82,6 +82,7 @@ Mandatory parameters are marked as **Required** and have to be set in one of the | `model_reward_parameters` | Parameters used inside the customizable `calculate_reward()` function in `ReinforcementLearner.py`
**Datatype:** int. | `add_state_info` | Tell FreqAI to include state information in the feature set for training and inferencing. The current state variables include trade duration, current profit, trade position. This is only available in dry/live runs, and is automatically switched to false for backtesting.
**Datatype:** bool.
Default: `False`. | `net_arch` | Network architecture which is well described in [`stable_baselines3` doc](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#examples). In summary: `[, dict(vf=[], pi=[])]`. By default this is set to `[128, 128]`, which defines 2 shared hidden layers with 128 units each. +| `randomize_starting_position` | Randomize the starting point of each episode to avoid overfitting.
**Datatype:** bool.
Default: `False`. ### Additional parameters diff --git a/freqtrade/constants.py b/freqtrade/constants.py index 878c38929..d869b89f6 100644 --- a/freqtrade/constants.py +++ b/freqtrade/constants.py @@ -591,6 +591,7 @@ CONF_SCHEMA = { "model_type": {"type": "string", "default": "PPO"}, "policy_type": {"type": "string", "default": "MlpPolicy"}, "net_arch": {"type": "array", "default": [128, 128]}, + "randomize_startinng_position": {"type": "boolean", "default": False}, "model_reward_parameters": { "type": "object", "properties": { diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 5d881ba32..8f940dd1b 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -122,9 +122,10 @@ class BaseEnvironment(gym.Env): self._done = False if self.starting_point is True: - length_of_data = int(self._end_tick/4) - start_tick = random.randint(self.window_size+1, length_of_data) - self._start_tick = start_tick + if self.rl_config.get('randomize_starting_position', False): + length_of_data = int(self._end_tick / 4) + start_tick = random.randint(self.window_size + 1, length_of_data) + self._start_tick = start_tick self._position_history = (self._start_tick * [None]) + [self._position] else: self._position_history = (self.window_size * [None]) + [self._position]