From 6e9d9b42598b8a78331df4093b829dd84efd8060 Mon Sep 17 00:00:00 2001 From: "Anssi \"Miffyli\" Kanervisto" Date: Thu, 30 Jul 2020 21:25:35 +0300 Subject: [PATCH 1/2] Separate feature extractor networks for DQN networks --- docs/misc/changelog.rst | 1 + stable_baselines3/dqn/policies.py | 9 ++++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index dfa728a00..988da16d6 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -33,6 +33,7 @@ Bug Fixes: - Use ``cloudpickle.load`` instead of ``pickle.load`` in ``CloudpickleWrapper``. (@shwang) - Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37) - Fixed approximate entropy calculation in PPO and A2C. (@andyshih12) +- Fixed DQN target network sharing feature extractor with the main network. Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index ea0eaa278..495b61f57 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -133,8 +133,6 @@ def __init__( else: net_arch = [] - self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) - self.features_dim = self.features_extractor.features_dim self.net_arch = net_arch self.activation_fn = activation_fn self.normalize_images = normalize_images @@ -142,8 +140,6 @@ def __init__( self.net_args = { "observation_space": self.observation_space, "action_space": self.action_space, - "features_extractor": self.features_extractor, - "features_dim": self.features_dim, "net_arch": self.net_arch, "activation_fn": self.activation_fn, "normalize_images": normalize_images, @@ -169,7 +165,10 @@ def _build(self, lr_schedule: Callable) -> None: self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) def make_q_net(self) -> QNetwork: - return QNetwork(**self.net_args).to(self.device) + # Make sure we always have separate networks for feature extractors etc + features_extractor = self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + features_dim = features_extractor.features_dim + return QNetwork(features_extractor=features_extractor, features_dim=features_dim, **self.net_args).to(self.device) def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: return self._predict(obs, deterministic=deterministic) From 9e5a2b4862b14abe61810aef7933ff177ab503d0 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 30 Jul 2020 20:44:33 +0200 Subject: [PATCH 2/2] [ci skip] Bump version --- docs/misc/changelog.rst | 2 +- stable_baselines3/version.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 988da16d6..7bb894f0b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Pre-Release 0.8.0a5 (WIP) +Pre-Release 0.8.0a6 (WIP) ------------------------------ Breaking Changes: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 87f68d6c9..db5057921 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.8.0a5 +0.8.0a6