Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add SGD error message
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Feb 15, 2022
1 parent 9d07113 commit 5c66179
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
12 changes: 11 additions & 1 deletion flash/core/optimizers/optimizers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from inspect import isclass
from typing import Callable, List

Expand All @@ -17,7 +18,16 @@
_optimizers.append(_optimizer)

for fn in _optimizers:
_OPTIMIZERS_REGISTRY(fn, name=fn.__name__.lower())
name = fn.__name__.lower()
if name == "sgd":

def wrapper(fn, parameters, lr=None, **kwargs):
if lr is None:
raise TypeError("The `learning_rate` argument is required when the optimizer is SGD.")
return fn(parameters, lr, **kwargs)

fn = partial(wrapper, fn)
_OPTIMIZERS_REGISTRY(fn, name=name)


if _TORCH_OPTIMIZER_AVAILABLE:
Expand Down
5 changes: 5 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,11 @@ def test_optimizer_learning_rate():
ClassificationTask(model, optimizer="test", learning_rate=10).configure_optimizers()
mock_optimizer.assert_called_once_with(mock.ANY, lr=10)

mock_optimizer.reset_mock()

with pytest.raises(TypeError, match="The `learning_rate` argument is required"):
ClassificationTask(model, optimizer="sgd").configure_optimizers()


@pytest.mark.skipif(not _TORCH_OPTIMIZER_AVAILABLE, reason="torch_optimizer isn't installed.")
@pytest.mark.parametrize("optim", ["Yogi"])
Expand Down

0 comments on commit 5c66179

Please sign in to comment.