diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index 083433165..d8214ba57 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -148,7 +148,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.load_local_state_dict(state_dict["state"][self.rank]) # Restore the global param_groups - self.param_groups = state_dict["param_groups"] + self.param_groups = recursive_copy_to_device(state_dict["param_groups"], non_blocking=True, device=self._device) def add_param_group(self, param_group: dict) -> None: super().add_param_group(param_group) diff --git a/tests/optim/test_oss.py b/tests/optim/test_oss.py index c69151d6f..a382f051a 100644 --- a/tests/optim/test_oss.py +++ b/tests/optim/test_oss.py @@ -62,6 +62,9 @@ def test_state_dict(): o.step() assert x == torch.tensor([0.9], device=DEVICE) + # Check that the exposed param_groups are on the proper device + assert o.param_groups[0]["params"][0].device == x.device + def test_local_state_dict(): x = torch.tensor([1.0], device=DEVICE, requires_grad=True)