Skip to content

Commit

Permalink
Add beta support for jsd (#290)
Browse files Browse the repository at this point in the history
## Summary
Resolve #278 .

## Details
### Forward:
```math
\begin{align}
JSD(X, Y, \beta) &= JSD_{\beta}(P \Vert Q)\\
&= \beta\ KL(P \Vert \beta P + (1-\beta)Q) + (1-\beta)\ KL(Q \Vert \beta P + (1-\beta)Q)\\
&= \sum \beta\ PY + (1-\beta)QX - M\ logM
\end{align}
```
where $X=logQ$, $Y=logP$ and $M=\beta P + (1-\beta)Q$. 

### Gradients:
```math
\frac{\partial}{\partial X_i} JSD(X, Y, \beta) = (1-\beta)Q_i(X_i - logM_i)
```
## Testing Done

![jsd_memory](https://github.com/user-attachments/assets/a26e1a64-df4b-49fe-8564-01a6757cb76a)

![jsd_speed](https://github.com/user-attachments/assets/6f631bdb-5abf-44ed-875b-2596f3a30b8b)

- Hardware Type: H100
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Shao Tang <tangshao28@gmail.com>
  • Loading branch information
Tcc0403 and lancerts authored Oct 3, 2024
1 parent 60640e1 commit 6817c2d
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 75 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ loss.backward()
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
| JSD | `liger_kernel.transformers.LigerJSD` |

- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
Expand All @@ -264,6 +265,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
<!-- TODO: verify vocab sizes are accurate -->
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.

### Experimental Kernels

Expand Down
72 changes: 36 additions & 36 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -445,39 +445,39 @@ kl_div,torch,full,speed,ms,V,vocab size,16384,11.124671936035156,11.122162818908
kl_div,torch,full,speed,ms,V,vocab size,32768,23.052032470703125,23.050334930419922,23.052589416503906,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
kl_div,torch,full,speed,ms,V,vocab size,65536,46.063167572021484,46.05990219116211,46.06643295288086,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
kl_div,torch,full,speed,ms,V,vocab size,131072,92.06393432617188,92.06393432617188,92.06393432617188,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
jsd,liger,full,memory,MB,V,vocab size,4096,768.0029296875,768.0029296875,768.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,8192,1536.0029296875,1536.0029296875,1536.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,16384,3072.0048828125,3072.0048828125,3072.0048828125,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,32768,6144.0087890625,6144.0087890625,6144.0087890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,65536,12288.0166015625,12288.0166015625,12288.0166015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,131072,24576.015625,24576.015625,24576.015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,torch,full,memory,MB,V,vocab size,4096,1664.0009765625,1664.0009765625,1664.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,8192,3328.0009765625,3328.0009765625,3328.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,16384,6656.0009765625,6656.0009765625,6656.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,32768,13312.0009765625,13312.0009765625,13312.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,65536,26624.0,26624.0,26624.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,131072,53248.0,53248.0,53248.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,4096,0.4657920002937317,0.4644480049610138,0.4670400023460388,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,8192,0.9084159731864929,0.9064639806747437,0.9099519848823547,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,16384,9.939423561096191,9.933785438537598,9.945216178894043,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,32768,20.06915283203125,20.05768394470215,20.087200164794922,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,65536,38.88547134399414,38.880577087402344,38.89036560058594,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,131072,77.7418212890625,77.7418212890625,77.7418212890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,4096,2.1717119216918945,2.1697471141815186,2.173452854156494,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,8192,4.2592315673828125,4.255411148071289,4.2608771324157715,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,16384,8.363903999328613,8.359071731567383,8.36620807647705,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,32768,16.591264724731445,16.588390350341797,16.595033645629883,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,65536,33.06208038330078,33.06206130981445,33.06536102294922,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,131072,66.0923843383789,66.0923843383789,66.0923843383789,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,liger,full,speed,ms,V,vocab size,4096,1.5683839321136475,1.4662528038024902,1.7244799137115479,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,8192,2.0588159561157227,2.055116891860962,2.093465566635132,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,16384,11.944751739501953,11.936684608459473,11.961983680725098,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,32768,24.27791976928711,24.254375457763672,24.299558639526367,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,65536,47.206687927246094,47.17191696166992,47.241458892822266,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,131072,94.15420532226562,94.15420532226562,94.15420532226562,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,torch,full,speed,ms,V,vocab size,4096,4.875328063964844,4.873446464538574,4.878073692321777,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,8192,9.582816123962402,9.57910442352295,9.58505630493164,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,16384,18.931264877319336,18.92802619934082,18.934911727905273,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,32768,38.07579040527344,38.07549285888672,38.076087951660156,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,65536,75.97628784179688,75.97628784179688,75.97628784179688,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,131072,151.8501739501953,151.8501739501953,151.8501739501953,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,liger,full,memory,MB,V,vocab size,4096,768.0029296875,768.0029296875,768.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,8192,1536.0029296875,1536.0029296875,1536.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,16384,3072.0048828125,3072.0048828125,3072.0048828125,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,32768,6144.0087890625,6144.0087890625,6144.0087890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,65536,12288.0166015625,12288.0166015625,12288.0166015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,131072,24576.015625,24576.015625,24576.015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,torch,full,memory,MB,V,vocab size,4096,1664.0009765625,1664.0009765625,1664.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,8192,3328.0009765625,3328.0009765625,3328.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,16384,6656.0009765625,6656.0009765625,6656.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,32768,13312.0009765625,13312.0009765625,13312.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,65536,26624.0,26624.0,26624.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,131072,53248.0,53248.0,53248.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,4096,0.4651840031147003,0.4636736214160919,0.4659839868545532,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,8192,0.927888035774231,0.926751971244812,0.92952960729599,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,16384,10.96003246307373,10.942886352539062,10.970770835876465,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,32768,22.405792236328125,22.390380859375,22.41998863220215,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,65536,43.49095916748047,43.47438049316406,43.50754165649414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,131072,87.0363540649414,87.0363540649414,87.0363540649414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,4096,2.4744958877563477,2.4725184440612793,2.4764864444732666,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,8192,4.8528642654418945,4.851238250732422,4.854745864868164,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,16384,9.532496452331543,9.528634071350098,9.535890579223633,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,32768,18.91379165649414,18.911853790283203,18.919116973876953,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,65536,37.70152282714844,37.70074462890625,37.70229721069336,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,131072,75.37680053710938,75.37680053710938,75.37680053710938,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,liger,full,speed,ms,V,vocab size,4096,1.2074079513549805,1.1739968061447144,1.2760319709777832,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,8192,2.091792106628418,2.0771327018737793,2.106553554534912,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,16384,12.928031921386719,12.8988676071167,12.936230659484863,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,32768,26.55548858642578,26.550823211669922,26.570655822753906,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,65536,51.6833610534668,51.6833610534668,51.6833610534668,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,131072,103.12793731689453,103.12793731689453,103.12793731689453,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,torch,full,speed,ms,V,vocab size,4096,5.397359848022461,5.392876625061035,5.39998722076416,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,8192,10.60153579711914,10.597900390625,10.60470962524414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,16384,20.9442081451416,20.94247055053711,20.9469051361084,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,32768,42.113216400146484,42.113216400146484,42.113216400146484,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,65536,83.9959716796875,83.9959716796875,83.9959716796875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,131072,167.94175720214844,167.94175720214844,167.94175720214844,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
24 changes: 16 additions & 8 deletions benchmark/scripts/benchmark_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,25 @@
from liger_kernel.transformers.jsd import LigerJSD


class TorchJSD(torch.nn.Module):
def __init__(self):
class TorchJSD(nn.Module):
def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float):
super(TorchJSD, self).__init__()
self.kl = nn.KLDivLoss(reduction="batchmean", log_target=True)

def forward(self, log_p: torch.tensor, log_q: torch.tensor):
self.beta = beta
self.dtype = dtype

def forward(
self,
log_q: torch.tensor, # input
log_p: torch.tensor, # target
):
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
m = 0.5 * (torch.exp(log_p) + torch.exp(log_q))
log_m = torch.log(m)
loss = 0.5 * (self.kl(log_m, log_p) + self.kl(log_m, log_q))
return loss
m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl(
torch.log(m), log_q
)
return loss.to(self.dtype)


def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
Expand Down
47 changes: 31 additions & 16 deletions src/liger_kernel/ops/jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def _jsd_kernel(
loss_stride,
dX_ptr,
dX_stride,
beta,
n_rows,
n_cols,
BLOCK_SIZE: tl.constexpr,
Expand All @@ -37,20 +38,20 @@ def _jsd_kernel(

Q = tl.exp(X)
P = tl.exp(Y)
M = 0.5 * P + 0.5 * Q
M = beta * P + (1 - beta) * Q
log_M = tl.log(M)

loss = 0.5 * (P * Y + Q * X - 2 * M * log_M)
loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
tl.store(loss_ptr + offsets, loss, mask=mask)

dX = 0.5 * Q * (X - log_M) / n_rows
dX = (1 - beta) * Q * (X - log_M) / n_rows
tl.store(dX_ptr + offsets, dX, mask=mask)


MAX_FUSED_SIZE = 65536


def jsd_forward(_input, target):
def jsd_forward(_input, target, beta):
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
Expand All @@ -67,6 +68,7 @@ def jsd_forward(_input, target):
loss_stride=loss.stride(-2),
dX_ptr=dX,
dX_stride=dX.stride(-2),
beta=beta,
n_rows=n_rows,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
Expand All @@ -77,23 +79,26 @@ def jsd_forward(_input, target):


def jsd_backward(dX, grad_output):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
# If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
return dX
else:
return grad_output * dX


class LigerJSDFunction(torch.autograd.Function):
"""
Class implementing the forward and backward pass for the JS Divergence using Triton, as defined by the following formula:
Parameters:
_input (tensor): predict values with shape (BT, V) in logspace
target (tensor): gournd truth values with shape (BT, V) in logspace
Returns:
loss (tensor): JSD
r"""
This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
.. math::
JSD(\beta)(P || Q)
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
.. note::
As all the other losses in PyTorch, this function expects the first argument,
:attr:`_input`, to be the predictions, the output of the student model, in log-space
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
"""

@staticmethod
Expand All @@ -102,9 +107,18 @@ def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
beta: float = 0.5,
) -> torch.Tensor:

loss, dX = jsd_forward(_input, target)
"""
Args:
_input (torch.Tensor): predict values with shape (BT, V) in logspace
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
Returns:
loss (torch.Tensor): generalized JSD
"""
loss, dX = jsd_forward(_input, target, beta)
ctx.save_for_backward(dX)
return loss

Expand All @@ -116,4 +130,5 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return (
dX,
None,
None,
)
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
LigerFusedLinearCrossEntropyLoss,
)
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
from liger_kernel.transformers.monkey_patch import ( # noqa: F401
_apply_liger_kernel,
Expand Down
Loading

0 comments on commit 6817c2d

Please sign in to comment.