Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
udpate on comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Apr 21, 2021
1 parent 2fed6a2 commit d5d24e5
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 31 deletions.
36 changes: 16 additions & 20 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.optim.optimizer import Optimizer

from flash.core.registry import FlashRegistry
from flash.core.schedulers import _SCHEDULER_REGISTRY
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess

Expand Down Expand Up @@ -68,7 +68,7 @@ class Task(LightningModule):
postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task.
"""

schedulers = _SCHEDULER_REGISTRY
schedulers: FlashRegistry = _SCHEDULERS_REGISTRY

def __init__(
self,
Expand Down Expand Up @@ -181,7 +181,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
batch = torch.stack(batch)
return self(batch)

def configure_optimizers(self):
def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]:
optimizer = self.optimizer
if not isinstance(self.optimizer, Optimizer):
self.optimizer_kwargs["lr"] = self.learning_rate
Expand Down Expand Up @@ -352,6 +352,8 @@ def available_schedulers(cls) -> List[str]:

def get_num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and devices."""
if not getattr(self, "trainer", None):
raise MisconfigurationException("The LightningModule isn't attached to the trainer yet.")
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
dataset_size = self.trainer.limit_train_batches
elif isinstance(self.trainer.limit_train_batches, float):
Expand All @@ -373,33 +375,27 @@ def get_num_training_steps(self) -> int:
return max_estimated_steps

def _compute_warmup(self, num_training_steps: int, num_warmup_steps: Union[int, float]) -> int:
if not isinstance(num_warmup_steps, float) or (num_warmup_steps > 1 or num_warmup_steps < 0):
raise MisconfigurationException(
"`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`"
)
if isinstance(num_warmup_steps, float):
# Convert float values to percentage of training steps to use as warmup
num_warmup_steps *= num_training_steps
return int(num_warmup_steps)
return round(num_warmup_steps)

def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler:
scheduler = self.scheduler
if isinstance(scheduler, _LRScheduler):
return scheduler
if isinstance(scheduler, str):
scheduler_fn = self.schedulers.get(self.scheduler)
if "num_warmup_steps" in self.scheduler_kwargs:
num_warmup_steps = self.scheduler_kwargs.get("num_warmup_steps")
if not isinstance(num_warmup_steps, float) or (num_warmup_steps > 1 or num_warmup_steps < 0):
raise MisconfigurationException(
"`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`"
)
num_training_steps = self.get_num_training_steps()
num_warmup_steps = self._compute_warmup(
num_training_steps=num_training_steps,
num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"),
)
return scheduler_fn(optimizer, num_warmup_steps, num_training_steps)
else:
raise MisconfigurationException(
"`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`"
)
num_training_steps: int = self.get_num_training_steps()
num_warmup_steps: int = self._compute_warmup(
num_training_steps=num_training_steps,
num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"),
)
return scheduler_fn(optimizer, num_warmup_steps, num_training_steps)
elif issubclass(scheduler, _LRScheduler):
return scheduler(optimizer, **self.scheduler_kwargs)
raise MisconfigurationException(
Expand Down
10 changes: 7 additions & 3 deletions flash/core/schedulers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Callable, List

from flash.core.registry import FlashRegistry
from flash.utils.imports import _TRANSFORMERS_AVAILABLE

_SCHEDULER_REGISTRY = FlashRegistry("scheduler")
_SCHEDULERS_REGISTRY = FlashRegistry("scheduler")

if _TRANSFORMERS_AVAILABLE:
from transformers import optimization
functions = [getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler')]
functions: List[Callable] = [
getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler')
]
for fn in functions:
_SCHEDULER_REGISTRY(fn, name=fn.__name__[4:])
_SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:])
6 changes: 1 addition & 5 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,4 @@
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
try:
import transformers
_TRANSFORMERS_AVAILABLE = True
except ModuleNotFoundError:
_TRANSFORMERS_AVAILABLE = False
_TRANSFORMERS_AVAILABLE = _module_available("transformers")
4 changes: 1 addition & 3 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,7 @@ def test_optimization(tmpdir):
]

optim = torch.optim.Adadelta(model.parameters())
with pytest.raises(
MisconfigurationException, match="`num_warmup_steps` should be provided as float between 0 and 1"
):
with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."):
task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup")
optimizer, scheduler = task.configure_optimizers()

Expand Down

0 comments on commit d5d24e5

Please sign in to comment.