Skip to content

Commit

Permalink
Merge branch 'master' into fix/deepspeed_logging_per_gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren committed Nov 15, 2021
2 parents 415d0c5 + 65ebfed commit 7d9dd6f
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 89 deletions.
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Do not fail if batch size could not be inferred for logging when using DeepSpeed ([#10438](https://github.com/PyTorchLightning/pytorch-lightning/issues/10438))


- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))


-


Expand Down Expand Up @@ -142,6 +145,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))


- Fixed `to_torchscript()` causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/issues/10470))


- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))


- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463))


Expand All @@ -150,7 +159,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

-


## [1.5.1] - 2021-11-09

### Fixed
Expand Down
17 changes: 12 additions & 5 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._param_requires_grad_state = {}
self._metric_attributes: Optional[Dict[int, str]] = None
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
# TODO: remove after the 1.6 release
self._running_torchscript = False

self._register_sharded_tensor_state_dict_hooks_if_available()

Expand Down Expand Up @@ -1893,6 +1895,8 @@ def to_torchscript(
"""
mode = self.training

self._running_torchscript = True

if method == "script":
torchscript_module = torch.jit.script(self.eval(), **kwargs)
elif method == "trace":
Expand All @@ -1918,6 +1922,8 @@ def to_torchscript(
with fs.open(file_path, "wb") as f:
torch.jit.save(torchscript_module, f)

self._running_torchscript = False

return torchscript_module

@property
Expand All @@ -1927,11 +1933,12 @@ def model_size(self) -> float:
Note:
This property will not return correct value for Deepspeed (stage 3) and fully-sharded training.
"""
rank_zero_deprecation(
"The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7."
" Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.",
stacklevel=5,
)
if not self._running_torchscript: # remove with the deprecation removal
rank_zero_deprecation(
"The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7."
" Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.",
stacklevel=5,
)
return get_model_size_mb(self)

def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/core/mixins/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
from torch.nn import Module

import pytorch_lightning as pl


class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ["device", "dtype"]
Expand Down Expand Up @@ -177,7 +179,9 @@ def __update_properties(
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None:
if not isinstance(module, DeviceDtypeModuleMixin):
# TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't
# work when using `init_meta_context`.
if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)):
return
if device is not None:
module._device = device
Expand Down
20 changes: 12 additions & 8 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,9 @@ def on_trainer_init(
" bar pass `enable_progress_bar = False` to the Trainer."
)

if enable_progress_bar:
self.trainer._progress_bar_callback = self.configure_progress_bar(
progress_bar_refresh_rate, process_position
)
else:
self.trainer._progress_bar_callback = None
self.trainer._progress_bar_callback = self.configure_progress_bar(
progress_bar_refresh_rate, process_position, enable_progress_bar
)

# configure the ModelSummary callback
self._configure_model_summary_callback(enable_model_summary, weights_summary)
Expand Down Expand Up @@ -215,7 +212,9 @@ def _configure_swa_callbacks(self):
if not existing_swa:
self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks

def configure_progress_bar(self, refresh_rate=None, process_position=0):
def configure_progress_bar(
self, refresh_rate: Optional[int] = None, process_position: int = 0, enable_progress_bar: bool = True
) -> Optional[ProgressBarBase]:
if os.getenv("COLAB_GPU") and refresh_rate is None:
# smaller refresh rate on colab causes crashes, choose a higher value
refresh_rate = 20
Expand All @@ -229,7 +228,12 @@ def configure_progress_bar(self, refresh_rate=None, process_position=0):
)
if len(progress_bars) == 1:
progress_bar_callback = progress_bars[0]
elif refresh_rate > 0:
if not enable_progress_bar:
raise MisconfigurationException(
"Trainer was configured with `enable_progress_bar=False`"
f" but found `{progress_bar_callback.__class__.__name__}` in callbacks list."
)
elif refresh_rate > 0 and enable_progress_bar:
progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position)
self.trainer.callbacks.append(progress_bar_callback)
else:
Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.meta import materialize_module
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import (
Expand Down Expand Up @@ -1406,10 +1406,21 @@ def _call_setup_hook(self) -> None:

def _call_configure_sharded_model(self) -> None:
with self.accelerator.model_sharded_context():
materialize_module(self.lightning_module)
self._handle_meta_model()
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")

def _handle_meta_model(self) -> None:
if not is_on_meta_device(self.lightning_module):
return

if isinstance(self.training_type_plugin, DDPSpawnPlugin):
raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.")

materialize_module(self.lightning_module)
# the trainer reference is lost during materialization
self.lightning_module.trainer = proxy(self)

def _call_teardown_hook(self) -> None:
fn = self.state.fn._setup_fn

Expand Down
46 changes: 30 additions & 16 deletions pytorch_lightning/utilities/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from functools import partial
from itertools import chain
from types import ModuleType
from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type

import torch
from torch import nn, Tensor
from torch.nn import Module
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10
Expand Down Expand Up @@ -191,7 +192,6 @@ def materialize_module(root_module: nn.Module) -> nn.Module:

# cache subclasses to optimize the search when resetting the meta device later on.
__STORAGE_META__ = {}

__CREATED_MODULES__ = set()


Expand Down Expand Up @@ -237,45 +237,52 @@ def _set_meta_device() -> None:

for subclass in get_all_subclasses(torch.nn.modules.module.Module):

if isinstance(subclass, (Sequential, ModuleList, ModuleDict)):
if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule):
continue

# if a subclass has already been stored, we should use the cache
if str(subclass) in __STORAGE_META__:
# reset the class import package to its rightfull state.
# reset the class import package to its rightful state.
mods, subclass, meta_class = __STORAGE_META__[subclass]
for mod in mods:
setattr(mod, subclass.__name__, meta_class)
continue

class _IsinstanceMetaclass(type(subclass)):
def __instancecheck__(self, instance: Any) -> bool:
"""Overrides the ``isinstance`` check on ``_MaterializerModule`` objects."""
return isinstance(instance, self.__bases__[0])

# Create a class subclassing current `subclass` overriding its new method.
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
# version of the current subclass module
class _MetaClass(subclass):
class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass):
@classmethod
@contextmanager
def instantiation_context(cls, materialize: bool):
def instantiation_context(cls):
_unset_meta_device(from_created=True)
yield
_set_meta_device_populated(from_created=True)

@classmethod
def materialize(cls, materialize_fn: Callable):
with cls.instantiation_context(materialize=True):
with cls.instantiation_context():
obj = materialize_fn()
return obj

@staticmethod
def add_subclasses(subclass):
"""This is used to unrol the instantion tree while creating the modules."""
__CREATED_MODULES__.add(subclass)
"""This is used to unroll the instantiation tree while creating the modules."""
# Don't store the LightningModule as skipped from the Meta process.
if subclass != pl.LightningModule:
__CREATED_MODULES__.add(subclass)
if subclass.__bases__[0] != torch.nn.modules.module.Module:
_MetaClass.add_subclasses(subclass.__bases__[0])
_MaterializerModule.add_subclasses(subclass.__bases__[0])

def __new__(cls, *args, **kwargs):
subclass = cls.__bases__[0]
cls.add_subclasses(subclass)
with cls.instantiation_context(materialize=False):
with cls.instantiation_context():
obj = init_meta(subclass, *args, **kwargs)

obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize)
Expand All @@ -294,9 +301,8 @@ def search(mod: ModuleType) -> List[ModuleType]:
# nn.Module class can be imported at different level and they all need to be mocked.
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
# needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass
out = []
out.append(search(mod))
# needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule
out = [search(mod)]
for name in submodules[1:]:
mod = getattr(mod, name)
out.append(search(mod))
Expand All @@ -305,11 +311,11 @@ def search(mod: ModuleType) -> List[ModuleType]:
mods = [mod for mod in chain(*out) if mod]

# store the modules search so it doesn't have to be performed again for this class
__STORAGE_META__[subclass] = (mods, subclass, _MetaClass)
__STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule)

# replace all subclass by its meta form
for mod in mods:
setattr(mod, subclass.__name__, _MetaClass)
setattr(mod, subclass.__name__, _MaterializerModule)


@contextmanager
Expand All @@ -321,3 +327,11 @@ def init_meta_context() -> Generator:
_set_meta_device()
yield
_unset_meta_device()


def is_on_meta_device(module: nn.Module) -> bool:
try:
param = next(module.parameters())
return param.device.type == "meta"
except StopIteration:
return False
Loading

0 comments on commit 7d9dd6f

Please sign in to comment.