Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Base classes for accelerator refactoring #5715

Merged
merged 22 commits into from
Jan 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))


- Refactored Accelerators and Plugins (
[#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715),
)


### Deprecated

- `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down
375 changes: 375 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,375 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Iterable, Optional, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins import TrainingTypePlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import LightningEnum


class Accelerator(object):
"""
The Accelerator Base Class.
An Accelerator is meant to deal with one type of Hardware.

Currently there are accelerators for:
- CPU
- GPU
- TPU

Each Accelerator gets two plugins upon initialization:
One to handle differences from the training routine and one to handle different precisions.

"""

def __init__(
self,
precision_plugin, #: PrecisionPlugin # fixme
training_type_plugin: TrainingTypePlugin,
) -> None:
"""

Args:
precision_plugin: the plugin to handle precision-specific parts
training_type_plugin: the plugin to handle different training routines
"""
self.precision_plugin = precision_plugin
self.training_type_plugin = training_type_plugin

self.optimizers = None
self.lr_schedulers = None
self.optimizer_frequencies = None

def setup(self, trainer: "Trainer", model: LightningModule) -> None:
"""
Connects the plugins to the training process, creates optimizers

Args:
trainer: the trainer instance to connect to
model: the model to train
"""
self.connect_training_type_plugin(self.training_type_plugin, model)
self.setup_optimizers(trainer, model)
self.connect_precision_plugin(self.precision_plugin)
self.optimizers = trainer.convert_to_lightning_optimizers(self.optimizers)

@property
def model(self) -> torch.nn.Module:
"""Returns the model. This can also be a wrapped LightningModule.
For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module`

"""
return self.training_type_plugin.model
Borda marked this conversation as resolved.
Show resolved Hide resolved

@model.setter
def model(self, new_model: torch.nn.Module) -> None:
self.training_type_plugin.model = new_model

@property
def lightning_module(self) -> LightningModule:
"""Returns the pure LightningModule.
To get the potentially wrapped model use :attr:`Accelerator.model`

"""
return self.training_type_plugin.lightning_module

@property
def root_device(self) -> torch.device:
return self.training_type_plugin.root_device

def teardown(self):
"""This method is called to teardown the training process.
It is the right place to release memory and free other ressources.
"""
pass

def batch_to_device(self, batch: Any, device: torch.device) -> Any:
"""Moves the batch to the correct device.
The returned batch is of the same type as the input batch, just having all tensors on the correct device.

Args:
batch: The batch of samples to move to the correct device
device: The target device
"""
model = self.lightning_module
if model is not None:
return model.transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)

def on_train_start(self):
"""Hook to do something upon the training start"""
pass

def training_step(self, args):
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""The actual training step.

Args:
args: the arguments for the models training step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): Integer displaying index of this batch
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.

"""
batch = self.to_device(args[0])

args[0] = batch

with self.precision_plugin.train_step_context():
with self.training_type_plugin.train_step_context():
return self.lightning_module.training_step(*args)

def validation_step(self, args):
"""The actual validation step.

Args:
args: the arguments for the models validation step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): The index of this batch
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple val dataloaders used)
"""
batch = self.to_device(args[0])

args[0] = batch

with self.precision_plugin.val_step_context():
with self.training_type_plugin.val_step_context():
return self.lightning_module.validation_step(*args)
Comment on lines +154 to +156
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type of logic makes me think the precision plugin should live in the training type plugin, and that the base training type plugin manage the precision logic

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge it as is and then change it later on :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed with @SeanNaren - that cleans up the interface for the accelerator then. conceptually the accelerator simply manages the training plugin with additional lifecycle hooks, and the training plugin is responsible for precision/whatever else(?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I also agree with @SeanNaren and there are plans to support that on top of the existing PoC since then this is probably easier to change then now.


def test_step(self, args):
"""The actual test step.

Args:
args: the arguments for the models test step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): The index of this batch.
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple test dataloaders used).
"""
batch = self.to_device(args[0])

args[0] = batch

with self.precision_plugin.test_step_context():
with self.training_type_plugin.test_step_context():
return self.lightning_module.test_step(*args)

def training_step_end(self, output):
"""A hook to do something at the end of the training step

Args:
output: the output of the training step
"""
return output

def test_step_end(self, output):
"""A hook to do something at the end of the test step

Args:
output: the output of the test step
"""
return output

def validation_step_end(self, output):
"""A hook to do something at the end of the validation step

Args:
output: the output of the validation step
"""
return output

def process_dataloader(
self, dataloader: Union[Iterable, torch.utils.data.DataLoader]
) -> Union[Iterable, torch.utils.data.DataLoader]:
"""Wraps the dataloader if necessary

Args:
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
"""
return dataloader

def backward(
self,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
) -> torch.Tensor:
"""Forwards backward-calls to the precision plugin.

Args:
closure_loss: a tensor holding the loss value to backpropagate
optimizer: the optimizer to do the step later on.
opt_idx: the index of the optimizer
should_accumulate: whether to accumulate gradients
"""
output = self.precision_plugin.backward(
self.lightning_module, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs
)

# TODO: this is a hack, find a better solution for this (hook?)
# fixme: uncomment when this class is added
# if isinstance(self.training_type_plugin, HorovodPlugin):
# optimizer.synchronize()

return output

def optimizer_step(
self,
optimizer: torch.optim.Optimizer,
current_epoch: int,
batch_idx: int,
opt_idx: int,
lambda_closure: Callable,
):
"""performs the actual optimizer step.

Args:
optimizer: the optimizer performing the step
current_epoch: current training epoch
batch_idx: index of the current batch
opt_idx: index of the current optimizer
lambda_closure: closure calculating the loss value

"""
model_ref = self.lightning_module
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
# fixme: uncomment when this class is added
# is_native_amp = (
# isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE
# )
is_native_amp = False

self.precision_plugin.pre_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx)

# model hook
res = model_ref.optimizer_step(
epoch=current_epoch,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=opt_idx,
optimizer_closure=lambda_closure,
on_tpu=False, # TPUAccelerator class sets this as True
using_native_amp=is_native_amp,
using_lbfgs=is_lbfgs,
)

self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx)
return res

def optimizer_zero_grad(
self, current_epoch: int, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int
) -> None:
"""Zeros all model parameter's gradients"""
model_ref = self.lightning_module
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)

def clip_gradients(self, optimizer: torch.optim.Optimizer, clip_val: Union[int, float]) -> None:
"""clips all the optimizer parameters to the given value"""

self.precision_plugin.clip_gradients(optimizer, clip_val)

def on_train_epoch_end(self, outputs) -> None:
"""Hook to do something on the end of an training epoch

Args:
outputs: the outputs of the training steps
"""
pass

def on_train_end(self) -> None:
"""Hook to do something at the end of the training"""
pass

def setup_optimizers(self, trainer: "Trainer", model: LightningModule):
"""creates optimizers and schedulers

Args:
trainer: the Trainer, these optimizers should be connected to
model: the model to be optimized by the created optimizers
"""
if trainer.testing is True:
return
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(model)
self.optimizers = optimizers
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies

def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
"""Attaches the training type plugin to the accelerator.
Also transfers ownership of the model to this plugin

"""
plugin.connect(model)

def connect_precision_plugin(self, plugin): #: PrecisionPlugin # fixme
"""Attaches the precision plugin to the accelerator"""
model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers)
self.model = model
self.optimizers = optimizers
self.schedulers = schedulers

def to_device(self, batch: Any) -> Any:
"""Pushes the batch to the root device"""
return self.batch_to_device(batch, self.root_device)

@property
def amp_backend(self) -> Optional[LightningEnum]:
# fixme: uncomment when this class is added
# if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
# return AMPType.APEX
# elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
# return AMPType.NATIVE
# return None
pass

@property
def precision(self) -> int:
return self.precision_plugin.precision

@property
def scaler(self):
if hasattr(self.precision_plugin, "scaler"):
return self.precision_plugin.scaler

return None

@property
def rpc_enabled(self) -> bool:
return self.training_type_plugin.rpc_enabled

def optimizer_state(self, optimizer: Optimizer) -> dict:
"""
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
plugins.
"""
if self.training_type_plugin and hasattr(self.training_type_plugin, "optimizer_state"):
return self.training_type_plugin.optimizer_state(optimizer)
return optimizer.state_dict()

def on_save(self, checkpoint):
return checkpoint
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401
Loading