Update ReinforcementLearner_DDPG_TD3.py

Clean up set policy code.
This commit is contained in:
Shane 2024-05-26 20:21:16 +10:00 committed by GitHub
parent 3436e8aa1d
commit bb62b0fc5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -83,19 +83,10 @@ class ReinforcementLearner_DDPG_TD3(BaseReinforcementLearningModel):
model_params["learning_rate"] = linear_schedule(_lr) model_params["learning_rate"] = linear_schedule(_lr)
logger.info(f"Learning rate linear schedule enabled, initial value: {_lr}") logger.info(f"Learning rate linear schedule enabled, initial value: {_lr}")
if any(model in self.freqai_info["rl_config"]["model_type"] for model in ["DDPG", "TD3"]):
model_params["policy_kwargs"] = dict(
#net_arch=self.net_arch,
net_arch=dict(qf=self.net_arch, pi=self.net_arch),
activation_fn=th.nn.ReLU,
optimizer_class=th.optim.Adam
)
else:
model_params["policy_kwargs"] = dict( model_params["policy_kwargs"] = dict(
net_arch=dict(vf=self.net_arch, pi=self.net_arch), net_arch=dict(vf=self.net_arch, pi=self.net_arch),
activation_fn=th.nn.ReLU, activation_fn=th.nn.ReLU,
optimizer_class=th.optim.Adam optimizer_class=th.optim.Adam
)
return model_params return model_params