Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient clipping to create_supervised_trainer() #1681

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import math
from collections.abc import Mapping
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import torch
from torch.nn.utils.clip_grad import clip_grad_norm_

import ignite.distributed as idist
from ignite.engine.deterministic import DeterministicEngine
Expand Down Expand Up @@ -31,9 +33,7 @@
def _prepare_batch(
batch: Sequence[torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False
) -> Tuple[Union[torch.Tensor, Sequence, Mapping, str, bytes], ...]:
"""Prepare batch for training: pass to a device with options.

"""
"""Prepare batch for training: pass to a device with options."""
x, y = batch
return (
convert_tensor(x, device=device, non_blocking=non_blocking),
Expand All @@ -49,6 +49,7 @@ def supervised_training_step(
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
gradient_clip: float = math.inf,
) -> Callable:
"""Factory function for supervised training.

Expand All @@ -65,6 +66,8 @@ def supervised_training_step(
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_clip (float): max norm of the gradients.
(default: math.inf)

Returns:
Callable: update function.
Expand All @@ -90,6 +93,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
clip_grad_norm_(model.parameters(), gradient_clip)
optimizer.step()
return output_transform(x, y, y_pred, loss)

Expand All @@ -105,6 +109,7 @@ def supervised_training_step_amp(
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
gradient_clip: float = math.inf,
) -> Callable:
"""Factory function for supervised training using ``torch.cuda.amp``.

Expand All @@ -122,6 +127,8 @@ def supervised_training_step_amp(
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
scaler (torch.cuda.amp.GradScaler, optional): GradScaler instance for gradient scaling. (default: None)
gradient_clip (float): max norm of the gradients.
(default: math.inf)

Returns:
Callable: update function
Expand Down Expand Up @@ -155,10 +162,13 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
loss = loss_fn(y_pred, y)
if scaler:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
clip_grad_norm_(model.parameters(), gradient_clip)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
clip_grad_norm_(model.parameters(), gradient_clip)
optimizer.step()
return output_transform(x, y, y_pred, loss)

Expand All @@ -173,6 +183,7 @@ def supervised_training_step_apex(
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
gradient_clip: float = math.inf,
) -> Callable:
"""Factory function for supervised training using apex.

Expand All @@ -189,6 +200,8 @@ def supervised_training_step_apex(
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_clip (float): max norm of the gradients.
(default: math.inf)

Returns:
Callable: update function.
Expand Down Expand Up @@ -220,6 +233,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
loss = loss_fn(y_pred, y)
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
clip_grad_norm_(apex_amp.master_params(optimizer), gradient_clip)
optimizer.step()
return output_transform(x, y, y_pred, loss)

Expand All @@ -234,6 +248,7 @@ def supervised_training_step_tpu(
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
gradient_clip: float = math.inf,
) -> Callable:
"""Factory function for supervised training using ``torch_xla``.

Expand All @@ -250,6 +265,8 @@ def supervised_training_step_tpu(
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_clip (float): max norm of the gradients.
(default: math.inf)

Returns:
Callable: update function.
Expand Down Expand Up @@ -279,6 +296,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
xm.reduce_gradients(optimizer)
clip_grad_norm_(model.parameters(), gradient_clip)
xm.optimizer_step(optimizer, barrier=True)
return output_transform(x, y, y_pred, loss)

Expand Down Expand Up @@ -324,6 +343,7 @@ def create_supervised_trainer(
deterministic: bool = False,
amp_mode: Optional[str] = None,
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
gradient_clip: float = math.inf,
) -> Engine:
"""Factory function for creating a trainer for supervised models.

Expand All @@ -350,6 +370,8 @@ def create_supervised_trainer(
and ``amp_mode`` is ``amp``. If ``amp_mode`` is ``apex``, this argument will be ignored.
If True, will create default GradScaler. If GradScaler instance is passed, it will be used instead.
(default: False)
gradient_clip (float): max norm of the gradients.
(default: math.inf)

Note:
If ``scaler`` is True, GradScaler instance will be created internally and trainer state has attribute named
Expand Down Expand Up @@ -390,19 +412,19 @@ def create_supervised_trainer(

if mode == "amp":
_update = supervised_training_step_amp(
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, _scaler
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, _scaler, gradient_clip
)
elif mode == "apex":
_update = supervised_training_step_apex(
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, gradient_clip
)
elif mode == "tpu":
_update = supervised_training_step_tpu(
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, gradient_clip
)
else:
_update = supervised_training_step(
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, gradient_clip
)

trainer = Engine(_update) if not deterministic else DeterministicEngine(_update)
Expand Down