From eaf691435d79c3b653faa848ecc9f8648403cf5b Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Thu, 12 Nov 2020 17:18:09 +0000 Subject: [PATCH] Sharded Accelerator 1/n: Expose clip gradients to plugins via abstract class (#4639) * Added abstract precision plugin to expose clip_gradients function, use within accelerator to clip gradients * Exclude model from override, keep optimizer (needed for sharded clip gradients), add override for O2 support apex * Fix doc * Applied codereview changes * Refactored clip function to encapsulate tpu changes with tpu accelerator. Default to standard clip function for vanilla torch * Pass correct grad clip val * Moved var to property * Apply code review suggestions (cherry picked from commit bacabaebaf16b0492cf9090b75238215c2c19de5) --- pytorch_lightning/accelerators/accelerator.py | 52 ++++--------------- .../accelerators/tpu_accelerator.py | 31 +++++++++-- pytorch_lightning/plugins/apex.py | 39 +++++++++++++- pytorch_lightning/plugins/native_amp.py | 10 +++- pytorch_lightning/plugins/precision_plugin.py | 38 ++++++++++++++ 5 files changed, 120 insertions(+), 50 deletions(-) create mode 100644 pytorch_lightning/plugins/precision_plugin.py diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 3b762e08ed5e6..a0d8f6f21a2f7 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -12,33 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import math from enum import Enum from typing import Any, Optional, Union import torch +from torch.optim import Optimizer -from pytorch_lightning.utilities import AMPType, rank_zero_warn +from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict import torch.distributed as torch_distrib from pytorch_lightning import _logger as log -try: - from apex import amp -except ImportError: - amp = None - if torch.distributed.is_available(): from torch.distributed import ReduceOp else: class ReduceOp: SUM = None -EPSILON = 1e-6 -EPSILON_FP16 = 1e-5 - class Accelerator(object): @@ -139,48 +131,22 @@ def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) def clip_gradients(self, optimizer, clip_val=None): - # TODO: separate TPU case from here - self._clip_gradients(optimizer, clip_val) - - def _clip_gradients(self, optimizer, clip_val=None): # use the trainer's clip val if none passed grad_clip_val = self.trainer.gradient_clip_val if clip_val is not None: grad_clip_val = clip_val grad_clip_val = float(grad_clip_val) - # this code is a modification of torch.nn.utils.clip_grad_norm_ - # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md if grad_clip_val <= 0: return + self._clip_gradients(optimizer, grad_clip_val) - model = self.trainer.get_model() - if self.trainer.amp_backend == AMPType.APEX: - parameters = amp.master_params(optimizer) + def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): + if self.trainer.amp_backend: + self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer, norm_type) else: - parameters = model.parameters() - - max_norm = grad_clip_val - norm_type = float(2.0) - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - - if norm_type == math.inf: - total_norm = max(p.grad.data.abs().max() for p in parameters) - else: - device = parameters[0].device - out = torch.empty(len(parameters), device=device) - for i, p in enumerate(parameters): - torch.norm(p.grad.data.to(device), norm_type, out=out[i]) - total_norm = torch.norm(out, norm_type) - - eps = EPSILON_FP16 if self.trainer.precision == 16 else EPSILON - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) - clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) - for p in parameters: - p.grad.data.mul_(clip_coef.to(p.grad.data.device)) + model = self.trainer.get_model() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type) def on_train_epoch_end(self, outputs): pass @@ -201,7 +167,7 @@ def setup_optimizers(self, model): self.trainer.optimizer_frequencies = optimizer_frequencies def init_ddp_connection( - self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True + self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True ) -> None: os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 15386b133f8bd..54ee57b74a16a 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import io +import math import os import re from typing import Optional, Union, Any import torch import torch.multiprocessing as mp +from torch.optim import Optimizer from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp @@ -261,10 +263,27 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure): using_lbfgs=is_lbfgs ) - def clip_gradients(self, optimizer, clip_val=None): - # apply clip gradients - # TODO: separate TPU case from here - self._clip_gradients(optimizer, clip_val) + def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): + # this code is a modification of torch.nn.utils.clip_grad_norm_ + # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md + model = self.trainer.get_model() + parameters = model.parameters() + max_norm = grad_clip_val + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + + device = parameters[0].device + out = torch.empty(len(parameters), device=device) + for i, p in enumerate(parameters): + torch.norm(p.grad.data.to(device), norm_type, out=out[i]) + total_norm = torch.norm(out, norm_type) + + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon) + clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) + for p in parameters: + p.grad.data.mul_(clip_coef.to(p.grad.data.device)) def barrier(self, name: Optional[str] = None): torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}") @@ -343,3 +362,7 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return tensor + + @property + def norm_clipping_epsilon(self): + return 1e-6 diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/apex.py index 0c8665e3719f3..654f7202fb9d1 100644 --- a/pytorch_lightning/plugins/apex.py +++ b/pytorch_lightning/plugins/apex.py @@ -11,11 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +import math +from typing import List, Tuple, Union +import torch from torch.optim.optimizer import Optimizer from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities import AMPType @@ -25,7 +28,7 @@ amp = None -class ApexPlugin: +class ApexPlugin(PrecisionPlugin): def __init__(self, trainer=None): self.trainer = trainer @@ -98,3 +101,35 @@ def configure_apex(self, amp, model, optimizers, amp_level): """ model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) return model, optimizers + + def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): + """ + This code is a modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights. + This is important when setting amp_level to O2, and the master weights are in fp16. + Args: + grad_clip_val: Maximum norm of gradients. + optimizer: Optimizer with gradients that will be clipped. + norm_type: (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + """ + model = self.trainer.get_model() + parameters = model.parameters() + max_norm = float(grad_clip_val) + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + total_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + clip_coef = max_norm / (total_norm + self.norm_clipping_epsilon) + if clip_coef < 1: + for p in parameters: + p.grad.detach().mul_(clip_coef.to(p.grad.device)) + + @property + def norm_clipping_epsilon(self): + return 1e-5 diff --git a/pytorch_lightning/plugins/native_amp.py b/pytorch_lightning/plugins/native_amp.py index 98bc8dfc87d25..1a6649986132c 100644 --- a/pytorch_lightning/plugins/native_amp.py +++ b/pytorch_lightning/plugins/native_amp.py @@ -11,11 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union import torch +from torch.optim import Optimizer +from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin -class NativeAMPPlugin: + +class NativeAMPPlugin(PrecisionPlugin): def __init__(self, trainer=None): """ @@ -51,3 +55,7 @@ def training_step(self, fx, args): with torch.cuda.amp.autocast(): output = fx(*args) return output + + def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): + model = self.trainer.get_model() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type) diff --git a/pytorch_lightning/plugins/precision_plugin.py b/pytorch_lightning/plugins/precision_plugin.py new file mode 100644 index 0000000000000..0102f677391ff --- /dev/null +++ b/pytorch_lightning/plugins/precision_plugin.py @@ -0,0 +1,38 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +from typing import Union + +from torch.optim import Optimizer + + +class PrecisionPlugin(abc.ABC): + """ + Abstract class to extend for precision support (32/16 etc). + + This is extended to cover any specific logic required for precision support such as AMP/APEX or sharded + training. + """ + + def connect(self, model, optimizers): + raise NotImplementedError + + def training_step(self, fx, args): + raise NotImplementedError + + def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): + raise NotImplementedError + + def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): + raise NotImplementedError