diff --git a/flash/core/optimizers/optimizers.py b/flash/core/optimizers/optimizers.py index 778f4ba247..5c5f9816a5 100644 --- a/flash/core/optimizers/optimizers.py +++ b/flash/core/optimizers/optimizers.py @@ -1,3 +1,4 @@ +from functools import partial from inspect import isclass from typing import Callable, List @@ -17,7 +18,16 @@ _optimizers.append(_optimizer) for fn in _optimizers: - _OPTIMIZERS_REGISTRY(fn, name=fn.__name__.lower()) + name = fn.__name__.lower() + if name == "sgd": + + def wrapper(fn, parameters, lr=None, **kwargs): + if lr is None: + raise TypeError("The `learning_rate` argument is required when the optimizer is SGD.") + return fn(parameters, lr, **kwargs) + + fn = partial(wrapper, fn) + _OPTIMIZERS_REGISTRY(fn, name=name) if _TORCH_OPTIMIZER_AVAILABLE: diff --git a/tests/core/test_model.py b/tests/core/test_model.py index e71f15582a..ce3f62fc44 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -374,6 +374,11 @@ def test_optimizer_learning_rate(): ClassificationTask(model, optimizer="test", learning_rate=10).configure_optimizers() mock_optimizer.assert_called_once_with(mock.ANY, lr=10) + mock_optimizer.reset_mock() + + with pytest.raises(TypeError, match="The `learning_rate` argument is required"): + ClassificationTask(model, optimizer="sgd").configure_optimizers() + @pytest.mark.skipif(not _TORCH_OPTIMIZER_AVAILABLE, reason="torch_optimizer isn't installed.") @pytest.mark.parametrize("optim", ["Yogi"])