diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index dfa728a00..7bb894f0b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Pre-Release 0.8.0a5 (WIP) +Pre-Release 0.8.0a6 (WIP) ------------------------------ Breaking Changes: @@ -33,6 +33,7 @@ Bug Fixes: - Use ``cloudpickle.load`` instead of ``pickle.load`` in ``CloudpickleWrapper``. (@shwang) - Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37) - Fixed approximate entropy calculation in PPO and A2C. (@andyshih12) +- Fixed DQN target network sharing feature extractor with the main network. Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index ea0eaa278..495b61f57 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -133,8 +133,6 @@ def __init__( else: net_arch = [] - self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) - self.features_dim = self.features_extractor.features_dim self.net_arch = net_arch self.activation_fn = activation_fn self.normalize_images = normalize_images @@ -142,8 +140,6 @@ def __init__( self.net_args = { "observation_space": self.observation_space, "action_space": self.action_space, - "features_extractor": self.features_extractor, - "features_dim": self.features_dim, "net_arch": self.net_arch, "activation_fn": self.activation_fn, "normalize_images": normalize_images, @@ -169,7 +165,10 @@ def _build(self, lr_schedule: Callable) -> None: self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) def make_q_net(self) -> QNetwork: - return QNetwork(**self.net_args).to(self.device) + # Make sure we always have separate networks for feature extractors etc + features_extractor = self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + features_dim = features_extractor.features_dim + return QNetwork(features_extractor=features_extractor, features_dim=features_dim, **self.net_args).to(self.device) def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: return self._predict(obs, deterministic=deterministic) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 87f68d6c9..db5057921 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.8.0a5 +0.8.0a6