From d1aba5f7b6e677ce9f3de512b3ef4621f1bfab46 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 21 Apr 2021 10:22:47 +0100 Subject: [PATCH 1/6] add support for schedulers --- flash/core/model.py | 92 +++++++++++++++++++++++++++++++++++-- flash/core/schedulers.py | 10 ++++ flash/data/data_pipeline.py | 3 ++ flash/utils/imports.py | 5 ++ tests/core/test_model.py | 67 +++++++++++++++++++++++++++ 5 files changed, 173 insertions(+), 4 deletions(-) create mode 100644 flash/core/schedulers.py diff --git a/flash/core/model.py b/flash/core/model.py index 9914b4cb61..78c911740c 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -21,9 +21,14 @@ from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer +from transformers.trainer_utils import SchedulerType from flash.core.registry import FlashRegistry +from flash.core.schedulers import _SCHEDULER_REGISTRY from flash.core.utils import get_callable_dict from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess @@ -64,11 +69,16 @@ class Task(LightningModule): postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task. """ + schedulers = _SCHEDULER_REGISTRY + def __init__( self, model: Optional[nn.Module] = None, loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, preprocess: Preprocess = None, @@ -78,7 +88,11 @@ def __init__( if model is not None: self.model = model self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn) - self.optimizer_cls = optimizer + self.optimizer = optimizer + self.scheduler = scheduler + self.optimizer_kwargs = optimizer_kwargs or {} + self.scheduler_kwargs = scheduler_kwargs or {} + self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) self.learning_rate = learning_rate # TODO: should we save more? Bug on some regarding yaml if we save metrics @@ -168,8 +182,14 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A batch = torch.stack(batch) return self(batch) - def configure_optimizers(self) -> torch.optim.Optimizer: - return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate) + def configure_optimizers(self): + optimizer = self.optimizer + if not isinstance(self.optimizer, Optimizer): + self.optimizer_kwargs["lr"] = self.learning_rate + optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs) + if self.scheduler: + return [optimizer], [self._instantiate_scheduler(optimizer)] + return optimizer def configure_finetune_callback(self) -> List[Callback]: return [] @@ -323,3 +343,67 @@ def available_models(cls) -> List[str]: if registry is None: return [] return registry.available_keys() + + @classmethod + def available_schedulers(cls) -> List[str]: + registry: Optional[FlashRegistry] = getattr(cls, "schedulers", None) + if registry is None: + return [] + return registry.available_keys() + + def get_num_training_steps(self) -> int: + """Total training steps inferred from datamodule and devices.""" + if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0: + dataset_size = self.trainer.limit_train_batches + elif isinstance(self.trainer.limit_train_batches, float): + # limit_train_batches is a percentage of batches + dataset_size = len(self.train_dataloader()) + dataset_size = int(dataset_size * self.trainer.limit_train_batches) + else: + dataset_size = len(self.train_dataloader()) + + num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) + if self.trainer.tpu_cores: + num_devices = max(num_devices, self.trainer.tpu_cores) + + effective_batch_size = self.trainer.accumulate_grad_batches * num_devices + max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs + + if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps: + return self.trainer.max_steps + return max_estimated_steps + + def _compute_warmup(self, num_training_steps: int, num_warmup_steps: Union[int, float]) -> int: + if isinstance(num_warmup_steps, float): + # Convert float values to percentage of training steps to use as warmup + num_warmup_steps *= num_training_steps + return int(num_warmup_steps) + + def _instantiate_scheduler(self, optimizer: Optimizer) -> SchedulerType: + scheduler = self.scheduler + if isinstance(scheduler, _LRScheduler): + return scheduler + if isinstance(scheduler, str): + scheduler_fn = self.schedulers.get(self.scheduler) + if "num_warmup_steps" in self.scheduler_kwargs: + num_warmup_steps = self.scheduler_kwargs.get("num_warmup_steps") + if not isinstance(num_warmup_steps, float) or (num_warmup_steps > 1 or num_warmup_steps < 0): + raise MisconfigurationException( + "`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`" + ) + num_training_steps = self.get_num_training_steps() + num_warmup_steps = self._compute_warmup( + num_training_steps=num_training_steps, + num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), + ) + return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) + else: + raise MisconfigurationException( + "`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`" + ) + elif issubclass(scheduler, _LRScheduler): + return scheduler(optimizer, **self.scheduler_kwargs) + raise MisconfigurationException( + "scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` " + f"or a built-in scheduler in {self.available_schedulers()}" + ) diff --git a/flash/core/schedulers.py b/flash/core/schedulers.py new file mode 100644 index 0000000000..6fdc934eee --- /dev/null +++ b/flash/core/schedulers.py @@ -0,0 +1,10 @@ +from flash.core.registry import FlashRegistry +from flash.utils.imports import _TRANSFORMERS_AVAILABLE + +_SCHEDULER_REGISTRY = FlashRegistry("scheduler") + +if _TRANSFORMERS_AVAILABLE: + from transformers.optimization import TYPE_TO_SCHEDULER_FUNCTION + + for v in TYPE_TO_SCHEDULER_FUNCTION.values(): + _SCHEDULER_REGISTRY(v, name=v.__name__[4:]) diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 46be3c823c..fe75404f1c 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -16,6 +16,7 @@ import weakref from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union +import torch from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import imports @@ -285,6 +286,8 @@ def _set_loader(model: 'Task', loader_name: str, new_loader: DataLoader) -> None def _attach_preprocess_to_model( self, model: 'Task', stage: Optional[RunningStage] = None, device_transform_only: bool = False ) -> None: + device_collate_fn = torch.nn.Identity() + if not stage: stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] diff --git a/flash/utils/imports.py b/flash/utils/imports.py index 5e17ba6d3e..a381141b45 100644 --- a/flash/utils/imports.py +++ b/flash/utils/imports.py @@ -5,3 +5,8 @@ _COCO_AVAILABLE = _module_available("pycocotools") _TIMM_AVAILABLE = _module_available("timm") _TORCHVISION_AVAILABLE = _module_available("torchvision") +try: + import transformers + _TRANSFORMERS_AVAILABLE = True +except ModuleNotFoundError: + _TRANSFORMERS_AVAILABLE = False diff --git a/tests/core/test_model.py b/tests/core/test_model.py index d0b0048b23..2a0be64c66 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -20,12 +20,16 @@ import pytorch_lightning as pl import torch from PIL import Image +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn, Tensor from torch.nn import functional as F +from torch.utils.data import DataLoader +import flash from flash.core.classification import ClassificationTask from flash.tabular import TabularClassifier from flash.text import SummarizationTask, TextClassifier +from flash.utils.imports import _TRANSFORMERS_AVAILABLE from flash.vision import ImageClassificationData, ImageClassifier # ======== Mock functions ======== @@ -160,3 +164,66 @@ class Foo(ImageClassifier): backbones = None assert Foo.available_backbones() == [] + + +def test_optimization(tmpdir): + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) + optim = torch.optim.Adam(model.parameters()) + task = ClassificationTask(model, optimizer=optim, scheduler=None) + + optimizer = task.configure_optimizers() + assert optimizer == optim + + task = ClassificationTask(model, optimizer=torch.optim.Adadelta, optimizer_kwargs={"eps": 0.5}, scheduler=None) + optimizer = task.configure_optimizers() + assert isinstance(optimizer, torch.optim.Adadelta) + assert optimizer.defaults["eps"] == 0.5 + + task = ClassificationTask( + model, + optimizer=torch.optim.Adadelta, + scheduler=torch.optim.lr_scheduler.StepLR, + scheduler_kwargs={"step_size": 1} + ) + optimizer, scheduler = task.configure_optimizers() + assert isinstance(optimizer[0], torch.optim.Adadelta) + assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + + optim = torch.optim.Adadelta(model.parameters()) + task = ClassificationTask(model, optimizer=optim, scheduler=torch.optim.lr_scheduler.StepLR(optim, step_size=1)) + optimizer, scheduler = task.configure_optimizers() + assert isinstance(optimizer[0], torch.optim.Adadelta) + assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + + if _TRANSFORMERS_AVAILABLE: + from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION + + assert task.available_schedulers() == [ + 'constant_schedule', 'constant_schedule_with_warmup', 'cosine_schedule_with_warmup', + 'cosine_with_hard_restarts_schedule_with_warmup', 'linear_schedule_with_warmup', + 'polynomial_decay_schedule_with_warmup' + ] + + optim = torch.optim.Adadelta(model.parameters()) + with pytest.raises( + MisconfigurationException, match="`num_warmup_steps` should be provided as float between 0 and 1" + ): + task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup") + optimizer, scheduler = task.configure_optimizers() + + task = ClassificationTask( + model, + optimizer=optim, + scheduler="linear_schedule_with_warmup", + scheduler_kwargs={"num_warmup_steps": 0.1}, + loss_fn=F.nll_loss, + ) + trainer = flash.Trainer(max_epochs=100) + ds = DummyDataset() + trainer.fit(task, train_dataloader=DataLoader(ds)) + optimizer, scheduler = task.configure_optimizers() + assert isinstance(optimizer[0], torch.optim.Adadelta) + assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR) + expected = TYPE_TO_SCHEDULER_FUNCTION[SchedulerType.LINEAR].__name__ + assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected From 4833eea8063e8a5765b7daeb4eb98f456ad26218 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 21 Apr 2021 10:31:59 +0100 Subject: [PATCH 2/6] update changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c50aaab6da..627e907654 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Switch to use `torchmetrics` ([#169](https://github.com/PyTorchLightning/lightning-flash/pull/169)) +- Better support for `optimizer` and `schedulers` ([#232](https://github.com/PyTorchLightning/lightning-flash/pull/232)) + + ### Fixed @@ -28,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added `RetinaNet` & `backbones` to `ObjectDetector` Task ([#121](https://github.com/PyTorchLightning/lightning-flash/pull/121)) -- Added .csv image loading utils ([#116](https://github.com/PyTorchLightning/lightning-flash/pull/116), +- Added .csv image loading utils ([#116](https://github.com/PyTorchLightning/lightning-flash/pull/116), [#117](https://github.com/PyTorchLightning/lightning-flash/pull/117), [#118](https://github.com/PyTorchLightning/lightning-flash/pull/118)) From 795f8d167c7e466b4bc6348656a3bfcdcc7117aa Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 21 Apr 2021 10:36:02 +0100 Subject: [PATCH 3/6] resolve typing --- flash/core/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 78c911740c..ccaebcffb6 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -25,7 +25,6 @@ from torch import nn from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer -from transformers.trainer_utils import SchedulerType from flash.core.registry import FlashRegistry from flash.core.schedulers import _SCHEDULER_REGISTRY @@ -379,7 +378,7 @@ def _compute_warmup(self, num_training_steps: int, num_warmup_steps: Union[int, num_warmup_steps *= num_training_steps return int(num_warmup_steps) - def _instantiate_scheduler(self, optimizer: Optimizer) -> SchedulerType: + def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: scheduler = self.scheduler if isinstance(scheduler, _LRScheduler): return scheduler From 644f062a61bc8fe2913c7c6a61cf79fa5a2b58ce Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 21 Apr 2021 11:09:38 +0100 Subject: [PATCH 4/6] update task --- flash/core/schedulers.py | 8 ++++---- tests/core/test_model.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/flash/core/schedulers.py b/flash/core/schedulers.py index 6fdc934eee..bff5f4d138 100644 --- a/flash/core/schedulers.py +++ b/flash/core/schedulers.py @@ -4,7 +4,7 @@ _SCHEDULER_REGISTRY = FlashRegistry("scheduler") if _TRANSFORMERS_AVAILABLE: - from transformers.optimization import TYPE_TO_SCHEDULER_FUNCTION - - for v in TYPE_TO_SCHEDULER_FUNCTION.values(): - _SCHEDULER_REGISTRY(v, name=v.__name__[4:]) + from transformers import optimization + functions = [getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler')] + for fn in functions: + _SCHEDULER_REGISTRY(fn, name=fn.__name__[4:]) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 2a0be64c66..5b18a711db 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -197,7 +197,7 @@ def test_optimization(tmpdir): assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) if _TRANSFORMERS_AVAILABLE: - from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION + from transformers.optimization import get_linear_schedule_with_warmup assert task.available_schedulers() == [ 'constant_schedule', 'constant_schedule_with_warmup', 'cosine_schedule_with_warmup', @@ -219,11 +219,11 @@ def test_optimization(tmpdir): scheduler_kwargs={"num_warmup_steps": 0.1}, loss_fn=F.nll_loss, ) - trainer = flash.Trainer(max_epochs=100) + trainer = flash.Trainer(max_epochs=1, limit_train_batches=2) ds = DummyDataset() trainer.fit(task, train_dataloader=DataLoader(ds)) optimizer, scheduler = task.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adadelta) assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR) - expected = TYPE_TO_SCHEDULER_FUNCTION[SchedulerType.LINEAR].__name__ + expected = get_linear_schedule_with_warmup.__name__ assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected From 2fed6a29f4deea33329fe59f39b086a99445c529 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 21 Apr 2021 11:11:52 +0100 Subject: [PATCH 5/6] change for log softmax --- tests/core/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 5b18a711db..2426cff78f 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -168,7 +168,7 @@ class Foo(ImageClassifier): def test_optimization(tmpdir): - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) optim = torch.optim.Adam(model.parameters()) task = ClassificationTask(model, optimizer=optim, scheduler=None) From d5d24e5d7c7bdbffbe7dd1b0a669796c5bc58d66 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 21 Apr 2021 15:39:14 +0100 Subject: [PATCH 6/6] udpate on comments --- flash/core/model.py | 36 ++++++++++++++++-------------------- flash/core/schedulers.py | 10 +++++++--- flash/utils/imports.py | 6 +----- tests/core/test_model.py | 4 +--- 4 files changed, 25 insertions(+), 31 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index ccaebcffb6..02dc367932 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -27,7 +27,7 @@ from torch.optim.optimizer import Optimizer from flash.core.registry import FlashRegistry -from flash.core.schedulers import _SCHEDULER_REGISTRY +from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.utils import get_callable_dict from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess @@ -68,7 +68,7 @@ class Task(LightningModule): postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task. """ - schedulers = _SCHEDULER_REGISTRY + schedulers: FlashRegistry = _SCHEDULERS_REGISTRY def __init__( self, @@ -181,7 +181,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A batch = torch.stack(batch) return self(batch) - def configure_optimizers(self): + def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: optimizer = self.optimizer if not isinstance(self.optimizer, Optimizer): self.optimizer_kwargs["lr"] = self.learning_rate @@ -352,6 +352,8 @@ def available_schedulers(cls) -> List[str]: def get_num_training_steps(self) -> int: """Total training steps inferred from datamodule and devices.""" + if not getattr(self, "trainer", None): + raise MisconfigurationException("The LightningModule isn't attached to the trainer yet.") if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0: dataset_size = self.trainer.limit_train_batches elif isinstance(self.trainer.limit_train_batches, float): @@ -373,10 +375,14 @@ def get_num_training_steps(self) -> int: return max_estimated_steps def _compute_warmup(self, num_training_steps: int, num_warmup_steps: Union[int, float]) -> int: + if not isinstance(num_warmup_steps, float) or (num_warmup_steps > 1 or num_warmup_steps < 0): + raise MisconfigurationException( + "`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`" + ) if isinstance(num_warmup_steps, float): # Convert float values to percentage of training steps to use as warmup num_warmup_steps *= num_training_steps - return int(num_warmup_steps) + return round(num_warmup_steps) def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: scheduler = self.scheduler @@ -384,22 +390,12 @@ def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: return scheduler if isinstance(scheduler, str): scheduler_fn = self.schedulers.get(self.scheduler) - if "num_warmup_steps" in self.scheduler_kwargs: - num_warmup_steps = self.scheduler_kwargs.get("num_warmup_steps") - if not isinstance(num_warmup_steps, float) or (num_warmup_steps > 1 or num_warmup_steps < 0): - raise MisconfigurationException( - "`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`" - ) - num_training_steps = self.get_num_training_steps() - num_warmup_steps = self._compute_warmup( - num_training_steps=num_training_steps, - num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), - ) - return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) - else: - raise MisconfigurationException( - "`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`" - ) + num_training_steps: int = self.get_num_training_steps() + num_warmup_steps: int = self._compute_warmup( + num_training_steps=num_training_steps, + num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), + ) + return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) elif issubclass(scheduler, _LRScheduler): return scheduler(optimizer, **self.scheduler_kwargs) raise MisconfigurationException( diff --git a/flash/core/schedulers.py b/flash/core/schedulers.py index bff5f4d138..eee60cc8f8 100644 --- a/flash/core/schedulers.py +++ b/flash/core/schedulers.py @@ -1,10 +1,14 @@ +from typing import Callable, List + from flash.core.registry import FlashRegistry from flash.utils.imports import _TRANSFORMERS_AVAILABLE -_SCHEDULER_REGISTRY = FlashRegistry("scheduler") +_SCHEDULERS_REGISTRY = FlashRegistry("scheduler") if _TRANSFORMERS_AVAILABLE: from transformers import optimization - functions = [getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler')] + functions: List[Callable] = [ + getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler') + ] for fn in functions: - _SCHEDULER_REGISTRY(fn, name=fn.__name__[4:]) + _SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:]) diff --git a/flash/utils/imports.py b/flash/utils/imports.py index a381141b45..5252a3e3d5 100644 --- a/flash/utils/imports.py +++ b/flash/utils/imports.py @@ -5,8 +5,4 @@ _COCO_AVAILABLE = _module_available("pycocotools") _TIMM_AVAILABLE = _module_available("timm") _TORCHVISION_AVAILABLE = _module_available("torchvision") -try: - import transformers - _TRANSFORMERS_AVAILABLE = True -except ModuleNotFoundError: - _TRANSFORMERS_AVAILABLE = False +_TRANSFORMERS_AVAILABLE = _module_available("transformers") diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 2426cff78f..450b662dbd 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -206,9 +206,7 @@ def test_optimization(tmpdir): ] optim = torch.optim.Adadelta(model.parameters()) - with pytest.raises( - MisconfigurationException, match="`num_warmup_steps` should be provided as float between 0 and 1" - ): + with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."): task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup") optimizer, scheduler = task.configure_optimizers()