diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 65c74a7d4..32c8d01ab 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -481,3 +481,27 @@ jsd,torch,full,speed,ms,V,vocab size,16384,20.9442081451416,20.94247055053711,20 jsd,torch,full,speed,ms,V,vocab size,32768,42.113216400146484,42.113216400146484,42.113216400146484,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 jsd,torch,full,speed,ms,V,vocab size,65536,83.9959716796875,83.9959716796875,83.9959716796875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 jsd,torch,full,speed,ms,V,vocab size,131072,167.94175720214844,167.94175720214844,167.94175720214844,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,1024,110.02185821533203,110.02185821533203,110.02185821533203,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,2048,124.14070129394531,124.14070129394531,124.14070129394531,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,4096,143.15420532226562,143.15420532226562,143.15420532226562,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,8192,180.90406799316406,180.90406799316406,180.90406799316406,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,1024,9.556896209716797,9.550745964050293,9.576268196105957,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,2048,18.73731231689453,18.732704162597656,18.737701416015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,4096,37.830482482910156,37.80821990966797,37.85274124145508,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,8192,75.15289306640625,75.15289306640625,75.15289306640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,1024,111.16019439697266,111.16019439697266,111.16019439697266,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,2048,125.6825942993164,125.6825942993164,125.6825942993164,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,4096,144.00784301757812,144.00784301757812,144.00784301757812,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,8192,182.5832977294922,182.5832977294922,182.5832977294922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,1024,25.977184295654297,25.968351364135742,25.989356994628906,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,2048,49.48417663574219,49.47330093383789,49.495052337646484,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,4096,98.31510162353516,98.31510162353516,98.31510162353516,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,8192,195.29539489746094,195.29539489746094,195.29539489746094,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,1024,4652.48486328125,4652.48486328125,4652.48486328125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,2048,5231.93798828125,5231.93798828125,5231.93798828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,4096,6391.87548828125,6391.87548828125,6391.87548828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,8192,8711.75,8711.75,8711.75,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859375,10609.005859375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 diff --git a/benchmark/scripts/benchmark_fused_linear_jsd.py b/benchmark/scripts/benchmark_fused_linear_jsd.py new file mode 100644 index 000000000..ee32428a7 --- /dev/null +++ b/benchmark/scripts/benchmark_fused_linear_jsd.py @@ -0,0 +1,250 @@ +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD + + +class TorchJSD(torch.nn.Module): + def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float): + super(TorchJSD, self).__init__() + self.kl = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) + self.beta = beta + self.dtype = dtype + + def forward( + self, + log_q: torch.tensor, # input + log_p: torch.tensor, # target + ): + log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) + log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) + m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta) + loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl( + torch.log(m), log_q + ) + return loss.to(self.dtype) + + +class TorchLMHeadJSD(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based jsd loss. + + :param H: hidden size + :param V: vocab size + :param temperature: softmax temperature + :param beta: jsd beta + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + temperature: float = 1.0, + beta: float = 0.5, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.jsd = TorchJSD(beta, dtype=dtype) + self.temperature = temperature + + def forward(self, student_input, teacher_input): + student_logits = self.student_lin(student_input) + teacher_logits = self.teacher_lin(teacher_input) + student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1) + teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1) + + return self.jsd(student_prob, teacher_prob) + + +class LigerLMHeadJSD(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + temperature: float = 1.0, + beta: float = 0.5, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.fused_jsd = LigerFusedLinearJSD(beta, temperature) + + def forward(self, student_input, teacher_input): + return self.fused_jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + ) + + +############################################################################# +# Test the memory consumption of the fused linear JSD +############################################################################# + + +def bench_memory_fused_linear_jsd( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device) + teacher_input = torch.rand(BT, H, dtype=dtype, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_jsd(student_input, teacher_input) + elif provider == "torch": + return torch_lm_head_jsd(student_input, teacher_input) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear JSD +# ############################################################################# + + +def bench_speed_fused_linear_jsd( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + mode = input.kernel_operation_mode + + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device) + teacher_input = torch.rand(BT, H, dtype=dtype, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_jsd(student_input, teacher_input) + elif provider == "torch": + return torch_lm_head_jsd(student_input, teacher_input) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[ + student_input, + torch_lm_head_jsd.student_lin.weight, + torch_lm_head_jsd.teacher_lin.weight, + ], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_linear_jsd", + "x_name": "BT", + "x_label": "B x T", + "x_values": [2**i for i in range(10, 14)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + {"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16} + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_jsd, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_jsd, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 97c6c06cb..c72ba8d45 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from liger_kernel.ops.utils import element_mul_kernel + @triton.jit def liger_cross_entropy_kernel( @@ -159,42 +161,6 @@ def liger_cross_entropy_kernel( MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning -@triton.jit -def element_mul_kernel( - X_ptr, - X_stride, - grad_output_ptr, - n_cols, - BLOCK_SIZE: tl.constexpr, -): - """ - This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. - The multiplication is performed in-place on the tensor pointed by X_ptr. - - Parameters: - X_ptr: Pointer to the input tensor. - X_stride (int): The stride of the input tensor. - grad_output_ptr: Pointer to the gradient output value. - n_cols (int): The number of columns in the input tensor. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - # Get the program ID and convert it to int64 to avoid overflow - program_id = tl.program_id(0).to(tl.int64) - - # Locate the start index - X_ptr += program_id * X_stride - - # Load the gradient output value - grad_output = tl.load(grad_output_ptr) - - # Perform the element-wise multiplication - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) - tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) - - def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction): BT, V = _input.shape n_rows = BT diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 73da9cd46..e9a28afbb 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -1,10 +1,8 @@ import torch import triton -from liger_kernel.ops.cross_entropy import ( - element_mul_kernel, - liger_cross_entropy_kernel, -) +from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel +from liger_kernel.ops.utils import element_mul_kernel # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling diff --git a/src/liger_kernel/ops/fused_linear_jsd.py b/src/liger_kernel/ops/fused_linear_jsd.py new file mode 100644 index 000000000..9f191c14a --- /dev/null +++ b/src/liger_kernel/ops/fused_linear_jsd.py @@ -0,0 +1,201 @@ +import torch +import triton + +from liger_kernel.ops.jsd import _jsd_kernel +from liger_kernel.ops.utils import element_mul_kernel + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 + + +def fused_linear_jsd_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + jsd_beta, + temperature, +): + device = student_input.device + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = student_input.shape + V = student_weight.shape[0] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2( + triton.cdiv(BT, inc_factor) + ) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + grad_weight = ( + torch.zeros_like(student_weight, device=device) + if student_weight.requires_grad + else None + ) + grad_input = torch.zeros_like(student_input) + # we use fp32 for loss accumulator + loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device) + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + + # chunk both inputs, shape: chunk_size x H + student_input_chunk = student_input[start_idx:end_idx] + teacher_input_chunk = teacher_input[start_idx:end_idx] + + # when doing matmul, use the original precision, shape: chunk_size x V + student_logits_chunk = student_input_chunk @ student_weight.t() + teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() + chunk_n_rows = student_logits_chunk.shape[0] + + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size + # log-softmax with temperature + student_logits_chunk = student_logits_chunk / temperature + teacher_logits_chunk = teacher_logits_chunk / temperature + student_prob_chunk = torch.log_softmax(student_logits_chunk, dim=-1) + teacher_prob_chunk = torch.log_softmax(teacher_logits_chunk, dim=-1) + + # ensure _input and target are contiguous + student_prob_chunk = student_prob_chunk.contiguous() + teacher_prob_chunk = teacher_prob_chunk.contiguous() + + # Here we calculate the gradient of prob_chunk in place so we can save memory. + _jsd_kernel[(chunk_n_rows,)]( + X_ptr=student_prob_chunk, + X_stride=student_prob_chunk.stride(-2), + Y_ptr=teacher_prob_chunk, + Y_stride=teacher_prob_chunk.stride(-2), + loss_ptr=loss_1d_slice, + loss_stride=loss_1d_slice.stride(-2), + dX_ptr=student_prob_chunk, + dX_stride=student_prob_chunk.stride(-2), + beta=jsd_beta, + n_rows=BT, # batchmean + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + ) + loss_1d[start_idx:end_idx] = loss_1d_slice + # gradients of prob_chunk in place, shape: chunk_size x V + # gradients of logits_chunk in place, shape: chunk_size x V + student_logits_chunk = ( + student_prob_chunk + - torch.softmax(student_logits_chunk, dim=-1) + * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to( + student_prob_chunk.shape + ) + ) / temperature + grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight + + if grad_weight is not None: + torch.addmm( + input=grad_weight, + mat1=student_logits_chunk.t(), # gradients of logits_chunk + mat2=student_input_chunk, + out=grad_weight, + ) + + loss = torch.sum(loss_1d) + return loss.to(student_input.dtype), grad_input, grad_weight + + +def fused_linear_jsd_backward(grad_output, grad_input, grad_weight): + # If JSD is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + return grad_input, grad_weight + + +class LigerFusedLinearJSDFunction(torch.autograd.Function): + """ + Fusing the last linear layer with generalized JSD + + Handle the forward and backward pass of the final linear layer via JSD by avoiding + the materialization of the large logits tensor. Since JSD is the last layer, we can + compute the gradient at the forward pass. + """ + + @staticmethod + def forward( + ctx, + student_input, + student_weight, + teacher_input, + teacher_weight, + jsd_beta=0.5, + temperature=1.0, + ): + """ + Args: + + student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension. + student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size + teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension. + teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size + jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5` + temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` + + Returns: + loss (torch.Tensor): generalized JSD + """ + loss, grad_input, grad_weight = fused_linear_jsd_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + jsd_beta, + temperature, + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + ) + return loss + + @staticmethod + def backward(ctx, grad_output): + (grad_input, grad_weight) = ctx.saved_tensors + grad_input, grad_weight = fused_linear_jsd_backward( + grad_output, grad_input, grad_weight + ) + return (grad_input, grad_weight, None, None, None, None) diff --git a/src/liger_kernel/ops/jsd.py b/src/liger_kernel/ops/jsd.py index 1f9cab732..e2a77d1d3 100644 --- a/src/liger_kernel/ops/jsd.py +++ b/src/liger_kernel/ops/jsd.py @@ -42,6 +42,8 @@ def _jsd_kernel( log_M = tl.log(M) loss = beta * P * Y + (1 - beta) * Q * X - M * log_M + # reduction == "batchmean" + loss = loss / n_rows tl.store(loss_ptr + offsets, loss, mask=mask) dX = (1 - beta) * Q * (X - log_M) / n_rows @@ -73,8 +75,8 @@ def jsd_forward(_input, target, beta): n_cols=V, BLOCK_SIZE=BLOCK_SIZE, ) - # reduction == "batchmean" - loss = torch.sum(loss) / n_rows + + loss = torch.sum(loss) return loss.to(_input.dtype), dX diff --git a/src/liger_kernel/ops/utils.py b/src/liger_kernel/ops/utils.py index 99500896b..2c01f3ac1 100644 --- a/src/liger_kernel/ops/utils.py +++ b/src/liger_kernel/ops/utils.py @@ -68,3 +68,39 @@ def compare_version(package: str, operator: Callable, target: str): torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, } + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 63cf779b0..ffb8235cc 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -5,6 +5,7 @@ from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401 LigerFusedLinearCrossEntropyLoss, ) +from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401 from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401 from liger_kernel.transformers.jsd import LigerJSD # noqa: F401 from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401 diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index a8fb14e20..f160887b8 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -2,6 +2,7 @@ from liger_kernel.ops.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyFunction, ) +from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction from liger_kernel.ops.geglu import LigerGELUMulFunction from liger_kernel.ops.jsd import LigerJSDFunction from liger_kernel.ops.kl_div import LigerKLDivLossFunction @@ -19,3 +20,4 @@ liger_layer_norm = LigerLayerNormFunction.apply liger_kl_div = LigerKLDivLossFunction.apply liger_jsd = LigerJSDFunction.apply +liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply diff --git a/src/liger_kernel/transformers/fused_linear_jsd.py b/src/liger_kernel/transformers/fused_linear_jsd.py new file mode 100644 index 000000000..d9579fe4b --- /dev/null +++ b/src/liger_kernel/transformers/fused_linear_jsd.py @@ -0,0 +1,61 @@ +import torch.nn as nn + +from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction + + +class LigerFusedLinearJSD(nn.Module): + r"""Fusing the last linear layer with generalized JSD + + Handle the forward and backward pass of the final linear layer via JSD by avoiding + the materialization of the large logits tensor. + + Args: + jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5` + temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` + + Shape: + - student_input: :math:`(BT, H)`, where B is batch size, T is sequence length, H is hidden dimension. + - student_weight: :math:`(V, H)`, where V is vocab size. + - teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model. + - teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different. + - Output: a scalar. + + Examples: + ```python + >>> (B, T, H, V) = (2, 2, 3, 5) + >>> fused_jsd = LigerFusedLinearJSD(jsd_beta=0.1, temperature=2.0) + >>> # generate inputs and weights + >>> student_input = torch.rand(B * T, H, device="cuda", requires_grad=True) + >>> student_lin = torch.nn.Linear(H, V, bias=False, device="cuda") + >>> # teacher input doesn't require grad, hidden_dim can be different from student's + >>> teacher_input = torch.rand(B * T, H * 2, device="cuda") + >>> teacher_lin = torch.nn.Linear(H * 2, V, bias=False, device="cuda") + >>> output = fused_jsd(student_input, student_lin.weight, teacher_input, teacher_lin.weight) + >>> output.backward() + ``` + """ + + def __init__(self, jsd_beta=0.5, temperature=1.0): + super().__init__() + assert ( + jsd_beta > 0 and jsd_beta < 1 + ), f"beta must be greater than 0 and less than 1. Got: {jsd_beta}" + assert temperature != 0, "temperature cannot be 0." + self.jsd_beta = jsd_beta + self.temperature = temperature + + def forward( + self, + student_input, + student_weight, + teacher_input, + teacher_weight, + ): + return LigerFusedLinearJSDFunction.apply( + student_input, + student_weight, + teacher_input, + teacher_weight, + self.jsd_beta, + self.temperature, + ) diff --git a/test/transformers/test_fused_linear_jsd.py b/test/transformers/test_fused_linear_jsd.py new file mode 100644 index 000000000..e5c3b1035 --- /dev/null +++ b/test/transformers/test_fused_linear_jsd.py @@ -0,0 +1,209 @@ +from test.transformers.test_jsd import JSD as TorchJSD +from test.utils import set_seed + +import pytest +import torch + +from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction +from liger_kernel.transformers.functional import liger_fused_linear_jsd +from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD + +set_seed(42) + + +class TorchLMHeadJSD(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based jsd loss. + + :param H: hidden size + :param V: vocab size + :param temperature: softmax temperature + :param beta: jsd beta + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + temperature: float = 1.0, + beta: float = 0.5, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.jsd = TorchJSD(beta, dtype=dtype) + self.temperature = temperature + + def forward(self, student_input, teacher_input): + student_logits = self.student_lin(student_input) + teacher_logits = self.teacher_lin(teacher_input) + student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1) + teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1) + + return self.jsd(student_prob, teacher_prob) + + +class LigerLMHeadJSD(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + temperature: float = 1.0, + beta: float = 0.5, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.fused_jsd = LigerFusedLinearJSD(beta, temperature) + + def forward(self, student_input, teacher_input): + return self.fused_jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + ) + + +############################################################################# +# Test the correctness of the fused linear JSD +############################################################################# + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 4, 2048, 3200), + (2, 2048, 4096, 32000), # llama2, mistral + # Comment out to speed up testing + # (4, 2048, 4096, 128256), # llama3 8B + # (4, 1024, 8192, 128256), # llama3 70B + (4, 423, 8192, 32000), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-2), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize( + "temperature, beta", + [ + (1.0, 0.5), + (2.0, 0.1), + ], +) +def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol): + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + beta=beta, + ).to(device) + liger_lm_head_jsd = LigerLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + beta=beta, + ).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H // 2, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + output1 = torch_lm_head_jsd(_input1, teacher_input) + output2 = liger_lm_head_jsd(_input2, teacher_input) + + assert torch.allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + + assert torch.allclose( + torch_lm_head_jsd.student_lin.weight.grad, + liger_lm_head_jsd.student_lin.weight.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 4, 2048, 3200), + (2, 2048, 4096, 32000), # llama2, mistral + # Comment out to speed up testing + # (4, 2048, 4096, 128256), # llama3 8B + # (4, 1024, 8192, 128256), # llama3 70B + (4, 423, 8192, 32000), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (0.5, torch.bfloat16, 5e-3, 5e-2), + (0.5, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("temperature, beta", [(1.0, 0.5), (2.0, 0.1)]) +def test_correctness_functional( + B, T, H, V, scalar, dtype, beta, temperature, atol, rtol +): + device = "cuda" + + # init the linear in all FusedLinearJSDs with the same weights + _weight = torch.rand(V, H // 2, device=device, dtype=dtype) + _weight1 = _weight.detach().clone().requires_grad_(True) + _weight2 = _weight.detach().clone().requires_grad_(True) + teacher_weight = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + output1 = liger_fused_linear_jsd( + _input1, _weight1, teacher_input, teacher_weight, beta, temperature + ) + output2 = LigerFusedLinearJSDFunction.apply( + _input2, _weight2, teacher_input, teacher_weight, beta, temperature + ) + + assert torch.allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + + assert torch.allclose(_weight1.grad, _weight2.grad, atol=atol, rtol=rtol)