Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model Parallel] Add configure sharded model hook #6679

Merged
merged 30 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9f8864f
Add base hook for model parallel
Mar 23, 2021
eac5344
fix callback signature
kaushikb11 Mar 25, 2021
32df0cb
Simplify hook
Mar 25, 2021
282a133
Add hook logic
Mar 25, 2021
7a94e72
add tests
kaushikb11 Mar 25, 2021
8091481
add property setter
kaushikb11 Mar 25, 2021
633fc77
add logic for being called once
kaushikb11 Mar 25, 2021
c99a36f
Update changelog
kaushikb11 Mar 25, 2021
a68c8d7
Merge branch 'master' into feat/model_parallel_hook
kaushikb11 Mar 25, 2021
9529a22
Fix
kaushikb11 Mar 25, 2021
3c1c782
fix return type
kaushikb11 Mar 25, 2021
a49ec3b
fix lambda callback test
kaushikb11 Mar 25, 2021
4dd55d7
Fix tests
kaushikb11 Mar 25, 2021
caad43c
Apply code suggestions
kaushikb11 Mar 25, 2021
a2574be
add logic for setup_optimizers_predispatch
kaushikb11 Mar 25, 2021
8c2bd6a
add common dummy model
kaushikb11 Mar 25, 2021
3240569
Swap call order
Mar 25, 2021
897bdbb
Remove test that isn't needed anymore
Mar 25, 2021
626fc7b
Update tests
kaushikb11 Mar 26, 2021
e94a7ae
Add a bit more doc
Mar 26, 2021
6a38417
Merge branch 'master' into feat/model_parallel_hook
Mar 26, 2021
202ef1a
Few code review fixes
Mar 29, 2021
0709baa
Update pytorch_lightning/accelerators/accelerator.py
SeanNaren Mar 29, 2021
9152d08
Change hook name
Mar 29, 2021
fbfe65f
Fix test
Mar 29, 2021
bae858f
Test setup hook, refactor names
Mar 29, 2021
41e9c22
Swap call order of callbacks and model initialization
Mar 29, 2021
76c7376
Change name of context manager
Mar 29, 2021
2dcafd0
Merge branch 'master' into feat/model_parallel_hook
Mar 29, 2021
aa35583
add docstring
tchaton Mar 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))


- Added `configure_sharded_model` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679))


- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))


Expand Down
38 changes: 37 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union
import contextlib
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer
Expand Down Expand Up @@ -439,6 +440,18 @@ def results(self) -> Any:
"""
return self.training_type_plugin.results

@contextlib.contextmanager
def model_parallel_context(self) -> Generator:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
"""
Provide hook to create modules in a distributed aware context. This is useful for when we'd like to
shard the model instantly - useful for extremely large models. Can save memory and
initialization time.

Returns: Model parallel context.
"""
with self.training_type_plugin.model_parallel_context():
yield

# todo: remove in v1.5
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
"""
Expand Down Expand Up @@ -466,3 +479,26 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
' It will be removed in v1.5.'
)
self.setup_precision_plugin(plugin)

@property
def call_configure_sharded_model_hook(self) -> bool:
"""
Allow model parallel hook to be called in suitable environments determined by the training type plugin.
This is useful for when we want to shard the model once within fit.
Returns: True if we want to call the model parallel setup hook.
"""
return self.training_type_plugin.call_configure_sharded_model_hook

@call_configure_sharded_model_hook.setter
def call_configure_sharded_model_hook(self, mode: bool) -> None:
self.training_type_plugin.call_configure_sharded_model_hook = mode

@property
def setup_optimizers_in_pre_dispatch(self) -> bool:
"""
Override to delay setting optimizers and schedulers till after dispatch.
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
However this may break certain precision plugins such as APEX which require optimizers to be set.
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
"""
return self.training_type_plugin.setup_optimizers_in_pre_dispatch
3 changes: 3 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class Callback(abc.ABC):
Subclass this class and override any of the relevant hooks
"""

def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None:
"""Called before configure sharded model"""

def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule) -> None:
"""Called before accelerator is being setup"""
pass
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
self,
on_before_accelerator_backend_setup: Optional[Callable] = None,
setup: Optional[Callable] = None,
on_configure_sharded_model: Optional[Callable] = None,
teardown: Optional[Callable] = None,
on_init_start: Optional[Callable] = None,
on_init_end: Optional[Callable] = None,
Expand Down Expand Up @@ -83,6 +84,8 @@ def __init__(
self.on_before_accelerator_backend_setup = on_before_accelerator_backend_setup
if setup is not None:
self.setup = setup
if on_configure_sharded_model is not None:
self.on_configure_sharded_model = on_configure_sharded_model
if teardown is not None:
self.teardown = teardown
if on_init_start is not None:
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,20 @@ def on_post_move_to_device(self):

"""

def configure_sharded_model(self) -> None:
"""
Hook to create modules in a distributed aware context. This is useful for when using sharded plugins,
where we'd like to shard the model instantly, which is useful for extremely large models
which can save memory and initialization time.

The accelerator manages whether to call this hook at every given stage.
For sharded plugins where model parallelism is required, the hook is usually on called once
to initialize the sharded parameters, and not called again in the same process.

By default for accelerators/plugins that do not use model sharding techniques,
this hook is called during each fit/val/test/predict stages.
"""


class DataHooks:
"""Hooks to be used for data related stuff."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Generator, Iterable, Optional, TYPE_CHECKING, Union

import torch
from torch.nn import Module
Expand All @@ -33,6 +34,7 @@ class TrainingTypePlugin(Plugin, ABC):
def __init__(self) -> None:
self._model = None
self._results = None
self._call_configure_sharded_model_hook = True

def connect(self, model: 'Module') -> None:
"""Called by the accelerator to connect the accelerator and the model with this plugin"""
Expand Down Expand Up @@ -192,3 +194,27 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
"""
return False

@contextlib.contextmanager
def model_parallel_context(self) -> Generator:
"""
Provide hook to create modules in a distributed aware context. This is useful for when we'd like to
shard the model instantly, which is useful for extremely large models which can save memory and
initialization time.

Returns: Model parallel context.
"""
yield

@property
def call_configure_sharded_model_hook(self) -> bool:
"""
Allow model parallel hook to be called in suitable environments determined by the training type plugin.
This is useful for when we want to shard the model once within fit.
Returns: True if we want to call the model parallel setup hook.
"""
return self._call_configure_sharded_model_hook

@call_configure_sharded_model_hook.setter
def call_configure_sharded_model_hook(self, mode: bool) -> None:
self._call_configure_sharded_model_hook = mode
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def on_before_accelerator_backend_setup(self, model: LightningModule) -> None:
for callback in self.callbacks:
callback.on_before_accelerator_backend_setup(self, model)

def configure_sharded_model(self, model: LightningModule) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.on_configure_sharded_model(self, model)

def setup(self, model: LightningModule, stage: Optional[str]) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def _setup_log():
"""Called when fit or test begins"""
return None

@staticmethod
def _on_configure_sharded_model_log():
"""Called before configure sharded model"""
return None

@staticmethod
def _teardown_log():
"""Called at the end of fit and test"""
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def fit(
self.accelerator.connect(model)
self.accelerator.setup_environment()
self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment
self.call_configure_sharded_model(model) # allow user to setup in model sharded environment
self.accelerator.setup(self, model) # note: this sets up self.lightning_module

# ----------------------------
Expand Down Expand Up @@ -1075,6 +1076,15 @@ def call_setup_hook(self, model: LightningModule) -> None:
self.setup(model, stage=state)
model.setup(stage=state)

def call_configure_sharded_model(self, model: LightningModule) -> None:
# Call configure sharded model hook if accelerator requests. In some cases
# we will not call the hook; the hook has initialized the sharded model for example.
if self.accelerator.call_configure_sharded_model_hook:
with self.accelerator.model_parallel_context():
model.configure_sharded_model()
self.configure_sharded_model(model)
self.accelerator.call_configure_sharded_model_hook = False

def call_teardown_hook(self, model: LightningModule) -> None:
state = self._teardown_state

Expand Down
105 changes: 105 additions & 0 deletions tests/accelerators/test_common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import SingleDevicePlugin
from tests.accelerators.test_dp import CustomClassificationModelDP
from tests.helpers.boring_model import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -44,3 +59,93 @@ def test_evaluate(tmpdir, trainer_kwargs):
# make sure weights didn't change
new_weights = model.layer_0.weight.clone().detach().cpu()
torch.testing.assert_allclose(old_weights, new_weights)


def test_model_parallel_setup_called(tmpdir):

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.configure_sharded_model_called = False
self.layer = None

def configure_sharded_model(self):
self.configure_sharded_model_called = True
self.layer = torch.nn.Linear(32, 2)

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
)
trainer.fit(model)

assert model.configure_sharded_model_called


class DummyModel(BoringModel):

def __init__(self):
super().__init__()
self.configure_sharded_model_called = False

def configure_sharded_model(self):
self.configure_sharded_model_called = True


def test_configure_sharded_model_false(tmpdir):
"""Ensure ``configure_sharded_model`` is not called, when turned off"""

class CustomPlugin(SingleDevicePlugin):

@property
def call_configure_sharded_model_hook(self) -> bool:
return False

model = DummyModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
plugins=CustomPlugin(device=torch.device("cpu"))
)
trainer.fit(model)

assert not model.configure_sharded_model_called


def test_accelerator_configure_sharded_model_called_once(tmpdir):
"""Ensure that the configure sharded model hook is called, and set to False after to ensure not called again."""

model = DummyModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
)
assert trainer.accelerator.call_configure_sharded_model_hook is True
trainer.fit(model)
assert trainer.accelerator.call_configure_sharded_model_hook is False


def test_configure_sharded_model_called_once(tmpdir):
"""Ensure ``configure_sharded_model`` is only called once"""

model = DummyModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
)
trainer.fit(model)

assert model.configure_sharded_model_called
model.configure_sharded_model_called = False

assert not model.configure_sharded_model_called
3 changes: 3 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'fit'),
call.on_configure_sharded_model(trainer, model),
call.on_fit_start(trainer, model),
call.on_pretrain_routine_start(trainer, model),
call.on_pretrain_routine_end(trainer, model),
Expand Down Expand Up @@ -119,6 +120,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'test'),
call.on_configure_sharded_model(trainer, model),
call.on_test_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_test_epoch_start(trainer, model),
Expand Down Expand Up @@ -153,6 +155,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'validate'),
call.on_configure_sharded_model(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def test_call_back_validator(tmpdir):
'on_epoch_end',
'on_epoch_start',
'on_fit_end',
'on_configure_sharded_model',
'on_fit_start',
'on_init_end',
'on_init_start',
Expand Down Expand Up @@ -316,6 +317,7 @@ def test_call_back_validator(tmpdir):
"on_before_accelerator_backend_setup",
"on_fit_end",
"on_fit_start",
"on_configure_sharded_model",
"on_init_end",
"on_init_start",
"on_keyboard_interrupt",
Expand Down