Skip to content

Commit

Permalink
[Model Parallel] Add configure sharded model hook (#6679)
Browse files Browse the repository at this point in the history
* Add base hook for model parallel

* fix callback signature

* Simplify hook

* Add hook logic

* add tests

* add property setter

* add logic for being called once

* Update changelog

* Fix

* fix return type

* fix lambda callback test

* Fix tests

* Apply code suggestions

* add logic for setup_optimizers_predispatch

* add common dummy model

* Swap call order

* Remove test that isn't needed anymore

* Update tests

* Add a bit more doc

* Few code review fixes

* Update pytorch_lightning/accelerators/accelerator.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Change hook name

* Fix test

* Test setup hook, refactor names

* Swap call order of callbacks and model initialization

* Change name of context manager

Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
4 people authored Mar 29, 2021
1 parent 646cf2f commit f79a13e
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))


- Added `configure_sharded_model` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679))


- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))


Expand Down
44 changes: 43 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union
import contextlib
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer
Expand Down Expand Up @@ -439,6 +440,18 @@ def results(self) -> Any:
"""
return self.training_type_plugin.results

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
"""
Provide hook to create modules in a distributed aware context. This is useful for when we'd like to
shard the model instantly - useful for extremely large models. Can save memory and
initialization time.
Returns: Model parallel context.
"""
with self.training_type_plugin.model_sharded_context():
yield

# todo: remove in v1.5
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
"""
Expand Down Expand Up @@ -468,4 +481,33 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
self.setup_precision_plugin(plugin)

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""
self.training_type_plugin.save_checkpoint(checkpoint, filepath)

@property
def call_configure_sharded_model_hook(self) -> bool:
"""
Allow model parallel hook to be called in suitable environments determined by the training type plugin.
This is useful for when we want to shard the model once within fit.
Returns: True if we want to call the model parallel setup hook.
"""
return self.training_type_plugin.call_configure_sharded_model_hook

@call_configure_sharded_model_hook.setter
def call_configure_sharded_model_hook(self, mode: bool) -> None:
self.training_type_plugin.call_configure_sharded_model_hook = mode

@property
def setup_optimizers_in_pre_dispatch(self) -> bool:
"""
Override to delay setting optimizers and schedulers till after dispatch.
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
However this may break certain precision plugins such as APEX which require optimizers to be set.
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
"""
return self.training_type_plugin.setup_optimizers_in_pre_dispatch
3 changes: 3 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class Callback(abc.ABC):
Subclass this class and override any of the relevant hooks
"""

def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None:
"""Called before configure sharded model"""

def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule) -> None:
"""Called before accelerator is being setup"""
pass
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
self,
on_before_accelerator_backend_setup: Optional[Callable] = None,
setup: Optional[Callable] = None,
on_configure_sharded_model: Optional[Callable] = None,
teardown: Optional[Callable] = None,
on_init_start: Optional[Callable] = None,
on_init_end: Optional[Callable] = None,
Expand Down Expand Up @@ -83,6 +84,8 @@ def __init__(
self.on_before_accelerator_backend_setup = on_before_accelerator_backend_setup
if setup is not None:
self.setup = setup
if on_configure_sharded_model is not None:
self.on_configure_sharded_model = on_configure_sharded_model
if teardown is not None:
self.teardown = teardown
if on_init_start is not None:
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,20 @@ def on_post_move_to_device(self):
"""

def configure_sharded_model(self) -> None:
"""
Hook to create modules in a distributed aware context. This is useful for when using sharded plugins,
where we'd like to shard the model instantly, which is useful for extremely large models
which can save memory and initialization time.
The accelerator manages whether to call this hook at every given stage.
For sharded plugins where model parallelism is required, the hook is usually on called once
to initialize the sharded parameters, and not called again in the same process.
By default for accelerators/plugins that do not use model sharding techniques,
this hook is called during each fit/val/test/predict stages.
"""


class DataHooks:
"""Hooks to be used for data related stuff."""
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,8 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
trainer: PyTorch Lightning Trainer
checkpoint: dict containing model and trainer state
filepath: write-target file's path
weights_only: saving model weights only
"""
# Todo: TypeError: 'mappingproxy' object does not support item assignment
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)
34 changes: 33 additions & 1 deletion pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Generator, Iterable, Optional, TYPE_CHECKING, Union

import torch
from torch.nn import Module
Expand All @@ -35,6 +36,7 @@ class TrainingTypePlugin(Plugin, ABC):
def __init__(self) -> None:
self._model = None
self._results = None
self._call_configure_sharded_model_hook = True

def connect(self, model: 'Module') -> None:
"""Called by the accelerator to connect the accelerator and the model with this plugin"""
Expand Down Expand Up @@ -196,6 +198,12 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
return False

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""
# dump states as a checkpoint dictionary object
if self.is_global_zero:
checkpoint = self.on_save(checkpoint)
Expand All @@ -210,3 +218,27 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
f' An attribute is not picklable {err}'
)
atomic_save(checkpoint, filepath)

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
"""
Provide hook to create modules in a distributed aware context. This is useful for when we'd like to
shard the model instantly, which is useful for extremely large models which can save memory and
initialization time.
Returns: Model parallel context.
"""
yield

@property
def call_configure_sharded_model_hook(self) -> bool:
"""
Allow model parallel hook to be called in suitable environments determined by the training type plugin.
This is useful for when we want to shard the model once within fit.
Returns: True if we want to call the model parallel setup hook.
"""
return self._call_configure_sharded_model_hook

@call_configure_sharded_model_hook.setter
def call_configure_sharded_model_hook(self, mode: bool) -> None:
self._call_configure_sharded_model_hook = mode
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def on_before_accelerator_backend_setup(self, model: LightningModule) -> None:
for callback in self.callbacks:
callback.on_before_accelerator_backend_setup(self, model)

def configure_sharded_model(self, model: LightningModule) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.on_configure_sharded_model(self, model)

def setup(self, model: LightningModule, stage: Optional[str]) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def _setup_log():
"""Called when fit or test begins"""
return None

@staticmethod
def _on_configure_sharded_model_log():
"""Called before configure sharded model"""
return None

@staticmethod
def _teardown_log():
"""Called at the end of fit and test"""
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def fit(
self.accelerator.connect(model)
self.accelerator.setup_environment()
self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment
self.call_configure_sharded_model(model) # allow user to setup in model sharded environment
self.accelerator.setup(self, model) # note: this sets up self.lightning_module

# ----------------------------
Expand Down Expand Up @@ -1082,6 +1083,15 @@ def call_setup_hook(self, model: LightningModule) -> None:
self.setup(model, stage=state)
model.setup(stage=state)

def call_configure_sharded_model(self, model: LightningModule) -> None:
# Call configure sharded model hook if accelerator requests. In some cases
# we will not call the hook; the hook has initialized the sharded model for example.
if self.accelerator.call_configure_sharded_model_hook:
with self.accelerator.model_sharded_context():
model.configure_sharded_model()
self.configure_sharded_model(model)
self.accelerator.call_configure_sharded_model_hook = False

def call_teardown_hook(self, model: LightningModule) -> None:
state = self._teardown_state

Expand Down
105 changes: 105 additions & 0 deletions tests/accelerators/test_common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import SingleDevicePlugin
from tests.accelerators.test_dp import CustomClassificationModelDP
from tests.helpers.boring_model import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -44,3 +59,93 @@ def test_evaluate(tmpdir, trainer_kwargs):
# make sure weights didn't change
new_weights = model.layer_0.weight.clone().detach().cpu()
torch.testing.assert_allclose(old_weights, new_weights)


def test_model_parallel_setup_called(tmpdir):

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.configure_sharded_model_called = False
self.layer = None

def configure_sharded_model(self):
self.configure_sharded_model_called = True
self.layer = torch.nn.Linear(32, 2)

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
)
trainer.fit(model)

assert model.configure_sharded_model_called


class DummyModel(BoringModel):

def __init__(self):
super().__init__()
self.configure_sharded_model_called = False

def configure_sharded_model(self):
self.configure_sharded_model_called = True


def test_configure_sharded_model_false(tmpdir):
"""Ensure ``configure_sharded_model`` is not called, when turned off"""

class CustomPlugin(SingleDevicePlugin):

@property
def call_configure_sharded_model_hook(self) -> bool:
return False

model = DummyModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
plugins=CustomPlugin(device=torch.device("cpu"))
)
trainer.fit(model)

assert not model.configure_sharded_model_called


def test_accelerator_configure_sharded_model_called_once(tmpdir):
"""Ensure that the configure sharded model hook is called, and set to False after to ensure not called again."""

model = DummyModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
)
assert trainer.accelerator.call_configure_sharded_model_hook is True
trainer.fit(model)
assert trainer.accelerator.call_configure_sharded_model_hook is False


def test_configure_sharded_model_called_once(tmpdir):
"""Ensure ``configure_sharded_model`` is only called once"""

model = DummyModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
)
trainer.fit(model)

assert model.configure_sharded_model_called
model.configure_sharded_model_called = False

assert not model.configure_sharded_model_called
3 changes: 3 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'fit'),
call.on_configure_sharded_model(trainer, model),
call.on_fit_start(trainer, model),
call.on_pretrain_routine_start(trainer, model),
call.on_pretrain_routine_end(trainer, model),
Expand Down Expand Up @@ -119,6 +120,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'test'),
call.on_configure_sharded_model(trainer, model),
call.on_test_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_test_epoch_start(trainer, model),
Expand Down Expand Up @@ -153,6 +155,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'validate'),
call.on_configure_sharded_model(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def test_call_back_validator(tmpdir):
'on_epoch_end',
'on_epoch_start',
'on_fit_end',
'on_configure_sharded_model',
'on_fit_start',
'on_init_end',
'on_init_start',
Expand Down Expand Up @@ -317,6 +318,7 @@ def test_call_back_validator(tmpdir):
"on_before_accelerator_backend_setup",
"on_fit_end",
"on_fit_start",
"on_configure_sharded_model",
"on_init_end",
"on_init_start",
"on_keyboard_interrupt",
Expand Down

0 comments on commit f79a13e

Please sign in to comment.