Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimize fused_linear_cross_entropy when weight does not require grads (
#237) ## Summary Add some easy checks for `weight.requires_grad` to skip allocating + calculating weight gradients if they're not needed. The weight gradient matrix can be pretty large, so this can also be a significant memory savings. Also, a small micro-optimization: skip the `.item()` call on `total_n_non_ignore` (the subsequent calculations work fine with the tensor form) to defer CUDA synchronization (otherwise it will wait for all the `torch.zeros` initializations on the preceding lines to synchronize, which may take a non-trivial amount of time.) ## Testing Done The existing unit test already has a case where the weight does not have gradients enabled, and it still passes forwards/backwards: https://github.com/linkedin/Liger-Kernel/blob/main/test/transformers/test_fused_linear_cross_entropy.py#L165 And the preceding test verifies the 'normal' case where the weight gradients are needed. - Hardware Type: A100 80G - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
- Loading branch information