Skip to content

Commit

Permalink
Optimize fused_linear_cross_entropy when weight does not require grads (
Browse files Browse the repository at this point in the history
#237)

## Summary

Add some easy checks for `weight.requires_grad` to skip allocating +
calculating weight gradients if they're not needed. The weight gradient
matrix can be pretty large, so this can also be a significant memory
savings.

Also, a small micro-optimization: skip the `.item()` call on
`total_n_non_ignore` (the subsequent calculations work fine with the
tensor form) to defer CUDA synchronization (otherwise it will wait for
all the `torch.zeros` initializations on the preceding lines to
synchronize, which may take a non-trivial amount of time.)

## Testing Done

The existing unit test already has a case where the weight does not have
gradients enabled, and it still passes forwards/backwards:
https://github.com/linkedin/Liger-Kernel/blob/main/test/transformers/test_fused_linear_cross_entropy.py#L165

And the preceding test verifies the 'normal' case where the weight
gradients are needed.

- Hardware Type: A100 80G
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
hansonw authored Sep 9, 2024
1 parent b5d8cbf commit acd8272
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ def fused_linear_cross_entropy_forward(
) # (BT + inc_factor - 1) // inc_factor
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size

grad_weight = torch.zeros_like(weight, device=device)
grad_weight = (
torch.zeros_like(weight, device=device) if weight.requires_grad else None
)
grad_input = torch.zeros_like(_input, device=device)
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
# we use fp32 for loss accumulator
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)

total_n_non_ignore = (target != ignore_index).sum().item()
# NOTE: skip .item() here to avoid CUDA synchronization
total_n_non_ignore = (target != ignore_index).sum()

for chunk_id in range(num_chunks):
start_idx = chunk_id * chunk_size
Expand Down Expand Up @@ -101,14 +104,16 @@ def fused_linear_cross_entropy_forward(
n_non_ignore / total_n_non_ignore
) # chunk_size x V
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
torch.addmm(
input=grad_weight,
mat1=logits_chunk.t(),
mat2=_input_chunk,
out=grad_weight,
alpha=n_non_ignore / total_n_non_ignore,
beta=1.0,
)

if grad_weight is not None:
torch.addmm(
input=grad_weight,
mat1=logits_chunk.t(),
mat2=_input_chunk,
out=grad_weight,
alpha=n_non_ignore / total_n_non_ignore,
beta=1.0,
)

if bias is not None:
torch.add(
Expand Down Expand Up @@ -143,17 +148,18 @@ def fused_linear_cross_entropy_backward(
)

# handle grad_weight
V, H = grad_weight.shape
n_rows = V
if grad_weight is not None:
V, H = grad_weight.shape
n_rows = V

element_mul_kernel[(n_rows,)](
grad_weight,
grad_weight.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)
element_mul_kernel[(n_rows,)](
grad_weight,
grad_weight.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)

if grad_bias is not None:
V = grad_bias.shape[0]
Expand Down Expand Up @@ -196,7 +202,7 @@ def forward(
# downcast to dtype and store for backward
ctx.save_for_backward(
grad_input.detach(),
grad_weight.detach(),
grad_weight.detach() if grad_weight is not None else None,
grad_bias.detach() if bias is not None else None,
)
return loss
Expand Down

0 comments on commit acd8272

Please sign in to comment.