-
Notifications
You must be signed in to change notification settings - Fork 230
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Adds chunked ORPO loss kernel <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> Benchmarks ![Speed ORPO](https://github.com/user-attachments/assets/ae9e6f67-14cd-4189-9d64-9a2f94a3b3c6) ![Mem ORPO](https://github.com/user-attachments/assets/47c289f4-2876-4530-949c-2c2825bc0f79) References: 1. #227 2. https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: shisahni_LinkedIn <shisahni@linkedin.com>
- Loading branch information
Showing
8 changed files
with
700 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
import triton | ||
from utils import ( | ||
QUANTILES, | ||
SingleBenchmarkRunInput, | ||
SingleBenchmarkRunOutput, | ||
_test_memory, | ||
parse_benchmark_script_args, | ||
run_benchmarks, | ||
) | ||
|
||
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction | ||
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | ||
|
||
|
||
class TorchLMHeadORPO(torch.nn.Module): | ||
"""Ground truth implementation of the linear fused with torch based cross entropy loss. | ||
:param H: hidden size | ||
:param V: vocab size | ||
:param ignore_index: index to ignore | ||
:param reduction: reduction method | ||
""" | ||
|
||
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | ||
from test.chunked_loss.test_orpo_loss import HF_ORPO_Loss | ||
|
||
super().__init__() | ||
self.lin = torch.nn.Linear( | ||
in_features=H, out_features=V, bias=False, dtype=dtype | ||
) | ||
self.orpo_loss = HF_ORPO_Loss().get_batch_loss_metrics | ||
|
||
def forward(self, x, y): | ||
return self.orpo_loss(x, self.lin.weight, y) | ||
|
||
|
||
class LigerLMHeadORPO(torch.nn.Module): | ||
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | ||
super().__init__() | ||
self.lin = torch.nn.Linear( | ||
in_features=H, out_features=V, bias=False, dtype=dtype | ||
) | ||
self.orpo_loss = LigerFusedLinearORPOFunction.apply | ||
|
||
def forward(self, x, y): | ||
return self.orpo_loss(x, self.lin.weight, y) | ||
|
||
|
||
############################################################################# | ||
# Test the memory consumption of the linear fused cross entropy loss | ||
############################################################################# | ||
|
||
|
||
def bench_memory_fused_linear_orpo_loss( | ||
input: SingleBenchmarkRunInput, | ||
) -> SingleBenchmarkRunOutput: | ||
B = input.x | ||
T = input.extra_benchmark_config["T"] | ||
H = input.extra_benchmark_config["H"] | ||
V = input.extra_benchmark_config["V"] | ||
dtype = input.extra_benchmark_config["dtype"] | ||
provider = input.kernel_provider | ||
|
||
device = "cuda" | ||
torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) | ||
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) | ||
|
||
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) | ||
target = torch.randint(V, (B, T), dtype=torch.long, device=device) | ||
|
||
def fwd(): | ||
if provider == "liger": | ||
return liger_lm_head_orpo(_input, target) | ||
elif provider == "huggingface": | ||
return torch_lm_head_orpo(_input, target) | ||
|
||
def full(): | ||
y = fwd() | ||
y.backward() | ||
|
||
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) | ||
return SingleBenchmarkRunOutput( | ||
y_20=mem_20, | ||
y_50=mem_50, | ||
y_80=mem_80, | ||
) | ||
|
||
|
||
# ############################################################################# | ||
# # Test the speed of the fused linear cross entropy loss | ||
# ############################################################################# | ||
|
||
|
||
def bench_speed_fused_linear_orpo_loss( | ||
input: SingleBenchmarkRunInput, | ||
) -> SingleBenchmarkRunOutput: | ||
B = input.x | ||
T = input.extra_benchmark_config["T"] | ||
H = input.extra_benchmark_config["H"] | ||
V = input.extra_benchmark_config["V"] | ||
dtype = input.extra_benchmark_config["dtype"] | ||
provider = input.kernel_provider | ||
mode = input.kernel_operation_mode | ||
|
||
device = "cuda" | ||
|
||
torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) | ||
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) | ||
|
||
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) | ||
target = torch.randint(V, (B, T), dtype=torch.long, device=device) | ||
|
||
def fwd(): | ||
if provider == "liger": | ||
return liger_lm_head_orpo(_input, target) | ||
elif provider == "huggingface": | ||
return torch_lm_head_orpo(_input, target) | ||
|
||
if mode == "forward": | ||
ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
fwd, | ||
rep=100, | ||
quantiles=QUANTILES, | ||
) | ||
elif mode == "backward": | ||
y = fwd() | ||
|
||
ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
lambda: y.backward(retain_graph=True), | ||
grad_to_none=[_input], | ||
rep=100, | ||
quantiles=QUANTILES, | ||
) | ||
elif mode == "full": | ||
|
||
def full(): | ||
y = fwd() | ||
y.backward() | ||
|
||
ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
full, | ||
rep=100, | ||
quantiles=QUANTILES, | ||
) | ||
return SingleBenchmarkRunOutput( | ||
y_20=ms_20, | ||
y_50=ms_50, | ||
y_80=ms_80, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_benchmark_script_args() | ||
|
||
common_configs = { | ||
"kernel_name": "fused_linear_orpo_loss", | ||
"x_name": "B", | ||
"x_label": "B", | ||
"x_values": [2**i for i in range(1, 5)], | ||
"kernel_providers": ["liger", "huggingface"], | ||
"extra_benchmark_configs": [ | ||
{ | ||
"T": 1024, | ||
"H": 4096, | ||
"V": 128256, | ||
"mode": "forward", | ||
"dtype": torch.bfloat16, | ||
} | ||
], | ||
"overwrite": args.overwrite, | ||
} | ||
|
||
run_benchmarks( | ||
bench_test_fn=bench_speed_fused_linear_orpo_loss, | ||
kernel_operation_modes=["forward", "full"], | ||
metric_name="speed", | ||
metric_unit="ms", | ||
**common_configs | ||
) | ||
run_benchmarks( | ||
bench_test_fn=bench_memory_fused_linear_orpo_loss, | ||
kernel_operation_modes=["full"], | ||
metric_name="memory", | ||
metric_unit="MB", | ||
**common_configs | ||
) |
Empty file.
Empty file.
Oops, something went wrong.