Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix sparse gradient clipping for torch>=2.0 #288

Merged
merged 1 commit into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions pecos/utils/torch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np
import torch
from typing import Union, Iterable

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,3 +73,64 @@ def apply_mask(hidden_states, masks):
hidden_dim = hidden_states.shape[-1]
hidden_states.view(-1, hidden_dim)[~masks.view(-1).type(torch.ByteTensor), :] = 0
return hidden_states


def clip_grad_norm_(
parameters: Union[torch.Tensor, Iterable[torch.Tensor]],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
) -> torch.Tensor:
r"""
Implementation of torch.nn.utils.clip_grad_norm_ in torch==1.13
This is to support sparse gradient with gradient clipping.
REF: https://pytorch.org/docs/1.13/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_

Clips gradient norm of an iterable of parameters.

The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.

Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)

Returns:
Total norm of the parameter gradients (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grads = [p.grad for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return torch.tensor(0.0)
device = grads[0].device
if norm_type == "inf":
norms = [g.detach().abs().max().to(device) for g in grads]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
else:
total_norm = torch.norm(
torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type
)
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for g in grads:
g.detach().mul_(clip_coef_clamped.to(g.device))
return total_norm
8 changes: 5 additions & 3 deletions pecos/xmc/xlinear/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,11 @@ def predict(
Ye = self.predict(
X[i : i + max_pred_chunk, :],
pred_params=pred_params,
selected_outputs_csr=selected_outputs_csr[i : i + max_pred_chunk, :]
if selected_outputs_csr is not None
else None,
selected_outputs_csr=(
selected_outputs_csr[i : i + max_pred_chunk, :]
if selected_outputs_csr is not None
else None
),
**new_kwargs,
)
Ys.append(Ye)
Expand Down
34 changes: 22 additions & 12 deletions pecos/xmc/xtransformer/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,18 +784,20 @@ def _predict(
if not only_embeddings:
text_model_W_seq, text_model_b_seq = self.text_model(
output_indices=inputs["label_indices"],
num_device=len(self.text_encoder.device_ids)
if hasattr(self.text_encoder, "device_ids")
else 1,
num_device=(
len(self.text_encoder.device_ids)
if hasattr(self.text_encoder, "device_ids")
else 1
),
)

outputs = self.text_encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
label_embedding=None
if only_embeddings
else (text_model_W_seq, text_model_b_seq),
label_embedding=(
None if only_embeddings else (text_model_W_seq, text_model_b_seq)
),
)

if not only_embeddings:
Expand Down Expand Up @@ -1088,9 +1090,11 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None):
}
text_model_W_seq, text_model_b_seq = self.text_model(
output_indices=inputs["label_indices"],
num_device=len(self.text_encoder.device_ids)
if hasattr(self.text_encoder, "device_ids")
else 1,
num_device=(
len(self.text_encoder.device_ids)
if hasattr(self.text_encoder, "device_ids")
else 1
),
)
outputs = self.text_encoder(
input_ids=inputs["input_ids"],
Expand Down Expand Up @@ -1119,9 +1123,15 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None):
scheduler.step() # update learning rate schedule
optimizer.zero_grad() # clear gradient accumulation

torch.nn.utils.clip_grad_norm_(
self.text_model.parameters(), train_params.max_grad_norm
)
if self.text_model.is_sparse:
torch_util.clip_grad_norm_(
self.text_model.parameters(), train_params.max_grad_norm
)
else:
torch.nn.utils.clip_grad_norm_(
self.text_model.parameters(), train_params.max_grad_norm
)

emb_optimizer.step() # perform gradient update
emb_scheduler.step() # update learning rate schedule
emb_optimizer.zero_grad() # clear gradient accumulation
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def get_blas_lib_dir(cls):
install_requires = numpy_requires + [
'scipy>=1.4.1',
'scikit-learn>=0.24.1',
'torch>=1.8.0,<2.0.0',
'torch==1.13; python_version<"3.8"',
'torch>=2.0; python_version>="3.8"',
'sentencepiece>=0.1.86,!=0.1.92', # 0.1.92 results in error for transformers
'transformers>=4.1.1; python_version<"3.9"',
'transformers>=4.4.2; python_version>="3.9"'
Expand Down
Loading