Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat ] OSS : optional closure argument for the optimizer #86

Merged
merged 4 commits into from
Sep 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,11 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) ->
# Sync oss param_groups attributes in case they've been updated by a scheduler.
self._sync_param_groups()

# Run the optimizer step on this shard only
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
# Run the optimizer step on this shard only:
if closure is not None:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore

This comment was marked as outdated.

else:
loss = self.optim.step(**kwargs)

# Sync all the states. Broadcast requests are issued async, we check completeness before moving on
requests = []
Expand Down
23 changes: 17 additions & 6 deletions tests/optim/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,12 @@ def test_lr_scheduler():
assert x == x2


class SGDWithStepKWArg(torch.optim.SGD):
def step(self, closure=None, kwarg=[]):
super().step()
kwarg.append(5)


def test_step_with_kwargs():
class SGDWithStepKWArg(torch.optim.SGD):
def step(self, closure=None, kwarg=[]):
super().step()
kwarg.append(5)

kwarg = []
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithStepKWArg, lr=0.1)
Expand All @@ -119,6 +118,18 @@ def test_step_with_kwargs():
assert x == torch.tensor([0.9], device=DEVICE)


def test_step_without_closure():
class SGDWithoutClosure(torch.optim.SGD):
def step(self):
return super().step()

x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithoutClosure, lr=0.1)
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)


def test_local_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
Expand Down