Skip to content

Commit

Permalink
Decide on the fly whether to pass a closure arg or not
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Sep 12, 2020
1 parent 61aac92 commit e143e1f
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import copy
import inspect
from itertools import chain
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type
Expand Down Expand Up @@ -61,6 +62,9 @@ def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any =
split_param_groups = self.partition_parameters()
self.optim = optim(split_param_groups[self.rank], **defaults)

# Check if this optimnizer accepts a closure
self._pass_closure = "closure" in inspect.signature(self.optim.step).parameters.keys()

# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []

Expand Down Expand Up @@ -103,8 +107,11 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) ->
# Sync oss param_groups attributes in case they've been updated by a scheduler.
self._sync_param_groups()

# Run the optimizer step on this shard only
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
# Run the optimizer step on this shard only:
if self._pass_closure:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
else:
loss = self.optim.step(**kwargs)

# Sync all the states. Broadcast requests are issued async, we check completeness before moving on
requests = []
Expand Down

1 comment on commit e143e1f

@blefaudeux
Copy link
Contributor Author

@blefaudeux blefaudeux commented on e143e1f Sep 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested with LARC, works fine. I need to update the unit test to make sure that the branch is covered, after that good to go
cc @prigoyal @mannatsingh

Please sign in to comment.