diff --git a/CHANGELOG.md b/CHANGELOG.md index 8accf9dc6b1ab..f49974c46a033 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Do not fail if batch size could not be inferred for logging when using DeepSpeed ([#10438](https://github.com/PyTorchLightning/pytorch-lightning/issues/10438)) +- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) + + - @@ -142,6 +145,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) +- Fixed `to_torchscript()` causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/issues/10470)) + + +- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493)) + + - Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463)) @@ -150,7 +159,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - - ## [1.5.1] - 2021-11-09 ### Fixed diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b6f064d7d9802..dc3ce5f0f4063 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -115,6 +115,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._param_requires_grad_state = {} self._metric_attributes: Optional[Dict[int, str]] = None self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False + # TODO: remove after the 1.6 release + self._running_torchscript = False self._register_sharded_tensor_state_dict_hooks_if_available() @@ -1893,6 +1895,8 @@ def to_torchscript( """ mode = self.training + self._running_torchscript = True + if method == "script": torchscript_module = torch.jit.script(self.eval(), **kwargs) elif method == "trace": @@ -1918,6 +1922,8 @@ def to_torchscript( with fs.open(file_path, "wb") as f: torch.jit.save(torchscript_module, f) + self._running_torchscript = False + return torchscript_module @property @@ -1927,11 +1933,12 @@ def model_size(self) -> float: Note: This property will not return correct value for Deepspeed (stage 3) and fully-sharded training. """ - rank_zero_deprecation( - "The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7." - " Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.", - stacklevel=5, - ) + if not self._running_torchscript: # remove with the deprecation removal + rank_zero_deprecation( + "The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7." + " Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.", + stacklevel=5, + ) return get_model_size_mb(self) def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: diff --git a/pytorch_lightning/core/mixins/device_dtype_mixin.py b/pytorch_lightning/core/mixins/device_dtype_mixin.py index e02790edddd1e..e8b122989cd9c 100644 --- a/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -17,6 +17,8 @@ import torch from torch.nn import Module +import pytorch_lightning as pl + class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ["device", "dtype"] @@ -177,7 +179,9 @@ def __update_properties( self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None ) -> None: def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: - if not isinstance(module, DeviceDtypeModuleMixin): + # TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't + # work when using `init_meta_context`. + if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)): return if device is not None: module._device = device diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 4d41734ed90e6..6a54e973ffcf3 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -94,12 +94,9 @@ def on_trainer_init( " bar pass `enable_progress_bar = False` to the Trainer." ) - if enable_progress_bar: - self.trainer._progress_bar_callback = self.configure_progress_bar( - progress_bar_refresh_rate, process_position - ) - else: - self.trainer._progress_bar_callback = None + self.trainer._progress_bar_callback = self.configure_progress_bar( + progress_bar_refresh_rate, process_position, enable_progress_bar + ) # configure the ModelSummary callback self._configure_model_summary_callback(enable_model_summary, weights_summary) @@ -215,7 +212,9 @@ def _configure_swa_callbacks(self): if not existing_swa: self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks - def configure_progress_bar(self, refresh_rate=None, process_position=0): + def configure_progress_bar( + self, refresh_rate: Optional[int] = None, process_position: int = 0, enable_progress_bar: bool = True + ) -> Optional[ProgressBarBase]: if os.getenv("COLAB_GPU") and refresh_rate is None: # smaller refresh rate on colab causes crashes, choose a higher value refresh_rate = 20 @@ -229,7 +228,12 @@ def configure_progress_bar(self, refresh_rate=None, process_position=0): ) if len(progress_bars) == 1: progress_bar_callback = progress_bars[0] - elif refresh_rate > 0: + if not enable_progress_bar: + raise MisconfigurationException( + "Trainer was configured with `enable_progress_bar=False`" + f" but found `{progress_bar_callback.__class__.__name__}` in callbacks list." + ) + elif refresh_rate > 0 and enable_progress_bar: progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position) self.trainer.callbacks.append(progress_bar_callback) else: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4cbb33c9b4766..19efdce8e3549 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -84,7 +84,7 @@ from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.meta import materialize_module +from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import ( @@ -1406,10 +1406,21 @@ def _call_setup_hook(self) -> None: def _call_configure_sharded_model(self) -> None: with self.accelerator.model_sharded_context(): - materialize_module(self.lightning_module) + self._handle_meta_model() self.call_hook("configure_sharded_model") self.call_hook("on_configure_sharded_model") + def _handle_meta_model(self) -> None: + if not is_on_meta_device(self.lightning_module): + return + + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.") + + materialize_module(self.lightning_module) + # the trainer reference is lost during materialization + self.lightning_module.trainer = proxy(self) + def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index 60e6cc791b7ae..6d3c1d6b5f11b 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -18,13 +18,14 @@ from functools import partial from itertools import chain from types import ModuleType -from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type import torch from torch import nn, Tensor from torch.nn import Module from torch.nn.modules.container import ModuleDict, ModuleList, Sequential +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10 @@ -191,7 +192,6 @@ def materialize_module(root_module: nn.Module) -> nn.Module: # cache subclasses to optimize the search when resetting the meta device later on. __STORAGE_META__ = {} - __CREATED_MODULES__ = set() @@ -237,45 +237,52 @@ def _set_meta_device() -> None: for subclass in get_all_subclasses(torch.nn.modules.module.Module): - if isinstance(subclass, (Sequential, ModuleList, ModuleDict)): + if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule): continue # if a subclass has already been stored, we should use the cache if str(subclass) in __STORAGE_META__: - # reset the class import package to its rightfull state. + # reset the class import package to its rightful state. mods, subclass, meta_class = __STORAGE_META__[subclass] for mod in mods: setattr(mod, subclass.__name__, meta_class) continue + class _IsinstanceMetaclass(type(subclass)): + def __instancecheck__(self, instance: Any) -> bool: + """Overrides the ``isinstance`` check on ``_MaterializerModule`` objects.""" + return isinstance(instance, self.__bases__[0]) + # Create a class subclassing current `subclass` overriding its new method. # this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta` # version of the current subclass module - class _MetaClass(subclass): + class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): @classmethod @contextmanager - def instantiation_context(cls, materialize: bool): + def instantiation_context(cls): _unset_meta_device(from_created=True) yield _set_meta_device_populated(from_created=True) @classmethod def materialize(cls, materialize_fn: Callable): - with cls.instantiation_context(materialize=True): + with cls.instantiation_context(): obj = materialize_fn() return obj @staticmethod def add_subclasses(subclass): - """This is used to unrol the instantion tree while creating the modules.""" - __CREATED_MODULES__.add(subclass) + """This is used to unroll the instantiation tree while creating the modules.""" + # Don't store the LightningModule as skipped from the Meta process. + if subclass != pl.LightningModule: + __CREATED_MODULES__.add(subclass) if subclass.__bases__[0] != torch.nn.modules.module.Module: - _MetaClass.add_subclasses(subclass.__bases__[0]) + _MaterializerModule.add_subclasses(subclass.__bases__[0]) def __new__(cls, *args, **kwargs): subclass = cls.__bases__[0] cls.add_subclasses(subclass) - with cls.instantiation_context(materialize=False): + with cls.instantiation_context(): obj = init_meta(subclass, *args, **kwargs) obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) @@ -294,9 +301,8 @@ def search(mod: ModuleType) -> List[ModuleType]: # nn.Module class can be imported at different level and they all need to be mocked. # Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear # Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear - # needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass - out = [] - out.append(search(mod)) + # needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule + out = [search(mod)] for name in submodules[1:]: mod = getattr(mod, name) out.append(search(mod)) @@ -305,11 +311,11 @@ def search(mod: ModuleType) -> List[ModuleType]: mods = [mod for mod in chain(*out) if mod] # store the modules search so it doesn't have to be performed again for this class - __STORAGE_META__[subclass] = (mods, subclass, _MetaClass) + __STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule) # replace all subclass by its meta form for mod in mods: - setattr(mod, subclass.__name__, _MetaClass) + setattr(mod, subclass.__name__, _MaterializerModule) @contextmanager @@ -321,3 +327,11 @@ def init_meta_context() -> Generator: _set_meta_device() yield _unset_meta_device() + + +def is_on_meta_device(module: nn.Module) -> bool: + try: + param = next(module.parameters()) + return param.device.type == "meta" + except StopIteration: + return False diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 99fe02ce21a11..a8371591759d7 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -14,7 +14,7 @@ import os import pickle import sys -from typing import Optional, Union +from typing import Union from unittest import mock from unittest.mock import ANY, call, Mock @@ -32,65 +32,54 @@ @pytest.mark.parametrize( - "callbacks,refresh_rate", + "kwargs", [ - ([], None), - ([], 1), - ([], 2), - ([TQDMProgressBar(refresh_rate=1)], 0), - ([TQDMProgressBar(refresh_rate=2)], 0), - ([TQDMProgressBar(refresh_rate=2)], 1), + # won't print but is still set + {"callbacks": TQDMProgressBar(refresh_rate=0)}, + {"callbacks": TQDMProgressBar()}, + {"progress_bar_refresh_rate": 1}, ], ) -def test_tqdm_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]): +def test_tqdm_progress_bar_on(tmpdir, kwargs): """Test different ways the progress bar can be turned on.""" - - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=callbacks, - progress_bar_refresh_rate=refresh_rate, - max_epochs=1, - overfit_batches=5, - ) + if "progress_bar_refresh_rate" in kwargs: + with pytest.deprecated_call(match=r"progress_bar_refresh_rate=.*` is deprecated"): + trainer = Trainer(default_root_dir=tmpdir, **kwargs) + else: + trainer = Trainer(default_root_dir=tmpdir, **kwargs) progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] - # Trainer supports only a single progress bar callback at the moment assert len(progress_bars) == 1 assert progress_bars[0] is trainer.progress_bar_callback -@pytest.mark.parametrize( - "callbacks,refresh_rate,enable_progress_bar", - [([], 0, True), ([], False, True), ([ModelCheckpoint(dirpath="../trainer")], 0, True), ([], 1, False)], -) -def test_tqdm_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool): +@pytest.mark.parametrize("kwargs", [{"enable_progress_bar": False}, {"progress_bar_refresh_rate": 0}]) +def test_tqdm_progress_bar_off(tmpdir, kwargs): """Test different ways the progress bar can be turned off.""" - - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=callbacks, - progress_bar_refresh_rate=refresh_rate, - enable_progress_bar=enable_progress_bar, - ) - - progress_bars = [c for c in trainer.callbacks if isinstance(c, TQDMProgressBar)] - assert 0 == len(progress_bars) - assert not trainer.progress_bar_callback + if "progress_bar_refresh_rate" in kwargs: + pytest.deprecated_call(match=r"progress_bar_refresh_rate=.*` is deprecated").__enter__() + trainer = Trainer(default_root_dir=tmpdir, **kwargs) + progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] + assert not len(progress_bars) def test_tqdm_progress_bar_misconfiguration(): """Test that Trainer doesn't accept multiple progress bars.""" + # Trainer supports only a single progress bar callback at the moment callbacks = [TQDMProgressBar(), TQDMProgressBar(), ModelCheckpoint(dirpath="../trainer")] with pytest.raises(MisconfigurationException, match=r"^You added multiple progress bar callbacks"): Trainer(callbacks=callbacks) + with pytest.raises(MisconfigurationException, match=r"enable_progress_bar=False` but found `TQDMProgressBar"): + Trainer(callbacks=TQDMProgressBar(), enable_progress_bar=False) + def test_tqdm_progress_bar_totals(tmpdir): """Test that the progress finishes with the correct total steps processed.""" model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=1, max_epochs=1) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) bar = trainer.progress_bar_callback assert float("inf") == bar.total_train_batches assert 0 == bar.total_val_batches @@ -209,14 +198,15 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal self.test_batches_seen += 1 progress_bar = CurrentProgressBar(refresh_rate=refresh_rate) - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=[progress_bar], - progress_bar_refresh_rate=101, # should not matter if custom callback provided - limit_train_batches=1.0, - num_sanity_val_steps=2, - max_epochs=3, - ) + with pytest.deprecated_call(match=r"progress_bar_refresh_rate=101\)` is deprecated"): + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[progress_bar], + progress_bar_refresh_rate=101, # should not matter if custom callback provided + limit_train_batches=1.0, + num_sanity_val_steps=2, + max_epochs=3, + ) assert trainer.progress_bar_callback.refresh_rate == refresh_rate trainer.fit(model) @@ -276,9 +266,6 @@ def test_tqdm_progress_bar_default_value(tmpdir): trainer = Trainer(default_root_dir=tmpdir) assert trainer.progress_bar_callback.refresh_rate == 1 - trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=None) - assert trainer.progress_bar_callback.refresh_rate == 1 - @mock.patch.dict(os.environ, {"COLAB_GPU": "1"}) def test_tqdm_progress_bar_value_on_colab(tmpdir): @@ -286,10 +273,14 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): trainer = Trainer(default_root_dir=tmpdir) assert trainer.progress_bar_callback.refresh_rate == 20 - trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=None) - assert trainer.progress_bar_callback.refresh_rate == 20 + trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar()) + assert trainer.progress_bar_callback.refresh_rate == 1 # FIXME: should be 20 + + trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar(refresh_rate=19)) + assert trainer.progress_bar_callback.refresh_rate == 19 - trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=19) + with pytest.deprecated_call(match=r"progress_bar_refresh_rate=19\)` is deprecated"): + trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=19) assert trainer.progress_bar_callback.refresh_rate == 19 diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 63a2211934ece..6bd7db1aeff8d 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -791,7 +791,7 @@ def val_dataloader(self): max_epochs=1, val_check_interval=val_check_interval, num_sanity_val_steps=0, - progress_bar_refresh_rate=0, + enable_progress_bar=False, ) trainer.fit(model) @@ -829,7 +829,7 @@ def val_dataloader(self): max_epochs=1, val_check_interval=val_check_interval, num_sanity_val_steps=0, - progress_bar_refresh_rate=0, + enable_progress_bar=False, ) with pytest.raises(CustomException): # will stop during validation @@ -880,7 +880,7 @@ def val_dataloader(self): max_epochs=1, val_check_interval=val_check_interval, num_sanity_val_steps=0, - progress_bar_refresh_rate=0, + enable_progress_bar=False, ) trainer.fit(model, ckpt_path=ckpt_path) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 2cb68aa2e95bd..e3c353c3eb063 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -22,6 +22,7 @@ LearningRateMonitor, ModelCheckpoint, ModelSummary, + ProgressBarBase, TQDMProgressBar, ) from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector @@ -143,10 +144,11 @@ def test_attach_model_callbacks(): def _attach_callbacks(trainer_callbacks, model_callbacks): model = LightningModule() model.configure_callbacks = lambda: model_callbacks + has_progress_bar = any(isinstance(cb, ProgressBarBase) for cb in trainer_callbacks + model_callbacks) trainer = Trainer( enable_checkpointing=False, - enable_progress_bar=False, - enable_model_summary=None, + enable_progress_bar=has_progress_bar, + enable_model_summary=False, callbacks=trainer_callbacks, ) trainer.model = model diff --git a/tests/utilities/test_meta.py b/tests/utilities/test_meta.py index 8e36a86c3beef..581b949d9167f 100644 --- a/tests/utilities/test_meta.py +++ b/tests/utilities/test_meta.py @@ -14,7 +14,7 @@ from torch import nn from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities.meta import init_meta_context, materialize_module +from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module from tests.helpers.runif import RunIf @@ -31,18 +31,23 @@ def __init__(self, num_layers: int): self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)]) -@RunIf(min_torch="1.10.0") +@RunIf(special=True, min_torch="1.10.0") def test_init_meta_context(): with init_meta_context(): m = nn.Linear(in_features=1, out_features=1) + assert isinstance(m, nn.Linear) assert m.weight.device.type == "meta" + assert is_on_meta_device(m) mlp = MLP(4) assert mlp.layer[0].weight.device.type == "meta" mlp = materialize_module(mlp) assert mlp.layer[0].weight.device.type == "cpu" + assert not is_on_meta_device(mlp) + assert not is_on_meta_device(nn.Module()) + model = BoringModel(4) assert model.layer[0].weight.device.type == "meta" materialize_module(model)