From e790f5e2bc8d3667dac26638975967f30ce7b960 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Fri, 5 Jul 2024 18:11:18 +0200 Subject: [PATCH 1/2] Updated DQN optimizer input to only include q_network parameters --- docs/misc/changelog.rst | 1 + stable_baselines3/dqn/policies.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 8df321129..358b347f9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -20,6 +20,7 @@ Bug Fixes: - ``CallbackList`` now sets the ``.parent`` attribute of child callbacks to its own ``.parent``. (will-maclean) - Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122) - Set requirement numpy<2.0 until PyTorch is compatible (https://github.com/pytorch/pytorch/issues/107302) +- Updated DQN optimizer input to only include q_network parameters, removing the target_q_network ones (@corentinlger) `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 9d2cf94df..5afa8608f 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -166,8 +166,9 @@ def _build(self, lr_schedule: Schedule) -> None: self.q_net_target.set_training_mode(False) # Setup optimizer with initial learning rate + q_net_parameters = list(self.q_net.parameters()) self.optimizer = self.optimizer_class( # type: ignore[call-arg] - self.parameters(), + q_net_parameters, lr=lr_schedule(1), **self.optimizer_kwargs, ) From 642666db947df5a48b0252bf903b56330746f81c Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 5 Jul 2024 18:48:56 +0200 Subject: [PATCH 2/2] Update version --- docs/misc/changelog.rst | 2 +- stable_baselines3/dqn/policies.py | 3 +-- stable_baselines3/version.txt | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 358b347f9..78eb2bd0e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a4 (WIP) +Release 2.4.0a5 (WIP) -------------------------- Breaking Changes: diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 5afa8608f..bfefc8137 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -166,9 +166,8 @@ def _build(self, lr_schedule: Schedule) -> None: self.q_net_target.set_training_mode(False) # Setup optimizer with initial learning rate - q_net_parameters = list(self.q_net.parameters()) self.optimizer = self.optimizer_class( # type: ignore[call-arg] - q_net_parameters, + self.q_net.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs, ) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 2d22b1587..a1fd35b5f 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a4 +2.4.0a5