Update BaseReinforcementLearningModel.py

Add support for DDPG and TD3.
This commit is contained in:
Shane 2024-05-24 21:29:38 +10:00 committed by GitHub
parent dc5766fb10
commit c83dd2d806
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
torch.multiprocessing.set_sharing_strategy("file_system")
SB3_MODELS = ["PPO", "A2C", "DQN"]
SB3_MODELS = ["PPO", "A2C", "DQN", "DDPG", "TD3"]
SB3_CONTRIB_MODELS = ["TRPO", "ARS", "RecurrentPPO", "MaskablePPO", "QRDQN"]