From fdcc206b007e10d1599e3f910536de2dc9c7f78c Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Fri, 4 Sep 2020 17:14:45 -0700 Subject: [PATCH 1/3] sync local and global param_group keys --- fairscale/optim/oss.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index ccd37efe9..934584c1c 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -67,6 +67,12 @@ def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = # Current device is set by the parameters allocated to this rank self._device = split_param_groups[self.rank][0]["params"][0].device + # Sync local and global param_groups keys + for global_group, local_group in zip(self.param_groups, self.optim.param_groups): + for k, v in local_group.items(): + if k != "params": + global_group[k] = v + def partition_parameters(self) -> List[List[dict]]: """Partitions parameters across distributed ranks. @@ -94,8 +100,8 @@ def partition_parameters(self) -> List[List[dict]]: # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs. # For example, the apex library contains fused optimizers with a step that supports extra kwargs. def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]: - # Sync lr in case its been update by an LRScheduler. - self._sync_lr() + # Sync oss param_groups attributes in case they've been updated by a scheduler. + self._sync_pg_attributes() # Run the optimizer step on this shard only loss = self.optim.step(closure=closure, **kwargs) # type: ignore @@ -116,8 +122,8 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None: This needs to be called on all replicas """ - # Sync lr in case its been update by an LRScheduler. - self._sync_lr() + # Sync lr and other attributes in case its been updated + self._sync_pg_attributes() if self.rank == recipient_rank: # Pull the sharded state from all the other replicas @@ -176,9 +182,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: {"state": state_dict["state"][self.rank], "param_groups": state_dict["param_groups"][self.rank]} ) - # Update the param_groups attribute for this instance - # TODO(ben) - def add_param_group(self, param_group: dict) -> None: super().add_param_group(param_group) if not self.in_super_constructor: @@ -186,10 +189,13 @@ def add_param_group(self, param_group: dict) -> None: if len(param_groups) == len(self.optim.param_groups) + 1: self.optim.add_param_group(param_groups[-1]) - def _sync_lr(self) -> None: - """Sync learning rate (needed to support LRScheduler).""" + def _sync_pg_attributes(self) -> None: + """Sync learning rate and other optimizer attributes (needed to support schedulers).""" for global_group, local_group in zip(self.param_groups, self.optim.param_groups): - local_group["lr"] = global_group["lr"] + for k in local_group.keys(): + if k != "params": + # Params have been sharded and should not be synced here + local_group[k] = global_group[k] def _collect_sharded_states(self) -> List[Dict[str, Any]]: """ From 33dedbd8c7c952ed657fa9a36840bee63e20b57e Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Fri, 4 Sep 2020 17:16:34 -0700 Subject: [PATCH 2/3] hopefully better choice of names --- fairscale/optim/oss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index 934584c1c..e5b478f99 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -101,7 +101,7 @@ def partition_parameters(self) -> List[List[dict]]: # For example, the apex library contains fused optimizers with a step that supports extra kwargs. def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]: # Sync oss param_groups attributes in case they've been updated by a scheduler. - self._sync_pg_attributes() + self._sync_param_groups() # Run the optimizer step on this shard only loss = self.optim.step(closure=closure, **kwargs) # type: ignore @@ -123,7 +123,7 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None: This needs to be called on all replicas """ # Sync lr and other attributes in case its been updated - self._sync_pg_attributes() + self._sync_param_groups() if self.rank == recipient_rank: # Pull the sharded state from all the other replicas @@ -189,7 +189,7 @@ def add_param_group(self, param_group: dict) -> None: if len(param_groups) == len(self.optim.param_groups) + 1: self.optim.add_param_group(param_groups[-1]) - def _sync_pg_attributes(self) -> None: + def _sync_param_groups(self) -> None: """Sync learning rate and other optimizer attributes (needed to support schedulers).""" for global_group, local_group in zip(self.param_groups, self.optim.param_groups): for k in local_group.keys(): From 93d571880a29c01c9014f17e2a2c0708f67f732d Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Tue, 8 Sep 2020 11:24:28 -0700 Subject: [PATCH 3/3] better test coverage, check for all dict state attributes --- tests/optim/test_oss.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/optim/test_oss.py b/tests/optim/test_oss.py index fc5c45ea2..57adf3146 100644 --- a/tests/optim/test_oss.py +++ b/tests/optim/test_oss.py @@ -54,11 +54,17 @@ def test_state_dict(): assert "param_groups" in state_dict.keys() assert "state" in state_dict.keys() - # Check that the pulled state is what we expect + # Check that the pulled state is what we expect, and that we have all the expected keys assert state_dict["param_groups"][0][0]["lr"] == 0.1 + assert state_dict["param_groups"][0][0]["momentum"] == 0.9 + assert not state_dict["param_groups"][0][0]["nesterov"] + assert state_dict["param_groups"][0][0]["weight_decay"] == 0.0 + assert state_dict["param_groups"][0][0]["dampening"] == 0.0 # Check that the pulled state and the .param_groups attribute are in sync - assert state_dict["param_groups"][0][0]["lr"] == o.param_groups[0]["lr"] + for k in state_dict["param_groups"][0][0].keys(): + if k != "params": + assert state_dict["param_groups"][0][0][k] == o.param_groups[0][k] # Check that it's correctly loaded o = optim.OSS([x], lr=0.01)