Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/tvd loss fused #1

Merged
merged 11 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,39 @@ fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859
fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
tvd,liger,full,memory,MB,V,vocab size,4096,1792.0009765625,1792.0009765625,1792.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,liger,full,memory,MB,V,vocab size,8192,3584.0009765625,3584.0009765625,3584.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,liger,full,memory,MB,V,vocab size,16384,7168.0009765625,7168.0009765625,7168.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,liger,full,memory,MB,V,vocab size,32768,14336.0009765625,14336.0009765625,14336.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,liger,full,memory,MB,V,vocab size,65536,28672.0,28672.0,28672.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,liger,full,memory,MB,V,vocab size,131072,57344.0,57344.0,57344.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,torch,full,memory,MB,V,vocab size,4096,2048.0009765625,2048.0009765625,2048.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,torch,full,memory,MB,V,vocab size,8192,4096.0009765625,4096.0009765625,4096.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,torch,full,memory,MB,V,vocab size,16384,8192.0009765625,8192.0009765625,8192.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,torch,full,memory,MB,V,vocab size,32768,16384.0,16384.0,16384.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,torch,full,memory,MB,V,vocab size,65536,32768.0,32768.0,32768.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,torch,full,memory,MB,V,vocab size,131072,65536.0,65536.0,65536.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,liger,forward,speed,ms,V,vocab size,4096,0.47814399003982544,0.4774720072746277,0.4790079891681671,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
tvd,liger,forward,speed,ms,V,vocab size,8192,0.906495988368988,0.905951976776123,0.9073920249938965,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
tvd,liger,forward,speed,ms,V,vocab size,16384,1.8787360191345215,1.8778239488601685,1.8797119855880737,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
tvd,liger,forward,speed,ms,V,vocab size,32768,3.5788800716400146,3.5772159099578857,3.58076810836792,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
tvd,liger,forward,speed,ms,V,vocab size,65536,7.008831977844238,7.007718086242676,7.010636806488037,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
tvd,liger,forward,speed,ms,V,vocab size,131072,13.88646411895752,13.88128662109375,13.890560150146484,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
tvd,torch,forward,speed,ms,V,vocab size,4096,1.308608055114746,1.306502342224121,1.3104127645492554,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
tvd,torch,forward,speed,ms,V,vocab size,8192,2.4735519886016846,2.472287893295288,2.4749441146850586,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
tvd,torch,forward,speed,ms,V,vocab size,16384,4.828320026397705,4.826848030090332,4.830643177032471,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
tvd,torch,forward,speed,ms,V,vocab size,32768,9.5206880569458,9.517024040222168,9.525145530700684,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
tvd,torch,forward,speed,ms,V,vocab size,65536,19.01535987854004,19.011123657226562,19.01806640625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
tvd,torch,forward,speed,ms,V,vocab size,131072,38.022865295410156,38.01945877075195,38.02627182006836,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
tvd,liger,full,speed,ms,V,vocab size,4096,2.626512050628662,2.621260643005371,2.646751880645752,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
tvd,liger,full,speed,ms,V,vocab size,8192,4.661711692810059,4.657618999481201,4.662930965423584,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
tvd,liger,full,speed,ms,V,vocab size,16384,9.088272094726562,9.080741882324219,9.092268943786621,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
tvd,liger,full,speed,ms,V,vocab size,32768,18.116064071655273,18.112728118896484,18.118234634399414,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
tvd,liger,full,speed,ms,V,vocab size,65536,35.85124969482422,35.849971771240234,35.85252380371094,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
tvd,liger,full,speed,ms,V,vocab size,131072,71.1648941040039,71.1648941040039,71.1648941040039,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
tvd,torch,full,speed,ms,V,vocab size,4096,4.361599922180176,4.360159873962402,4.3639678955078125,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
tvd,torch,full,speed,ms,V,vocab size,8192,8.11302375793457,8.11075210571289,8.114463806152344,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
tvd,torch,full,speed,ms,V,vocab size,16384,15.841055870056152,15.837087631225586,15.841856002807617,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
tvd,torch,full,speed,ms,V,vocab size,32768,31.71219253540039,31.706951141357422,31.715898513793945,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
tvd,torch,full,speed,ms,V,vocab size,65536,63.17919921875,63.17919921875,63.17919921875,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
tvd,torch,full,speed,ms,V,vocab size,131072,126.0436782836914,126.0436782836914,126.0436782836914,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
136 changes: 136 additions & 0 deletions benchmark/scripts/benchmark_tvd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
import torch.nn as nn
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.transformers.tvd import LigerTVDLoss


class TorchTVDLoss(torch.nn.Module):
def __init__(self, reduction='batchmean'):
super(TorchTVDLoss, self).__init__()
self.reduction = reduction

def forward(self, p, q):
tvd = torch.abs(p - q) / 2.0
if self.reduction == 'mean':
return torch.sum(tvd) / (p.size(0) * p.size(1))
elif self.reduction == 'sum':
return torch.sum(tvd)
elif self.reduction == 'none':
return tvd
elif self.reduction == 'batchmean':
return torch.sum(tvd) / p.size(0)
else:
raise ValueError("Invalid reduction type.")

S, E = 12, 18


def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
reduction = "batchmean"
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
torch_tvd = TorchTVDLoss(reduction=reduction)
liger_tvd = LigerTVDLoss(reduction=reduction)

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1)
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
return liger_tvd(_input, target)
else:
return torch_tvd(_input, target)

if input.kernel_operation_mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
elif input.kernel_operation_mode == "backward":
y = fwd()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[_input],
rep=100,
)
elif input.kernel_operation_mode == "full":

def full():
y = fwd()
y.backward(retain_graph=True)

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full, quantiles=QUANTILES, rep=100
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


def bench_memory_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
reduction = "batchmean"
torch_tvd = TorchTVDLoss(reduction=reduction)
liger_tvd = LigerTVDLoss(reduction=reduction)

V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1)
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
return liger_tvd(_input, target)
else:
return torch_tvd(_input, target)

def full():
y = fwd()
y.backward(retain_graph=True)

mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)

return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()
common_args = {
"kernel_name": "tvd",
"x_name": "V",
"x_label": "vocab size",
"x_values": [2**i for i in range(12, 18)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [{"B": 8, "T": 2048}],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_memory_tvd,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_args,
)

run_benchmarks(
bench_test_fn=bench_speed_tvd,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_args,
)
174 changes: 174 additions & 0 deletions src/liger_kernel/ops/tvd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import torch
import triton
import triton.language as tl

from typing import Literal
from liger_kernel.ops.utils import ensure_contiguous

MAX_FUSED_SIZE = 65536 // 4

REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]

_REDUCTION_MODE_NONE = tl.constexpr(0)
_REDUCTION_MODE_SUM = tl.constexpr(1)
_REDUCTION_MODE_MEAN = tl.constexpr(2)
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)

_str_to_reduction_mode = {
"none": _REDUCTION_MODE_NONE.value,
"sum": _REDUCTION_MODE_SUM.value,
"mean": _REDUCTION_MODE_MEAN.value,
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
}

def get_num_warps(BLOCK_SIZE):
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32
elif BLOCK_SIZE >= 8192:
num_warps = 16
elif BLOCK_SIZE >= 2048:
num_warps = 8

return num_warps

@triton.jit
def _tv_distance_kernel(
p_ptr,
p_stride,
q_ptr,
q_stride,
loss_ptr,
loss_stride,
grads_ptr,
grads_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
):
pid = tl.program_id(0).to(tl.int64)
p_ptr += pid * p_stride
q_ptr += pid * q_stride
loss_ptr += pid * loss_stride
grads_ptr += pid * grads_stride

base_offsets = tl.arange(0, BLOCK_SIZE)

loss_sum = 0.0
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + base_offsets
mask = offsets < n_cols

p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
q = tl.load(q_ptr + offsets, mask=mask, other=0.0)

# TVD(P || Q) = 0.5 * |P - Q|
tv_loss = 0.5 * tl.abs(p - q)

grad_res = tl.where(p > q, 0.5, -0.5)

tl.store(grads_ptr + offsets, grad_res, mask=mask)

if reduction == _REDUCTION_MODE_NONE:
tl.store(loss_ptr + offsets, tv_loss, mask=mask)
else:
loss_sum += tl.sum(tv_loss, axis=0)

if reduction != _REDUCTION_MODE_NONE:
tl.store(loss_ptr, loss_sum)

def tv_distance_forward_triton(p, q, reduction):
BT, V = p.shape

BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
num_warps = get_num_warps(BLOCK_SIZE)

grid = (BT,)

reduction = _str_to_reduction_mode[reduction]

out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
grads = torch.empty_like(p)

_tv_distance_kernel[grid](
p,
p.stride(0),
q,
q.stride(0),
output_tensor,
output_tensor.stride(0),
grads,
grads.stride(0),
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
reduction=reduction,
)

if reduction == _REDUCTION_MODE_BATCHMEAN.value:
return output_tensor.sum() / BT, grads
elif reduction == _REDUCTION_MODE_SUM.value:
return output_tensor.sum(dim=0), grads
elif reduction == _REDUCTION_MODE_MEAN.value:
return output_tensor.sum() / (BT * V), grads
else:
return output_tensor, grads

def tvd_backward_triton(grad_output, grads):

# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
return grads

return grads * grad_output

class LigerTVDLossFunction(torch.autograd.Function):
"""
Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
"""

@staticmethod
@ensure_contiguous
def forward(
ctx, p: torch.Tensor, q: torch.Tensor, reduction: REDUCTION_LITERAL = "batchmean"
) -> torch.Tensor:
"""A forward pass for the Total Variation Distance Loss.

Args:
ctx: Torch autograd context
p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".

Returns:
torch.Tensor: The computed Total Variation Distance Loss.
"""
loss, grads = tv_distance_forward_triton(p, q, reduction)
ctx.save_for_backward(grads)
ctx.reduction = reduction
return loss

@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""A backward pass for the Total Variation Distance Loss.

Args:
ctx: Torch autograd context
grad_output (torch.Tensor): The gradient of the loss with respect to the output.

Returns:
tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the inputs.
"""
grads, = ctx.saved_tensors
BT, V = grads.shape

grads = tvd_backward_triton(grad_output, grads)

if ctx.reduction == "batchmean":
grads /= BT
elif ctx.reduction == "mean":
grads /= (BT * V)

return grads, None, None
2 changes: 2 additions & 0 deletions src/liger_kernel/transformers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
from liger_kernel.ops.rope import LigerRopeFunction
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
from liger_kernel.ops.tvd import LigerTVDLossFunction

liger_swiglu = LigerSiLUMulFunction.apply
liger_cross_entropy = LigerCrossEntropyFunction.apply
Expand All @@ -21,3 +22,4 @@
liger_kl_div = LigerKLDivLossFunction.apply
liger_jsd = LigerJSDFunction.apply
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
liger_tvd = LigerTVDLossFunction.apply
11 changes: 11 additions & 0 deletions src/liger_kernel/transformers/tvd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch
import torch.nn as nn
from liger_kernel.ops.tvd import LigerTVDLossFunction

class LigerTVDLoss(nn.Module):
def __init__(self, reduction='batchmean'):
super(LigerTVDLoss, self).__init__()
self.reduction = reduction

def forward(self, p, q):
return LigerTVDLossFunction.apply(p, q, self.reduction)
Loading