Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharded Accelerator 1/n: Expose clip gradients to plugins via abstract class #4639

Merged
merged 10 commits into from
Nov 12, 2020
49 changes: 7 additions & 42 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
from enum import Enum
from typing import Any, Optional, Union

import torch

from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log

try:
from apex import amp
except ImportError:
amp = None

if torch.distributed.is_available():
from torch.distributed import ReduceOp
else:
class ReduceOp:
SUM = None

EPSILON = 1e-6
EPSILON_FP16 = 1e-5


class Accelerator(object):

Expand Down Expand Up @@ -139,48 +130,22 @@ def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)

def clip_gradients(self, optimizer, clip_val=None):
# TODO: separate TPU case from here
self._clip_gradients(optimizer, clip_val)

def _clip_gradients(self, optimizer, clip_val=None):
# use the trainer's clip val if none passed
grad_clip_val = self.trainer.gradient_clip_val
if clip_val is not None:
grad_clip_val = clip_val
grad_clip_val = float(grad_clip_val)

# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
if grad_clip_val <= 0:
return
self._clip_gradients(optimizer, grad_clip_val)

model = self.trainer.get_model()
if self.trainer.amp_backend == AMPType.APEX:
parameters = amp.master_params(optimizer)
def _clip_gradients(self, optimizer, grad_clip_val):
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
if self.trainer.amp_backend:
self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer)
else:
parameters = model.parameters()

max_norm = grad_clip_val
norm_type = float(2.0)

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))

if norm_type == math.inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
device = parameters[0].device
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)

eps = EPSILON_FP16 if self.trainer.precision == 16 else EPSILON
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
model = self.trainer.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=2.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would make sense to have a precision plugin for 32bit for the default case.
Then we wouldn't need this if/else block and could just call
self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer)
just as a refactoring idea for the future.

Copy link
Contributor

@awaelchli awaelchli Nov 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

talking about this block (github selected more):

def _clip_gradients(self, optimizer, grad_clip_val):
        if self.trainer.amp_backend:
            self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer)
        else:
            model = self.trainer.get_model()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=2.0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree fully, a precision 32 plugin makes the most sense because the accelerators shouldn't need to care about these details (It goes into @justusschock's refactor proposal of making sure the accelerators are hardware only).

Out of the scope of this PR, but definitely something low hanging to reduce liability of issues in the accelerators

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to do this in my refactor proposal :)


def on_train_epoch_end(self, outputs):
pass
Expand Down
34 changes: 30 additions & 4 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import math
import os
import re
from typing import Optional, Union, Any
Expand Down Expand Up @@ -261,10 +262,31 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
using_lbfgs=is_lbfgs
)

def clip_gradients(self, optimizer, clip_val=None):
# apply clip gradients
# TODO: separate TPU case from here
self._clip_gradients(optimizer, clip_val)
def _clip_gradients(self, optimizer, grad_clip_val):
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
model = self.trainer.get_model()
parameters = model.parameters()
max_norm = grad_clip_val
norm_type = 2.0
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))

if norm_type == math.inf:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
device = parameters[0].device
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)

clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))

def barrier(self, name: Optional[str] = None):
torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}")
Expand Down Expand Up @@ -343,3 +365,7 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return tensor

@property
def norm_clipping_epsilon(self):
return 1e-6
40 changes: 39 additions & 1 deletion pytorch_lightning/plugins/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Tuple

import torch
from torch.optim.optimizer import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities import AMPType

Expand All @@ -25,7 +28,7 @@
amp = None


class ApexPlugin:
class ApexPlugin(PrecisionPlugin):

def __init__(self, trainer=None):
self.trainer = trainer
Expand Down Expand Up @@ -98,3 +101,38 @@ def configure_apex(self, amp, model, optimizers, amp_level):
"""
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers

def clip_gradients(self, grad_clip_val, optimizer):
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
"""
This code is a modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights.
This is important when setting amp_level to O2, and the master weights are in fp16.
Args:
grad_clip_val: Maximum norm of gradients.
optimizer: Optimizer with gradients that will be clipped.
"""
model = self.trainer.get_model()
parameters = model.parameters()
max_norm = grad_clip_val
norm_type = 2.0

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == math.inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
clip_coef = max_norm / (total_norm + self.norm_clipping_epsilon)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef.to(p.grad.device))

@property
def norm_clipping_epsilon(self):
return 1e-5
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import torch

from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin

class NativeAMPPlugin:

class NativeAMPPlugin(PrecisionPlugin):

def __init__(self, trainer=None):
"""
Expand Down Expand Up @@ -51,3 +53,7 @@ def training_step(self, fx, args):
with torch.cuda.amp.autocast():
output = fx(*args)
return output

def clip_gradients(self, grad_clip_val, optimizer):
model = self.trainer.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=2.0)
35 changes: 35 additions & 0 deletions pytorch_lightning/plugins/precision_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved


class PrecisionPlugin(abc.ABC):
"""
Abstract class to extend for precision support (32/16 etc).

This is extended to cover any specific logic required for precision support such as AMP/APEX or sharded
training.
"""

def connect(self, model, optimizers):
raise NotImplementedError

def training_step(self, fx, args):
raise NotImplementedError

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
raise NotImplementedError

def clip_gradients(self, grad_clip_val, optimizer):
raise NotImplementedError