Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix un-usability of manual optimization in Flash. #1342

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr
You can also register you own custom scheduler recipes beforeahand and use them shown as above:

```py
@ImageClassifier.lr_schedulers
@ImageClassifier.lr_schedulers_registry
def my_steplr_recipe(optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/general/optimization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ Using the ``optimizers`` and ``lr_schedulers`` decorator pertaining to each :cla

>>> import torch
>>> from flash.image import ImageClassifier
>>> @ImageClassifier.lr_schedulers
>>> @ImageClassifier.lr_schedulers_registry
... def my_flash_steplr_recipe(optimizer):
... return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
...
Expand Down
12 changes: 6 additions & 6 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,8 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, FineTuningHooks
task.
"""

optimizers: FlashRegistry = _OPTIMIZERS_REGISTRY
lr_schedulers: FlashRegistry = _SCHEDULERS_REGISTRY
optimizers_registry: FlashRegistry = _OPTIMIZERS_REGISTRY
lr_schedulers_registry: FlashRegistry = _SCHEDULERS_REGISTRY
finetuning_strategies: FlashRegistry = _FINETUNING_STRATEGIES_REGISTRY
outputs: FlashRegistry = BASE_OUTPUTS

Expand Down Expand Up @@ -490,7 +490,7 @@ def _get_optimizer_class_from_registry(self, optimizer_key: str) -> Optimizer:
f"\nUse `{self.__class__.__name__}.available_optimizers()` to list the available optimizers."
f"\nList of available Optimizers: {self.available_optimizers()}."
)
optimizer_fn = self.optimizers.get(optimizer_key.lower())
optimizer_fn = self.optimizers_registry.get(optimizer_key.lower())
return optimizer_fn

def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]:
Expand Down Expand Up @@ -614,15 +614,15 @@ def get_backbone_details(cls, key) -> List[str]:
@classmethod
def available_optimizers(cls) -> List[str]:
"""Returns a list containing the keys of the available Optimizers."""
registry: Optional[FlashRegistry] = getattr(cls, "optimizers", None)
registry: Optional[FlashRegistry] = getattr(cls, "optimizers_registry", None)
if registry is None:
return []
return registry.available_keys()

@classmethod
def available_lr_schedulers(cls) -> List[str]:
"""Returns a list containing the keys of the available LR schedulers."""
registry: Optional[FlashRegistry] = getattr(cls, "lr_schedulers", None)
registry: Optional[FlashRegistry] = getattr(cls, "lr_schedulers_registry", None)
if registry is None:
return []
return registry.available_keys()
Expand Down Expand Up @@ -683,7 +683,7 @@ def _get_lr_scheduler_class_from_registry(self, lr_scheduler_key: str) -> Dict[s
f"\nUse `{self.__class__.__name__}.available_lr_schedulers()` to list the available schedulers."
f"\n>>> List of available LR Schedulers: {self.available_lr_schedulers()}."
)
lr_scheduler_fn: Dict[str, Any] = self.lr_schedulers.get(lr_scheduler_key.lower(), with_metadata=True)
lr_scheduler_fn: Dict[str, Any] = self.lr_schedulers_registry.get(lr_scheduler_key.lower(), with_metadata=True)
return deepcopy(lr_scheduler_fn)

def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
Expand Down
36 changes: 33 additions & 3 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,12 @@ def test_available_backbones_raises():
_ = ImageClassifier.available_backbones()


@ClassificationTask.lr_schedulers
@ClassificationTask.lr_schedulers_registry
def custom_steplr_configuration_return_as_instance(optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)


@ClassificationTask.lr_schedulers
@ClassificationTask.lr_schedulers_registry
def custom_steplr_configuration_return_as_dict(optimizer):
return {
"scheduler": torch.optim.lr_scheduler.StepLR(optimizer, step_size=10),
Expand Down Expand Up @@ -370,7 +370,7 @@ def test_optimizers_and_schedulers(tmpdir, optim, sched, interval):
@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
def test_optimizer_learning_rate():
mock_optimizer = MagicMock()
Task.optimizers(mock_optimizer, "test")
Task.optimizers_registry(mock_optimizer, "test")

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax())

Expand Down Expand Up @@ -444,6 +444,36 @@ def train_dataloader(self):
assert isinstance(trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.LambdaLR)


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
def test_manual_optimization(tmpdir):
class ManualOptimizationTask(Task):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False

def training_step(self, batch: Any, batch_idx: int) -> Any:
optimizers = self.optimizers()
assert isinstance(optimizers, torch.optim.Optimizer)
optimizers.zero_grad()

output = self.step(batch, batch_idx, self.train_metrics)
self.manual_backward(output["loss"])

optimizers.step()

lr_schedulers = self.lr_schedulers()
assert isinstance(lr_schedulers, torch.optim.lr_scheduler._LRScheduler)
lr_schedulers.step()

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
train_dl = DataLoader(DummyDataset())
val_dl = DataLoader(DummyDataset())
task = ManualOptimizationTask(model, loss_fn=F.nll_loss, lr_scheduler=("steplr", {"step_size": 1}))

trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(task, train_dl, val_dl)


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
def test_errors_and_exceptions_optimizers_and_schedulers():
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax())
Expand Down