From ff6650bbcef5d31b7522694cbeb73a21169460e9 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat, 12 Oct 2024 02:22:25 +0800 Subject: [PATCH] Add FusedLinearJSD (#300) ## Summary similar to the fuse linear CE. It handles the forward and backward pass of the final linear layer via JSD by avoiding the materialization of the large logits tensor. Since JSD is the last layer, we can compute the gradient at the forward pass. ## Testing Done Hidden size: 4096, Vocab size: 128256 ![fused_linear_jsd_memory](https://github.com/user-attachments/assets/231303d1-4734-49fb-8c69-8e60730563c2) ![fused_linear_jsd_speed](https://github.com/user-attachments/assets/d83c85ec-ab29-44e0-a3d9-ad85acf4577d) - Hardware Type: NVIDIA H100 80GB HBM3 (SXM5) - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Qingquan Song --- benchmark/data/all_benchmark_data.csv | 24 ++ .../scripts/benchmark_fused_linear_jsd.py | 250 ++++++++++++++++++ src/liger_kernel/ops/cross_entropy.py | 38 +-- .../ops/fused_linear_cross_entropy.py | 6 +- src/liger_kernel/ops/fused_linear_jsd.py | 201 ++++++++++++++ src/liger_kernel/ops/jsd.py | 6 +- src/liger_kernel/ops/utils.py | 36 +++ src/liger_kernel/transformers/__init__.py | 1 + src/liger_kernel/transformers/functional.py | 2 + .../transformers/fused_linear_jsd.py | 61 +++++ test/transformers/test_fused_linear_jsd.py | 209 +++++++++++++++ 11 files changed, 792 insertions(+), 42 deletions(-) create mode 100644 benchmark/scripts/benchmark_fused_linear_jsd.py create mode 100644 src/liger_kernel/ops/fused_linear_jsd.py create mode 100644 src/liger_kernel/transformers/fused_linear_jsd.py create mode 100644 test/transformers/test_fused_linear_jsd.py diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 65c74a7d4..32c8d01ab 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -481,3 +481,27 @@ jsd,torch,full,speed,ms,V,vocab size,16384,20.9442081451416,20.94247055053711,20 jsd,torch,full,speed,ms,V,vocab size,32768,42.113216400146484,42.113216400146484,42.113216400146484,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 jsd,torch,full,speed,ms,V,vocab size,65536,83.9959716796875,83.9959716796875,83.9959716796875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 jsd,torch,full,speed,ms,V,vocab size,131072,167.94175720214844,167.94175720214844,167.94175720214844,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,1024,110.02185821533203,110.02185821533203,110.02185821533203,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,2048,124.14070129394531,124.14070129394531,124.14070129394531,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,4096,143.15420532226562,143.15420532226562,143.15420532226562,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,8192,180.90406799316406,180.90406799316406,180.90406799316406,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,1024,9.556896209716797,9.550745964050293,9.576268196105957,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,2048,18.73731231689453,18.732704162597656,18.737701416015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,4096,37.830482482910156,37.80821990966797,37.85274124145508,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,8192,75.15289306640625,75.15289306640625,75.15289306640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,1024,111.16019439697266,111.16019439697266,111.16019439697266,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,2048,125.6825942993164,125.6825942993164,125.6825942993164,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,4096,144.00784301757812,144.00784301757812,144.00784301757812,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,8192,182.5832977294922,182.5832977294922,182.5832977294922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,1024,25.977184295654297,25.968351364135742,25.989356994628906,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,2048,49.48417663574219,49.47330093383789,49.495052337646484,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,4096,98.31510162353516,98.31510162353516,98.31510162353516,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,8192,195.29539489746094,195.29539489746094,195.29539489746094,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,1024,4652.48486328125,4652.48486328125,4652.48486328125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,2048,5231.93798828125,5231.93798828125,5231.93798828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,4096,6391.87548828125,6391.87548828125,6391.87548828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,8192,8711.75,8711.75,8711.75,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859375,10609.005859375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 diff --git a/benchmark/scripts/benchmark_fused_linear_jsd.py b/benchmark/scripts/benchmark_fused_linear_jsd.py new file mode 100644 index 000000000..ee32428a7 --- /dev/null +++ b/benchmark/scripts/benchmark_fused_linear_jsd.py @@ -0,0 +1,250 @@ +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD + + +class TorchJSD(torch.nn.Module): + def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float): + super(TorchJSD, self).__init__() + self.kl = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) + self.beta = beta + self.dtype = dtype + + def forward( + self, + log_q: torch.tensor, # input + log_p: torch.tensor, # target + ): + log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) + log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) + m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta) + loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl( + torch.log(m), log_q + ) + return loss.to(self.dtype) + + +class TorchLMHeadJSD(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based jsd loss. + + :param H: hidden size + :param V: vocab size + :param temperature: softmax temperature + :param beta: jsd beta + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + temperature: float = 1.0, + beta: float = 0.5, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.jsd = TorchJSD(beta, dtype=dtype) + self.temperature = temperature + + def forward(self, student_input, teacher_input): + student_logits = self.student_lin(student_input) + teacher_logits = self.teacher_lin(teacher_input) + student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1) + teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1) + + return self.jsd(student_prob, teacher_prob) + + +class LigerLMHeadJSD(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + temperature: float = 1.0, + beta: float = 0.5, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.fused_jsd = LigerFusedLinearJSD(beta, temperature) + + def forward(self, student_input, teacher_input): + return self.fused_jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + ) + + +############################################################################# +# Test the memory consumption of the fused linear JSD +############################################################################# + + +def bench_memory_fused_linear_jsd( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device) + teacher_input = torch.rand(BT, H, dtype=dtype, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_jsd(student_input, teacher_input) + elif provider == "torch": + return torch_lm_head_jsd(student_input, teacher_input) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear JSD +# ############################################################################# + + +def bench_speed_fused_linear_jsd( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + mode = input.kernel_operation_mode + + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device) + teacher_input = torch.rand(BT, H, dtype=dtype, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_jsd(student_input, teacher_input) + elif provider == "torch": + return torch_lm_head_jsd(student_input, teacher_input) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[ + student_input, + torch_lm_head_jsd.student_lin.weight, + torch_lm_head_jsd.teacher_lin.weight, + ], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_linear_jsd", + "x_name": "BT", + "x_label": "B x T", + "x_values": [2**i for i in range(10, 14)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + {"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16} + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_jsd, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_jsd, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 97c6c06cb..c72ba8d45 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from liger_kernel.ops.utils import element_mul_kernel + @triton.jit def liger_cross_entropy_kernel( @@ -159,42 +161,6 @@ def liger_cross_entropy_kernel( MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning -@triton.jit -def element_mul_kernel( - X_ptr, - X_stride, - grad_output_ptr, - n_cols, - BLOCK_SIZE: tl.constexpr, -): - """ - This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. - The multiplication is performed in-place on the tensor pointed by X_ptr. - - Parameters: - X_ptr: Pointer to the input tensor. - X_stride (int): The stride of the input tensor. - grad_output_ptr: Pointer to the gradient output value. - n_cols (int): The number of columns in the input tensor. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - # Get the program ID and convert it to int64 to avoid overflow - program_id = tl.program_id(0).to(tl.int64) - - # Locate the start index - X_ptr += program_id * X_stride - - # Load the gradient output value - grad_output = tl.load(grad_output_ptr) - - # Perform the element-wise multiplication - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) - tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) - - def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction): BT, V = _input.shape n_rows = BT diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 73da9cd46..e9a28afbb 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -1,10 +1,8 @@ import torch import triton -from liger_kernel.ops.cross_entropy import ( - element_mul_kernel, - liger_cross_entropy_kernel, -) +from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel +from liger_kernel.ops.utils import element_mul_kernel # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling diff --git a/src/liger_kernel/ops/fused_linear_jsd.py b/src/liger_kernel/ops/fused_linear_jsd.py new file mode 100644 index 000000000..9f191c14a --- /dev/null +++ b/src/liger_kernel/ops/fused_linear_jsd.py @@ -0,0 +1,201 @@ +import torch +import triton + +from liger_kernel.ops.jsd import _jsd_kernel +from liger_kernel.ops.utils import element_mul_kernel + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 + + +def fused_linear_jsd_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + jsd_beta, + temperature, +): + device = student_input.device + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = student_input.shape + V = student_weight.shape[0] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2( + triton.cdiv(BT, inc_factor) + ) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + grad_weight = ( + torch.zeros_like(student_weight, device=device) + if student_weight.requires_grad + else None + ) + grad_input = torch.zeros_like(student_input) + # we use fp32 for loss accumulator + loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device) + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + + # chunk both inputs, shape: chunk_size x H + student_input_chunk = student_input[start_idx:end_idx] + teacher_input_chunk = teacher_input[start_idx:end_idx] + + # when doing matmul, use the original precision, shape: chunk_size x V + student_logits_chunk = student_input_chunk @ student_weight.t() + teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() + chunk_n_rows = student_logits_chunk.shape[0] + + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size + # log-softmax with temperature + student_logits_chunk = student_logits_chunk / temperature + teacher_logits_chunk = teacher_logits_chunk / temperature + student_prob_chunk = torch.log_softmax(student_logits_chunk, dim=-1) + teacher_prob_chunk = torch.log_softmax(teacher_logits_chunk, dim=-1) + + # ensure _input and target are contiguous + student_prob_chunk = student_prob_chunk.contiguous() + teacher_prob_chunk = teacher_prob_chunk.contiguous() + + # Here we calculate the gradient of prob_chunk in place so we can save memory. + _jsd_kernel[(chunk_n_rows,)]( + X_ptr=student_prob_chunk, + X_stride=student_prob_chunk.stride(-2), + Y_ptr=teacher_prob_chunk, + Y_stride=teacher_prob_chunk.stride(-2), + loss_ptr=loss_1d_slice, + loss_stride=loss_1d_slice.stride(-2), + dX_ptr=student_prob_chunk, + dX_stride=student_prob_chunk.stride(-2), + beta=jsd_beta, + n_rows=BT, # batchmean + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + ) + loss_1d[start_idx:end_idx] = loss_1d_slice + # gradients of prob_chunk in place, shape: chunk_size x V + # gradients of logits_chunk in place, shape: chunk_size x V + student_logits_chunk = ( + student_prob_chunk + - torch.softmax(student_logits_chunk, dim=-1) + * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to( + student_prob_chunk.shape + ) + ) / temperature + grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight + + if grad_weight is not None: + torch.addmm( + input=grad_weight, + mat1=student_logits_chunk.t(), # gradients of logits_chunk + mat2=student_input_chunk, + out=grad_weight, + ) + + loss = torch.sum(loss_1d) + return loss.to(student_input.dtype), grad_input, grad_weight + + +def fused_linear_jsd_backward(grad_output, grad_input, grad_weight): + # If JSD is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + return grad_input, grad_weight + + +class LigerFusedLinearJSDFunction(torch.autograd.Function): + """ + Fusing the last linear layer with generalized JSD + + Handle the forward and backward pass of the final linear layer via JSD by avoiding + the materialization of the large logits tensor. Since JSD is the last layer, we can + compute the gradient at the forward pass. + """ + + @staticmethod + def forward( + ctx, + student_input, + student_weight, + teacher_input, + teacher_weight, + jsd_beta=0.5, + temperature=1.0, + ): + """ + Args: + + student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension. + student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size + teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension. + teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size + jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5` + temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` + + Returns: + loss (torch.Tensor): generalized JSD + """ + loss, grad_input, grad_weight = fused_linear_jsd_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + jsd_beta, + temperature, + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + ) + return loss + + @staticmethod + def backward(ctx, grad_output): + (grad_input, grad_weight) = ctx.saved_tensors + grad_input, grad_weight = fused_linear_jsd_backward( + grad_output, grad_input, grad_weight + ) + return (grad_input, grad_weight, None, None, None, None) diff --git a/src/liger_kernel/ops/jsd.py b/src/liger_kernel/ops/jsd.py index 1f9cab732..e2a77d1d3 100644 --- a/src/liger_kernel/ops/jsd.py +++ b/src/liger_kernel/ops/jsd.py @@ -42,6 +42,8 @@ def _jsd_kernel( log_M = tl.log(M) loss = beta * P * Y + (1 - beta) * Q * X - M * log_M + # reduction == "batchmean" + loss = loss / n_rows tl.store(loss_ptr + offsets, loss, mask=mask) dX = (1 - beta) * Q * (X - log_M) / n_rows @@ -73,8 +75,8 @@ def jsd_forward(_input, target, beta): n_cols=V, BLOCK_SIZE=BLOCK_SIZE, ) - # reduction == "batchmean" - loss = torch.sum(loss) / n_rows + + loss = torch.sum(loss) return loss.to(_input.dtype), dX diff --git a/src/liger_kernel/ops/utils.py b/src/liger_kernel/ops/utils.py index 99500896b..2c01f3ac1 100644 --- a/src/liger_kernel/ops/utils.py +++ b/src/liger_kernel/ops/utils.py @@ -68,3 +68,39 @@ def compare_version(package: str, operator: Callable, target: str): torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, } + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 63cf779b0..ffb8235cc 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -5,6 +5,7 @@ from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401 LigerFusedLinearCrossEntropyLoss, ) +from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401 from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401 from liger_kernel.transformers.jsd import LigerJSD # noqa: F401 from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401 diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index a8fb14e20..f160887b8 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -2,6 +2,7 @@ from liger_kernel.ops.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyFunction, ) +from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction from liger_kernel.ops.geglu import LigerGELUMulFunction from liger_kernel.ops.jsd import LigerJSDFunction from liger_kernel.ops.kl_div import LigerKLDivLossFunction @@ -19,3 +20,4 @@ liger_layer_norm = LigerLayerNormFunction.apply liger_kl_div = LigerKLDivLossFunction.apply liger_jsd = LigerJSDFunction.apply +liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply diff --git a/src/liger_kernel/transformers/fused_linear_jsd.py b/src/liger_kernel/transformers/fused_linear_jsd.py new file mode 100644 index 000000000..d9579fe4b --- /dev/null +++ b/src/liger_kernel/transformers/fused_linear_jsd.py @@ -0,0 +1,61 @@ +import torch.nn as nn + +from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction + + +class LigerFusedLinearJSD(nn.Module): + r"""Fusing the last linear layer with generalized JSD + + Handle the forward and backward pass of the final linear layer via JSD by avoiding + the materialization of the large logits tensor. + + Args: + jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5` + temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` + + Shape: + - student_input: :math:`(BT, H)`, where B is batch size, T is sequence length, H is hidden dimension. + - student_weight: :math:`(V, H)`, where V is vocab size. + - teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model. + - teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different. + - Output: a scalar. + + Examples: + ```python + >>> (B, T, H, V) = (2, 2, 3, 5) + >>> fused_jsd = LigerFusedLinearJSD(jsd_beta=0.1, temperature=2.0) + >>> # generate inputs and weights + >>> student_input = torch.rand(B * T, H, device="cuda", requires_grad=True) + >>> student_lin = torch.nn.Linear(H, V, bias=False, device="cuda") + >>> # teacher input doesn't require grad, hidden_dim can be different from student's + >>> teacher_input = torch.rand(B * T, H * 2, device="cuda") + >>> teacher_lin = torch.nn.Linear(H * 2, V, bias=False, device="cuda") + >>> output = fused_jsd(student_input, student_lin.weight, teacher_input, teacher_lin.weight) + >>> output.backward() + ``` + """ + + def __init__(self, jsd_beta=0.5, temperature=1.0): + super().__init__() + assert ( + jsd_beta > 0 and jsd_beta < 1 + ), f"beta must be greater than 0 and less than 1. Got: {jsd_beta}" + assert temperature != 0, "temperature cannot be 0." + self.jsd_beta = jsd_beta + self.temperature = temperature + + def forward( + self, + student_input, + student_weight, + teacher_input, + teacher_weight, + ): + return LigerFusedLinearJSDFunction.apply( + student_input, + student_weight, + teacher_input, + teacher_weight, + self.jsd_beta, + self.temperature, + ) diff --git a/test/transformers/test_fused_linear_jsd.py b/test/transformers/test_fused_linear_jsd.py new file mode 100644 index 000000000..e5c3b1035 --- /dev/null +++ b/test/transformers/test_fused_linear_jsd.py @@ -0,0 +1,209 @@ +from test.transformers.test_jsd import JSD as TorchJSD +from test.utils import set_seed + +import pytest +import torch + +from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction +from liger_kernel.transformers.functional import liger_fused_linear_jsd +from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD + +set_seed(42) + + +class TorchLMHeadJSD(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based jsd loss. + + :param H: hidden size + :param V: vocab size + :param temperature: softmax temperature + :param beta: jsd beta + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + temperature: float = 1.0, + beta: float = 0.5, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.jsd = TorchJSD(beta, dtype=dtype) + self.temperature = temperature + + def forward(self, student_input, teacher_input): + student_logits = self.student_lin(student_input) + teacher_logits = self.teacher_lin(teacher_input) + student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1) + teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1) + + return self.jsd(student_prob, teacher_prob) + + +class LigerLMHeadJSD(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + temperature: float = 1.0, + beta: float = 0.5, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.fused_jsd = LigerFusedLinearJSD(beta, temperature) + + def forward(self, student_input, teacher_input): + return self.fused_jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + ) + + +############################################################################# +# Test the correctness of the fused linear JSD +############################################################################# + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 4, 2048, 3200), + (2, 2048, 4096, 32000), # llama2, mistral + # Comment out to speed up testing + # (4, 2048, 4096, 128256), # llama3 8B + # (4, 1024, 8192, 128256), # llama3 70B + (4, 423, 8192, 32000), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-2), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize( + "temperature, beta", + [ + (1.0, 0.5), + (2.0, 0.1), + ], +) +def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol): + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + beta=beta, + ).to(device) + liger_lm_head_jsd = LigerLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + beta=beta, + ).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H // 2, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + output1 = torch_lm_head_jsd(_input1, teacher_input) + output2 = liger_lm_head_jsd(_input2, teacher_input) + + assert torch.allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + + assert torch.allclose( + torch_lm_head_jsd.student_lin.weight.grad, + liger_lm_head_jsd.student_lin.weight.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 4, 2048, 3200), + (2, 2048, 4096, 32000), # llama2, mistral + # Comment out to speed up testing + # (4, 2048, 4096, 128256), # llama3 8B + # (4, 1024, 8192, 128256), # llama3 70B + (4, 423, 8192, 32000), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (0.5, torch.bfloat16, 5e-3, 5e-2), + (0.5, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("temperature, beta", [(1.0, 0.5), (2.0, 0.1)]) +def test_correctness_functional( + B, T, H, V, scalar, dtype, beta, temperature, atol, rtol +): + device = "cuda" + + # init the linear in all FusedLinearJSDs with the same weights + _weight = torch.rand(V, H // 2, device=device, dtype=dtype) + _weight1 = _weight.detach().clone().requires_grad_(True) + _weight2 = _weight.detach().clone().requires_grad_(True) + teacher_weight = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + output1 = liger_fused_linear_jsd( + _input1, _weight1, teacher_input, teacher_weight, beta, temperature + ) + output2 = LigerFusedLinearJSDFunction.apply( + _input2, _weight2, teacher_input, teacher_weight, beta, temperature + ) + + assert torch.allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + + assert torch.allclose(_weight1.grad, _weight2.grad, atol=atol, rtol=rtol)