Skip to content

Commit

Permalink
[RLlib] Use PyTorch's implementation of grad norm clipping. (ray-proj…
Browse files Browse the repository at this point in the history
…ect#36382)

Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
avnishn authored and arvind-chandra committed Aug 31, 2023
1 parent 1a733fa commit 8bb7b95
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions rllib/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,29 @@ def clip_gradients(
grad_clip_by == "global_norm"
), f"`grad_clip_by` ({grad_clip_by}) must be one of [value|norm|global_norm]!"

# Compute the global L2-norm of all the gradient tensors.
global_norm = sum(
# `.norm()` is the square root of the sum of all squares.
# We need to "undo" the square root b/c we want to compute the global
# norm afterwards -> `** 2`.
t.norm(2) ** 2
for t in gradients_dict.values()
if t is not None
grads = [g for g in gradients_dict.values() if g is not None]
norm_type = 2.0
if len(grads) == 0:
return torch.tensor(0.0)
device = grads[0].device

total_norm = torch.norm(
torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]),
norm_type,
)
# Now we do the square root.
global_norm = torch.sqrt(global_norm)

# Clip all the gradients.
if global_norm > grad_clip:
for tensor in gradients_dict.values():
if tensor is not None:
tensor.mul_(grad_clip / global_norm)

# Return the computed global norm scalar.
return global_norm
if torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. "
)
clip_coef = grad_clip / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to
# 1, but doing so avoids a `if clip_coef < 1:` conditional which can require a
# CPU <=> device synchronization when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for g in grads:
g.detach().mul_(clip_coef_clamped.to(g.device))
return total_norm


@PublicAPI
Expand Down

0 comments on commit 8bb7b95

Please sign in to comment.