Add sb3 learn progress bar

This commit is contained in:
initrv 2023-04-02 02:59:02 +03:00
parent dc7e834911
commit cab82e8e60
5 changed files with 9 additions and 2 deletions

View File

@ -73,7 +73,8 @@
10,
20
],
"plot_feature_importances": 0
"plot_feature_importances": 0,
"progress_bar": false
},
"data_split_parameters": {
"test_size": 0.33,

View File

@ -85,6 +85,7 @@ Mandatory parameters are marked as **Required** and have to be set in one of the
| `net_arch` | Network architecture which is well described in [`stable_baselines3` doc](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#examples). In summary: `[<shared layers>, dict(vf=[<non-shared value network layers>], pi=[<non-shared policy network layers>])]`. By default this is set to `[128, 128]`, which defines 2 shared hidden layers with 128 units each.
| `randomize_starting_position` | Randomize the starting point of each episode to avoid overfitting. <br> **Datatype:** bool. <br> Default: `False`.
| `drop_ohlc_from_features` | Do not include the normalized ohlc data in the feature set passed to the agent during training (ohlc will still be used for driving the environment in all cases) <br> **Datatype:** Boolean. <br> **Default:** `False`
| `progress_bar` | Display a progress bar with the current progress, elapsed time and estimated remaining time. <br> **Datatype:** Boolean. <br> Default: `False`.
### Additional parameters

View File

@ -599,6 +599,7 @@ CONF_SCHEMA = {
"policy_type": {"type": "string", "default": "MlpPolicy"},
"net_arch": {"type": "array", "default": [128, 128]},
"randomize_startinng_position": {"type": "boolean", "default": False},
"progress_bar": {"type": "boolean", "default": False},
"model_reward_parameters": {
"type": "object",
"properties": {

View File

@ -71,7 +71,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
model.learn(
total_timesteps=int(total_timesteps),
callback=[self.eval_callback, self.tensorboard_callback]
callback=[self.eval_callback, self.tensorboard_callback],
progress_bar=self.freqai_info["rl_config"]["progress_bar"]
)
if Path(dk.data_path / "best_model.zip").is_file():

View File

@ -8,3 +8,6 @@ sb3-contrib==1.7.0; python_version < '3.11'
# Gym is forced to this version by stable-baselines3.
setuptools==65.5.1 # Should be removed when gym is fixed.
gym==0.21; python_version < '3.11'
# Progress bar for stable-baselines3 and sb3-contrib
tqdm==4.65.0; python_version < '3.11'
rich==13.3.3; python_version < '3.11'