Skip to content

Commit

Permalink
[fix] OSS restore state to proper device (#46)
Browse files Browse the repository at this point in the history
* move the restored param groups to the original device

* adding a corresponding test
  • Loading branch information
blefaudeux authored Aug 20, 2020
1 parent 9d6c7b6 commit c2d6f4b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.load_local_state_dict(state_dict["state"][self.rank])

# Restore the global param_groups
self.param_groups = state_dict["param_groups"]
self.param_groups = recursive_copy_to_device(state_dict["param_groups"], non_blocking=True, device=self._device)

def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group)
Expand Down
3 changes: 3 additions & 0 deletions tests/optim/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def test_state_dict():
o.step()
assert x == torch.tensor([0.9], device=DEVICE)

# Check that the exposed param_groups are on the proper device
assert o.param_groups[0]["params"][0].device == x.device


def test_local_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
Expand Down

0 comments on commit c2d6f4b

Please sign in to comment.