Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add label smoothing for cross entropy (linkedin#198)
## Summary Aim to solve linkedin#81. ## Details ### For loss: Label smoothing regularization ( LSR ) by replacing the label distribution $q(k) = \delta_{k,y}$ with ```math q'(k) = (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K} ``` Considering cross entropy with LSR is ```math \begin{align} L' = H(q', p) &= -\sum^K_{k=1}log\ {p(k)}q'(k) = -\sum^K_{k=1}log\ {p(k)}((1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K})\\ &= -\sum^K_{k=1}log\ {p(k)}(1 - \epsilon)q(k) -\sum^K_{k=1}log\ {p(k)}\frac{\epsilon}{K} \\ &= (1 - \epsilon)H(q,p) + \frac{\epsilon}{K} \sum^K_{k=1} log\ softmax(x_k)\\ &= (1- \epsilon)L + \frac{\epsilon}{K}\ SmoothLoss, \end{align} ``` where $L = H(q,p)$ is the original loss and $\sum^K_{k=1} log\ softmax(x_k)$ is smooth loss. ### For gradients: The original: ```math \begin{align} \frac{\partial L}{\partial x_i} &= p(k) - q(k)\\ &= \begin{cases} softmax(x_i) , & i \neq y \\ softmax(x_i) - 1, & i = y \end{cases} \end{align} ``` With LSR: ```math \begin{align} \frac{\partial L'}{\partial x_i} &= p(k) - q'(k)\\ &= softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}\\ &= \begin{cases} softmax(x_i) - \frac{\epsilon}{K}, & i \neq y \\ softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) & i = y \end{cases} \end{align} ``` We can handle the $i = y$ case by simply adding $-(1-\epsilon)$ after computing all $i$. Reference: [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/abs/1512.00567) ## Testing Done Add a unit test for label smoothing. - Hardware Type: RTX-3080 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ```bash ❯ python3 -m pytest test/transformers/test_cross_entropy.py ============================================ test session starts ============================================= platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel collected 94 items test/transformers/test_cross_entropy.py .............................................................. [ 65%] ...............................F [100%] ================================================== FAILURES ================================================== __________________________________ test_large_no_exception[8-16384-128256] ___________________________________ B = 8, T = 16384, V = 128256 @pytest.mark.parametrize( "B, T, V", [ ( 8, 8192, 128256, ), # _input = 16GB, total = ~32GB, 8405385216 > 2,147,483,647, so we need int64 (8, 16384, 128256), # _input = 32GB, total = ~64GB ], ) # @pytest.mark.skipif( # torch.cuda.get_device_properties(0).total_memory < 64 * 1000 * 1000 * 1000, # reason="Needs 64GB+ GPU memory.", # ) def test_large_no_exception(B, T, V): # The large inputs were hitting cuda illegal memory access because of # triton-lang/triton#1058 > _full_pass_once(B, T, V) test/transformers/test_cross_entropy.py:401: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ B = 8, T = 16384, V = 128256 def _full_pass_once(B, T, V): torch.manual_seed(0) liger_ce = LigerCrossEntropyLoss() > _input = torch.randn( B * T, V, requires_grad=True, device="cuda", dtype=torch.bfloat16 ) E torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10.00 GiB of which 8.84 GiB is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 0 bytes is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) test/transformers/test_cross_entropy.py:374: OutOfMemoryError ========================================== short test summary info =========================================== FAILED test/transformers/test_cross_entropy.py::test_large_no_exception[8-16384-128256] - torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10... ================================== 1 failed, 93 passed in 130.88s (0:02:10) ================================== ``` ```bash ❯ make test python -m pytest --disable-warnings test/ --ignore=test/convergence ============================================ test session starts ============================================= platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel collected 256 items test/transformers/test_auto_model.py . [ 0%] test/transformers/test_cross_entropy.py ssssssssssssssssssssssss............ssssssssssssssssssssssssss [ 24%] ssssssssssssssssssssssssssssssss [ 37%] test/transformers/test_embedding.py ........... [ 41%] test/transformers/test_fused_linear_cross_entropy.py ................ [ 47%] test/transformers/test_geglu.py ............ [ 52%] test/transformers/test_layer_norm.py ................ [ 58%] test/transformers/test_monkey_patch.py ..... [ 60%] test/transformers/test_rms_norm.py ............................................................ [ 83%] test/transformers/test_rope.py .................. [ 91%] test/transformers/test_swiglu.py .................... [ 98%] test/transformers/test_trainer_integration.py . [ 99%] test/triton/test_triton_monkey_patch.py .. [100%] ================================ 174 passed, 82 skipped in 123.06s (0:02:03) ================================= ``` ```bash ❯ make checkstyle flake8 .; flake8_status=$?; \ isort .; isort_status=$?; \ black .; black_status=$?; \ if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \ exit 1; \ fi Skipped 2 files All done! ✨ 🍰 ✨ 68 files left unchanged. ``` ```bash ❯ make test-convergence HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence ============================================ test session starts ============================================= platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel collected 30 items test/convergence/test_mini_models.py .............. [ 46%] test/convergence/test_mini_models_no_logits.py ................ [100%] ======================================= 30 passed in 223.18s (0:03:43) ======================================= ```
- Loading branch information