Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Chunked SimPO Loss #386

Merged
merged 9 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -667,3 +667,51 @@ fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.3144
fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0
fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0
fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,16,33418.421875,33418.421875,33418.421875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0
fused_linear_cpo_loss,liger,forward,speed,ms,B,B,2,31.536447525024414,31.457439422607422,31.543052673339844,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:54:47,0.4.1
fused_linear_cpo_loss,liger,forward,speed,ms,B,B,4,62.407745361328125,62.407745361328125,62.407745361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:54:47,0.4.1
fused_linear_cpo_loss,liger,forward,speed,ms,B,B,8,123.64259338378906,123.64259338378906,123.64259338378906,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:54:47,0.4.1
fused_linear_cpo_loss,liger,forward,speed,ms,B,B,16,245.66575622558594,245.66575622558594,245.66575622558594,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:54:47,0.4.1
fused_linear_cpo_loss,huggingface,forward,speed,ms,B,B,2,14.516239166259766,14.514080047607422,14.52575969696045,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:20,0.4.1
fused_linear_cpo_loss,huggingface,forward,speed,ms,B,B,4,26.087743759155273,25.943340301513672,26.269376754760742,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:20,0.4.1
fused_linear_cpo_loss,huggingface,forward,speed,ms,B,B,8,51.85932922363281,51.85932922363281,51.85932922363281,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:20,0.4.1
fused_linear_cpo_loss,huggingface,forward,speed,ms,B,B,16,104.99673461914062,104.99673461914062,104.99673461914062,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:20,0.4.1
fused_linear_cpo_loss,liger,full,speed,ms,B,B,2,33.309967041015625,33.21604919433594,33.40388488769531,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:55,0.4.1
fused_linear_cpo_loss,liger,full,speed,ms,B,B,4,63.053470611572266,63.053470611572266,63.053470611572266,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:55,0.4.1
fused_linear_cpo_loss,liger,full,speed,ms,B,B,8,125.53849792480469,125.53849792480469,125.53849792480469,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:55,0.4.1
fused_linear_cpo_loss,liger,full,speed,ms,B,B,16,250.22178649902344,250.22178649902344,250.22178649902344,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:55,0.4.1
fused_linear_cpo_loss,huggingface,full,speed,ms,B,B,2,39.45849609375,39.33102798461914,39.58596420288086,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:56:30,0.4.1
fused_linear_cpo_loss,huggingface,full,speed,ms,B,B,4,77.00272369384766,77.00272369384766,77.00272369384766,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:56:30,0.4.1
fused_linear_cpo_loss,huggingface,full,speed,ms,B,B,8,154.28419494628906,154.28419494628906,154.28419494628906,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:56:30,0.4.1
fused_linear_cpo_loss,huggingface,full,speed,ms,B,B,16,309.23162841796875,309.23162841796875,309.23162841796875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:56:30,0.4.1
fused_linear_cpo_loss,liger,full,memory,MB,B,B,2,8161.34619140625,8161.34619140625,8161.34619140625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:06,0.4.1
fused_linear_cpo_loss,liger,full,memory,MB,B,B,4,8209.361328125,8209.361328125,8209.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:06,0.4.1
fused_linear_cpo_loss,liger,full,memory,MB,B,B,8,8305.392578125,8305.392578125,8305.392578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:06,0.4.1
fused_linear_cpo_loss,liger,full,memory,MB,B,B,16,8497.455078125,8497.455078125,8497.455078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:06,0.4.1
fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1
fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1
fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1
fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1
fused_linear_simpo_loss,liger,forward,speed,ms,B,B,2,30.28438377380371,30.107013702392578,30.284786224365234,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1
fused_linear_simpo_loss,liger,forward,speed,ms,B,B,4,58.80876922607422,58.80876922607422,58.80876922607422,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1
fused_linear_simpo_loss,liger,forward,speed,ms,B,B,8,117.96163177490234,117.96163177490234,117.96163177490234,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1
fused_linear_simpo_loss,liger,forward,speed,ms,B,B,16,235.60794067382812,235.60794067382812,235.60794067382812,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1
fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,2,14.513839721679688,14.510687828063965,14.517855644226074,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1
fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,4,28.78099250793457,28.72719383239746,28.792186737060547,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1
fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,8,52.5733757019043,52.5733757019043,52.5733757019043,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1
fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,16,104.44764709472656,104.44764709472656,104.44764709472656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1
fused_linear_simpo_loss,liger,full,speed,ms,B,B,2,31.566062927246094,31.457612991333008,31.674514770507812,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1
fused_linear_simpo_loss,liger,full,speed,ms,B,B,4,61.4403190612793,61.4403190612793,61.4403190612793,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1
fused_linear_simpo_loss,liger,full,speed,ms,B,B,8,119.97705841064453,119.97705841064453,119.97705841064453,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1
fused_linear_simpo_loss,liger,full,speed,ms,B,B,16,238.13417053222656,238.13417053222656,238.13417053222656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1
fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,2,39.811119079589844,39.65474319458008,39.96749496459961,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1
fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,4,77.20928192138672,77.20928192138672,77.20928192138672,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1
fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,8,153.6952667236328,153.6952667236328,153.6952667236328,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1
fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,16,307.7382507324219,307.7382507324219,307.7382507324219,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1
fused_linear_simpo_loss,liger,full,memory,MB,B,B,2,7675.3291015625,7675.3291015625,7675.3291015625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1
fused_linear_simpo_loss,liger,full,memory,MB,B,B,4,7723.3447265625,7723.3447265625,7723.3447265625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1
fused_linear_simpo_loss,liger,full,memory,MB,B,B,8,7819.3759765625,7819.3759765625,7819.3759765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1
fused_linear_simpo_loss,liger,full,memory,MB,B,B,16,8011.4384765625,8011.4384765625,8011.4384765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
191 changes: 191 additions & 0 deletions benchmark/scripts/benchmark_cpo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import os
import sys

import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchLMHeadCPO(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.

:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""

def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
from test.chunked_loss.test_cpo_loss import HFCPOLoss

super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.cpo_loss = HFCPOLoss().get_batch_loss_metrics

def forward(self, x, y):
return self.cpo_loss(x, self.lin.weight, y)


class LigerLMHeadCPO(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.cpo_loss = LigerFusedLinearCPOFunction.apply

def forward(self, x, y):
return self.cpo_loss(x, self.lin.weight, y)


#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################


def bench_memory_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_cpo(_input, target)
elif provider == "huggingface":
return torch_lm_head_cpo(_input, target)

def full():
y = fwd()
y.backward()

mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


# #############################################################################
# # Test the speed of the fused linear cross entropy loss
# #############################################################################


def bench_speed_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"

torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_cpo(_input, target)
elif provider == "huggingface":
return torch_lm_head_cpo(_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":

def full():
y = fwd()
y.backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "fused_linear_cpo_loss",
"x_name": "B",
"x_label": "B",
"x_values": [2**i for i in range(1, 5)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 1024,
"H": 4096,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_fused_linear_cpo_loss,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_cpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
)
Loading
Loading