diff --git a/ignite/engine/__init__.py b/ignite/engine/__init__.py index ba8aae0465d..abe487c7299 100644 --- a/ignite/engine/__init__.py +++ b/ignite/engine/__init__.py @@ -1,7 +1,9 @@ +import math from collections.abc import Mapping from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import torch +from torch.nn.utils.clip_grad import clip_grad_norm_ import ignite.distributed as idist from ignite.engine.deterministic import DeterministicEngine @@ -31,9 +33,7 @@ def _prepare_batch( batch: Sequence[torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False ) -> Tuple[Union[torch.Tensor, Sequence, Mapping, str, bytes], ...]: - """Prepare batch for training: pass to a device with options. - - """ + """Prepare batch for training: pass to a device with options.""" x, y = batch return ( convert_tensor(x, device=device, non_blocking=non_blocking), @@ -49,6 +49,7 @@ def supervised_training_step( non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable = lambda x, y, y_pred, loss: loss.item(), + gradient_clip: float = math.inf, ) -> Callable: """Factory function for supervised training. @@ -65,6 +66,8 @@ def supervised_training_step( tuple of tensors `(batch_x, batch_y)`. output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. + gradient_clip (float): max norm of the gradients. + (default: math.inf) Returns: Callable: update function. @@ -90,6 +93,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to y_pred = model(x) loss = loss_fn(y_pred, y) loss.backward() + clip_grad_norm_(model.parameters(), gradient_clip) optimizer.step() return output_transform(x, y, y_pred, loss) @@ -105,6 +109,7 @@ def supervised_training_step_amp( prepare_batch: Callable = _prepare_batch, output_transform: Callable = lambda x, y, y_pred, loss: loss.item(), scaler: Optional["torch.cuda.amp.GradScaler"] = None, + gradient_clip: float = math.inf, ) -> Callable: """Factory function for supervised training using ``torch.cuda.amp``. @@ -122,6 +127,8 @@ def supervised_training_step_amp( output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. scaler (torch.cuda.amp.GradScaler, optional): GradScaler instance for gradient scaling. (default: None) + gradient_clip (float): max norm of the gradients. + (default: math.inf) Returns: Callable: update function @@ -155,10 +162,13 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to loss = loss_fn(y_pred, y) if scaler: scaler.scale(loss).backward() + scaler.unscale_(optimizer) + clip_grad_norm_(model.parameters(), gradient_clip) scaler.step(optimizer) scaler.update() else: loss.backward() + clip_grad_norm_(model.parameters(), gradient_clip) optimizer.step() return output_transform(x, y, y_pred, loss) @@ -173,6 +183,7 @@ def supervised_training_step_apex( non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable = lambda x, y, y_pred, loss: loss.item(), + gradient_clip: float = math.inf, ) -> Callable: """Factory function for supervised training using apex. @@ -189,6 +200,8 @@ def supervised_training_step_apex( tuple of tensors `(batch_x, batch_y)`. output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. + gradient_clip (float): max norm of the gradients. + (default: math.inf) Returns: Callable: update function. @@ -220,6 +233,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to loss = loss_fn(y_pred, y) with apex_amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() + clip_grad_norm_(apex_amp.master_params(optimizer), gradient_clip) optimizer.step() return output_transform(x, y, y_pred, loss) @@ -234,6 +248,7 @@ def supervised_training_step_tpu( non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable = lambda x, y, y_pred, loss: loss.item(), + gradient_clip: float = math.inf, ) -> Callable: """Factory function for supervised training using ``torch_xla``. @@ -250,6 +265,8 @@ def supervised_training_step_tpu( tuple of tensors `(batch_x, batch_y)`. output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. + gradient_clip (float): max norm of the gradients. + (default: math.inf) Returns: Callable: update function. @@ -279,6 +296,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to y_pred = model(x) loss = loss_fn(y_pred, y) loss.backward() + xm.reduce_gradients(optimizer) + clip_grad_norm_(model.parameters(), gradient_clip) xm.optimizer_step(optimizer, barrier=True) return output_transform(x, y, y_pred, loss) @@ -324,6 +343,7 @@ def create_supervised_trainer( deterministic: bool = False, amp_mode: Optional[str] = None, scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False, + gradient_clip: float = math.inf, ) -> Engine: """Factory function for creating a trainer for supervised models. @@ -350,6 +370,8 @@ def create_supervised_trainer( and ``amp_mode`` is ``amp``. If ``amp_mode`` is ``apex``, this argument will be ignored. If True, will create default GradScaler. If GradScaler instance is passed, it will be used instead. (default: False) + gradient_clip (float): max norm of the gradients. + (default: math.inf) Note: If ``scaler`` is True, GradScaler instance will be created internally and trainer state has attribute named @@ -390,19 +412,19 @@ def create_supervised_trainer( if mode == "amp": _update = supervised_training_step_amp( - model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, _scaler + model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, _scaler, gradient_clip ) elif mode == "apex": _update = supervised_training_step_apex( - model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform + model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, gradient_clip ) elif mode == "tpu": _update = supervised_training_step_tpu( - model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform + model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, gradient_clip ) else: _update = supervised_training_step( - model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform + model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, gradient_clip ) trainer = Engine(_update) if not deterministic else DeterministicEngine(_update)