From b752e81d1a423d7d0423ed5105e048229bc2a5c6 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 27 Dec 2024 12:42:45 -0500 Subject: [PATCH 01/17] add lora benchmark files Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 867 +++++++++++++++++++++++++++ benchmarks/kernels/utils.py | 209 +++++++ 2 files changed, 1076 insertions(+) create mode 100644 benchmarks/kernels/benchmark_lora.py create mode 100644 benchmarks/kernels/utils.py diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py new file mode 100644 index 0000000000000..e9519da7c34ff --- /dev/null +++ b/benchmarks/kernels/benchmark_lora.py @@ -0,0 +1,867 @@ +import argparse +import copy +import json +import pickle +import time +from dataclasses import dataclass +from enum import Enum, auto +from itertools import product +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import ArgPool, Bench, CudaGraphBenchParams +from weight_shapes import WEIGHT_SHAPES + +from vllm.lora.ops.bgmv_expand import bgmv_expand +from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice +from vllm.lora.ops.bgmv_shrink import bgmv_shrink +from vllm.lora.ops.sgmv_expand import sgmv_expand +from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice +from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_TP_SIZES = [1] +DEFAULT_BATCH_SIZES = [ + 1, 16, 32, 64, 128, 192, 256, 320, 384, 448, 512, 640, 768, 896, 1024, + 2048, 3072, 4096, 5120, 6144, 7168, 8192 +] +DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384] +DEFAULT_LORA_RANKS = [16] +DEFAULT_NUM_LORAS = [1, 2, 3, 4] +DEFAULT_SORT_BY_LORA_IDS = [False, True] +DEFAULT_SEQ_LENGTHS = [1] + + +## Utilities +def dtype_to_str(dtype: torch.dtype): + if dtype == torch.float16: + return "f16" + if dtype == torch.bfloat16: + return "bf16" + if dtype == torch.float32: + return "f32" + raise ValueError(f"Unsupported dtype {dtype}") + + +def make_rand_lora_weight_tensor(k: int, + n: int, + num_loras: int, + dtype: torch.dtype, + device: str = "cuda") -> torch.Tensor: + + # LoRA weights column major + return torch.rand((num_loras, n, k), dtype=dtype).to(device) + + +def make_rand_tensors( + m: int, + k: int, + n: int, + num_loras: int, + num_slices: Optional[int], + a_dtype: torch.dtype, + b_dtype: torch.dtype, + c_dtype: torch.dtype, + device: str = "cuda", +) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: + # Make input / output tensors + # Input matrix A of shape {m, k} + # num_slices Input matrices B of shape {k, n} + # Output matrix C of shape {m, n * num_slices} + num_slices = num_slices if num_slices is not None else 1 + + A = torch.rand((m, k), dtype=a_dtype).to(device) + + # LoRA weights column major + Bs = [ + make_rand_lora_weight_tensor(k, n, num_loras, b_dtype, device) + for _ in range(num_slices) + ] + + C = torch.zeros((m, n * num_slices), dtype=c_dtype).to(device) + + return A, Bs, C + + +def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, + sort_by_lora_id: bool, + device: str) -> torch.Tensor: + """ + All prompts are mapped to a Lora ID in range [0, num_active_loras). + where 0 refers to first lora, 1 refers to second lora and so on. + """ + assert num_active_loras > 0 + + if not sort_by_lora_id: + return torch.randint(0, + num_active_loras, (num_prompts, ), + dtype=torch.long) + + # Divide LoRAs equally and in order. + part_size = num_prompts // num_active_loras + part_size = max(part_size, 1) + + lora_id = 0 + prompt_lora_mapping = [] + while len(prompt_lora_mapping) < num_prompts: + prompt_lora_mapping.extend([lora_id] * part_size) + lora_id = lora_id + 1 if lora_id < num_active_loras else lora_id + return torch.tensor(prompt_lora_mapping[:num_prompts], + dtype=torch.long, + device=device) + + +def make_token_lora_mapping(num_tokens: int, num_prompts: int, + prompt_lora_mapping: torch.Tensor, + seq_len_tensor: torch.Tensor, device: str): + """ + Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor + """ + assert prompt_lora_mapping.shape[0] == num_prompts + + # token to lora index mapping + token_lora_mapping = [0] * num_tokens + current_offset = 0 + for b_id in range(num_prompts): + lora_index = prompt_lora_mapping[b_id].item() + s = current_offset + e = s + seq_len_tensor[b_id].item() + token_lora_mapping[s:e] = [lora_index] * (e - s) + current_offset += seq_len_tensor[b_id].item() + + return torch.tensor(token_lora_mapping, dtype=torch.long, device=device) + + +## LoRA Ops to Benchmark and its properties +class OpType(Enum): + SGMV_SHRINK = auto() + BGMV_SHRINK = auto() + SGMV_EXPAND = auto() + BGMV_EXPAND = auto() + SGMV_EXPAND_SLICE = auto() + BGMV_EXPAND_SLICE = auto() + + @staticmethod + def from_str(s: str) -> "OpType": + if s.lower() == 'sgmv_shrink': + return OpType.SGMV_SHRINK + if s.lower() == 'sgmv_expand': + return OpType.SGMV_EXPAND + if s.lower() == 'bgmv_shrink': + return OpType.BGMV_SHRINK + if s.lower() == 'bgmv_expand': + return OpType.BGMV_EXPAND + if s.lower() == "sgmv_expand_slice": + return OpType.SGMV_EXPAND_SLICE + if s.lower() == "bgmv_expand_slice": + return OpType.BGMV_EXPAND_SLICE + raise ValueError(f"Unrecognized str {s} to convert to OpType") + + def is_shrink_fn(self) -> bool: + return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK] + + def is_expand_fn(self) -> bool: + return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND] + + def is_expand_slice_fn(self) -> bool: + return self in [OpType.SGMV_EXPAND_SLICE, OpType.BGMV_EXPAND_SLICE] + + def is_prefill_op(self) -> bool: + return self in [ + OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.SGMV_EXPAND_SLICE + ] + + def is_decode_op(self) -> bool: + return self in [ + OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE + ] + + def num_slices(self) -> int: + if self.is_expand_slice_fn(): + return 3 + return 1 + + def mkn(self, batch_size: int, seq_length: int, hidden_size: int, + lora_rank: int) -> Tuple[int, int, int]: + num_tokens = batch_size * seq_length + if self.is_shrink_fn(): + m = num_tokens + k = hidden_size + n = lora_rank + else: + assert self.is_expand_fn() or self.is_expand_slice_fn() + m = num_tokens + k = lora_rank + n = hidden_size + return m, k, n + + def matmul_dtypes( + self, op_dtype: torch.dtype + ) -> Tuple[torch.dtype, torch.dtype, torch.dtype]: + """ + return a type, b type and c type for A x B = C + """ + if self.is_shrink_fn(): + return op_dtype, op_dtype, torch.float32 + else: + assert self.is_expand_fn() or self.is_expand_slice_fn() + return torch.float32, op_dtype, op_dtype + + def bench_fn(self) -> Callable: + + def emulate_sgmv_expand_slice(kwargs_list: List[Dict[str, Any]]): + for x in kwargs_list: + sgmv_expand_slice(**x) + + def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]): + for x in kwargs_list: + bgmv_expand_slice(**x) + + if self == OpType.SGMV_SHRINK: + return sgmv_shrink + if self == OpType.SGMV_EXPAND: + return sgmv_expand + if self == OpType.BGMV_SHRINK: + return bgmv_shrink + if self == OpType.BGMV_EXPAND: + return bgmv_expand + if self == OpType.SGMV_EXPAND_SLICE: + return emulate_sgmv_expand_slice + if self == OpType.BGMV_EXPAND_SLICE: + return emulate_bgmv_expand_slice + raise ValueError(f"Unrecognized optype {self}") + + +@dataclass +class BenchmarkContext: + """ + LoRA benchmark context + """ + batch_size: int + hidden_size: int + num_loras: int + num_active_loras: int + lora_rank: int + sort_by_lora_id: bool + dtype: torch.dtype + seq_length: Optional[int] = None + num_slices: Optional[int] = None # num_slices for expand_slice kernels + + def with_seq_length(self, seq_length: int) -> "BenchmarkContext": + ctx = copy.copy(self) + ctx.seq_length = seq_length + return ctx + + def with_num_slices(self, num_slices: Optional[int]) -> "BenchmarkContext": + ctx = copy.copy(self) + ctx.num_slices = num_slices + return ctx + + def bench_label(self) -> str: + return f"lora-{self.dtype}" + + def bench_sublabel(self, op_type: OpType) -> str: + m, k, n = op_type.mkn(self.batch_size, self.seq_length, + self.hidden_size, self.lora_rank) + desc = { + 'bs': self.batch_size, + 'sl': self.seq_length, + 'm': m, + 'k': k, + 'n': n, + 'num_loras': self.num_loras, + 'sort_by_lora': self.sort_by_lora_id + } + return json.dumps(desc) + + +@dataclass +class BenchmarkTensors: + """ + Input/Output tensors used for benchmarks + """ + # matmul tensors + input: torch.Tensor + lora_weights_lst: List[torch.Tensor] + output: torch.Tensor + # metadata tensors + seq_lens: torch.Tensor + seq_start_loc: torch.Tensor + prompt_lora_mapping: torch.Tensor + token_lora_mapping: torch.Tensor + + def io_types(self) -> str: + return (f"{dtype_to_str(self.input.dtype)}x" + f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>" + f"{dtype_to_str(self.output.dtype)}") + + @staticmethod + def make(ctx: BenchmarkContext, + op_type: OpType, + device: str = "cuda") -> "BenchmarkTensors": + + ## Make input / output matmul tensors + a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) + m, k, n = op_type.mkn(ctx.batch_size, ctx.seq_length, ctx.hidden_size, + ctx.lora_rank) + input_tensor, lora_weights, output_tensor = \ + make_rand_tensors(m, k, n, ctx.num_loras, + num_slices = ctx.num_slices, + a_dtype = a_type, + b_dtype = b_type, + c_dtype = c_type) + + ## Make metadata tensors + # Keep the metadata tensors in the CPU for further processing if needed. + # The tensors get moved to the GPU before benchmarking. + assert ctx.num_active_loras <= ctx.num_loras + total_tokens = ctx.batch_size * ctx.seq_length + + # Prepare seq lens tensor + seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1, + (ctx.batch_size, )) + # Prepare seq_start_loc tensor + seq_start_loc_tensor = torch.cumsum(torch.tensor( + [0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), + dim=0) + assert total_tokens == seq_len_tensor.sum() + # Prepare prompt lora indices tensor + prompt_lora_indices_tensor = make_prompt_lora_mapping( + ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu") + # Prepare token lora indices tensor + token_lora_indices_tensor = make_token_lora_mapping( + total_tokens, ctx.batch_size, prompt_lora_indices_tensor, + seq_len_tensor, "cpu") + + return BenchmarkTensors(input_tensor, lora_weights, output_tensor, + seq_len_tensor, seq_start_loc_tensor, + prompt_lora_indices_tensor, + token_lora_indices_tensor) + + def sanity_check(self) -> None: + """ + Fails asserts when non-conformality is detected. + """ + # Check that the tensors have the right shapes + m = self.input.shape[0] + k = self.input.shape[1] + n = self.output.shape[1] + + # check matmul tensors + assert self.output.shape[0] == m + assert len(self.lora_weights_lst) >= 1 + num_slices = len(self.lora_weights_lst) + for w in self.lora_weights_lst: + _, w_n, w_k = w.shape # n, k flipped due to col-major ordering. + assert (w_n, w_k) == (n, k) or (w_n * num_slices, w_k) == (n, k) + # check metadata tensors + assert torch.sum(self.seq_lens) == m + num_seqs = self.seq_lens.shape[0] + assert self.seq_start_loc.shape[0] == num_seqs + assert self.prompt_lora_mapping.shape[0] == num_seqs + assert self.token_lora_mapping.shape[0] == m + + def to_device(self, device: str): + """ + Transfer tensors to device if the tensors aren't already on the device + """ + + def to_device(tensor: torch.Tensor): + if tensor.device != device: + tensor = tensor.to(device=device) + return tensor + + self.input = to_device(self.input) + self.output = to_device(self.output) + self.seq_lens = to_device(self.seq_lens) + self.seq_start_loc = to_device(self.seq_start_loc) + self.prompt_lora_mapping = to_device(self.prompt_lora_mapping) + self.token_lora_mapping = to_device(self.token_lora_mapping) + for i in range(len(self.lora_weights_lst)): + self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i]) + + def metadata(self) -> Tuple[int, int, int]: + """ + Return num_seqs, num_tokens and max_seq_len + """ + num_seqs = self.seq_lens.shape[0] + num_tokens = self.input.shape[0] + max_seq_len = torch.max(self.seq_lens).item() + return num_seqs, num_tokens, max_seq_len + + def convert_to_sgmv_benchmark_tensors(self): + """ + for sgmv punica kernels, when consecutive sequences have the + same LoRA ID, we just merge them together. + This happens in punica.py::compute_metadata + """ + + # Collapse seq_lens and seq_start_loc + _, seq_lens = torch.unique_consecutive(self.token_lora_mapping, + return_counts=True) + cum_result = torch.cumsum(seq_lens, dim=0) + seq_start_loc = torch.zeros_like(seq_lens) + seq_start_loc[1:].copy_(cum_result[:-1]) + + # Collapse prompt mapping + prompt_lora_mapping = torch.unique_consecutive( + self.prompt_lora_mapping) + + assert torch.sum(seq_lens) == torch.sum(self.seq_lens), \ + f"dont match - new {torch.sum(seq_lens)} vs {torch.sum(self.seq_lens)}" + + self.prompt_lora_mapping = prompt_lora_mapping.to( + dtype=self.prompt_lora_mapping.dtype) + self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype) + self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype) + + ## Benchmark function args. + def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]: + assert len(self.lora_weights_lst) == 1 + + self.convert_to_sgmv_benchmark_tensors() + self.sanity_check() + self.to_device(self.input.device) + + num_seqs, num_tokens, max_seq_len = self.metadata() + return { + 'inputs': self.input, + 'lora_a_weights': self.lora_weights_lst[0], + 'output_tensor': self.output, + 'b_seq_start_loc': self.seq_start_loc, + 'seq_len_tensor': self.seq_lens, + 'lora_indices_tensor': self.prompt_lora_mapping, + 'batches': num_seqs, + 'max_seq_length': max_seq_len, + 'token_nums': num_tokens, + 'scaling': 1.0, + } + + def as_sgmv_expand_kwargs(self) -> Dict[str, Any]: + assert len(self.lora_weights_lst) == 1 + + self.convert_to_sgmv_benchmark_tensors() + self.sanity_check() + self.to_device(self.input.device) + + num_seqs, num_tokens, max_seq_len = self.metadata() + return { + 'inputs': self.input, + 'lora_b_weights': self.lora_weights_lst[0], + 'output_tensor': self.output, + 'b_seq_start_loc': self.seq_start_loc, + 'seq_len_tensor': self.seq_lens, + 'lora_indices_tensor': self.prompt_lora_mapping, + 'batches': num_seqs, + 'max_seq_length': max_seq_len, + 'token_nums': num_tokens, + 'add_inputs': True, + } + + def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]: + assert len(self.lora_weights_lst) == 1 + self.to_device(self.input.device) + return { + 'inputs': self.input, + 'lora_a_weights': self.lora_weights_lst[0], + 'output_tensor': self.output, + 'lora_indices_tensor': self.token_lora_mapping, + 'scaling': 1.0 + } + + def as_bgmv_expand_kwargs(self): + assert len(self.lora_weights_lst) == 1 + self.to_device(self.input.device) + return { + 'inputs': self.input, + 'lora_b_weights': self.lora_weights_lst[0], + 'output_tensor': self.output, + 'lora_indices_tensor': self.token_lora_mapping, + 'add_inputs': True + } + + def as_sgmv_expand_slice_kwargs(self) -> Dict[str, Any]: + assert len(self.lora_weights_lst) > 1 + self.convert_to_sgmv_benchmark_tensors() + self.sanity_check() + + self.to_device(self.input.device) + num_seqs, num_tokens, max_seq_len = self.metadata() + + num_slices = len(self.lora_weights_lst) + slice_size = self.lora_weights_lst[0].shape[-2] # n + assert slice_size * num_slices == self.output.shape[-1] + + kwargs_list = [] + for i in range(num_slices): + kwargs_list.append({ + 'inputs': self.input, + 'lora_b_weights': self.lora_weights_lst[i], + 'output_tensor': self.output, + 'b_seq_start_loc': self.seq_start_loc, + 'seq_len_tensor': self.seq_lens, + 'lora_indices_tensor': self.prompt_lora_mapping, + 'batches': num_seqs, + 'max_seq_length': max_seq_len, + 'token_nums': num_tokens, + 'slice_offset': i * slice_size, + 'slice_size': slice_size, + 'add_inputs': True, + }) + return {'kwargs_list': kwargs_list} + + def as_bgmv_expand_slice_kwargs(self) -> Dict[str, Any]: + assert len(self.lora_weights_lst) > 1 + num_slices = len(self.lora_weights_lst) + slice_size = self.lora_weights_lst[0].shape[-2] # n + assert slice_size * num_slices == self.output.shape[-1] + + self.to_device(self.input.device) + + kwargs_list = [] + for i in range(num_slices): + kwargs_list.append({ + 'inputs': self.input, + 'lora_b_weights': self.lora_weights_lst[i], + 'output_tensor': self.output, + 'lora_indices_tensor': self.token_lora_mapping, + 'slice_offset': i * slice_size, + 'slice_size': slice_size, + 'add_inputs': True + }) + return {'kwargs_list': kwargs_list} + + def bench_fn_kwargs(self, op_type: OpType) -> Dict[str, Any]: + if op_type == OpType.SGMV_SHRINK: + return self.as_sgmv_shrink_kwargs() + if op_type == OpType.SGMV_EXPAND: + return self.as_sgmv_expand_kwargs() + if op_type == OpType.BGMV_SHRINK: + return self.as_bgmv_shrink_kwargs() + if op_type == OpType.BGMV_EXPAND: + return self.as_bgmv_expand_kwargs() + if op_type == OpType.SGMV_EXPAND_SLICE: + return self.as_sgmv_expand_slice_kwargs() + if op_type == OpType.BGMV_EXPAND_SLICE: + return self.as_bgmv_expand_slice_kwargs() + raise ValueError(f"Unrecognized optype {self}") + + +def bench_optype(ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + with_cuda_graph: bool = False) -> TMeasurement: + + assert arg_pool_size >= 1 + + # BenchmarkContext -> BenchmarkTensors + bench_tensors : List[BenchmarkTensors] = \ + [BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)] + for bt in bench_tensors: + bt.sanity_check() + + # BenchmarkTensors -> Dict (kwargs) + kwargs_list = [bt.bench_fn_kwargs(op_type) for bt in bench_tensors] + + # Merge into a single kwargs and quality arguments as ArgPool + kwargs = {k: ArgPool([]) for k in kwargs_list[0]} + for _kwargs in kwargs_list: + for k, v in _kwargs.items(): + kwargs[k].values.append(v) + + description = f"{op_type.name}({bench_tensors[0].io_types()})" + cuda_graph_params = CudaGraphBenchParams( + num_ops_in_cuda_graph=arg_pool_size) if with_cuda_graph else None + with Bench(cuda_graph_params, + ctx.bench_label(), ctx.bench_sublabel(op_type), description, + op_type.bench_fn(), **kwargs) as bench: + return bench.run() + + +def bench_baseline(ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + with_cuda_graph: bool = False) -> TMeasurement: + + batch_size, hidden_size, lora_rank, seq_length, dtype = (ctx.batch_size, + ctx.hidden_size, + ctx.lora_rank, + ctx.seq_length, + ctx.dtype) + + m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank) + if op_type.is_expand_slice_fn(): + # For a fairer comparison. + n = n * ctx.num_slices + + # Get matmul input and output tensors for A x B = C + As, Bs, Cs = [], [], [] + for _ in range(arg_pool_size): + As.append(torch.rand((m, k), dtype=dtype).to("cuda")) + Bs.append(torch.rand((n, k), dtype=dtype).to("cuda").t()) + Cs.append(torch.rand((m, n), dtype=dtype).to("cuda")) + + # Make torch.mm kwargs + mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)} + + description = (f"torch.mm({dtype_to_str(dtype)}" + f"x{dtype_to_str(dtype)}" + f"=>{dtype_to_str(dtype)})") + cuda_graph_params = CudaGraphBenchParams( + num_ops_in_cuda_graph=arg_pool_size) if with_cuda_graph else None + with Bench(cuda_graph_params, ctx.bench_label(), + ctx.bench_sublabel(op_type), description, torch.mm, + **mm_kwargs) as bench: + return bench.run() + + +# runner +def print_timers(timers: List[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + +def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): + + timers = [] + for bench_ctx in bench_ctxs: + for seq_len in args.seq_lengths: + bench_ops: List[OpType] = [] + if seq_len == 1: + # bench all decode ops + bench_ops = [op for op in args.op_types if op.is_decode_op()] + else: + # bench all prefill ops + bench_ops = [op for op in args.op_types if op.is_prefill_op()] + + seq_len_timers = [] + for bench_op in bench_ops: + _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices( + bench_op.num_slices()) + seq_len_timers.append( + bench_baseline(_ctx, args.arg_pool_size, bench_op, + args.with_cuda_graph)) + seq_len_timers.append( + bench_optype(_ctx, args.arg_pool_size, bench_op, + args.with_cuda_graph)) + + print_timers(seq_len_timers) + timers.extend(seq_len_timers) + + # Result stdout dump + print("== All Results ====") + print_timers(timers) + + # Result file dump + timestamp = int(time.time()) + pkl_file = f"lora_bench-{timestamp}.pkl" + print (f"Writing benchmarks to {pkl_file}") + with open(pkl_file, "wb") as f: + pickle.dump(timers, f) + + +def as_benchmark_contexts(hidden_sizes: List[int], lora_ranks: List[int], + args: argparse.Namespace) -> List[BenchmarkContext]: + + ctxs: List[BenchmarkContext] = [] + for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa + args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, + args.sort_by_lora_id): + ctxs.append( + BenchmarkContext( + batch_size=batch_size, + hidden_size=hidden_size, + lora_rank=lora_rank, + num_loras=num_loras, + num_active_loras=args.num_active_loras + if args.num_active_loras else num_loras, + # To be filled based on the OpType to benchmark + seq_length=None, + sort_by_lora_id=sort_by_lora_id, + dtype=args.dtype, + # To be filled based on the OpType to benchmark + num_slices=None)) + + return ctxs + + +def run_list_bench(args: argparse.Namespace): + print(args) + + print("List bench :\n" + f" Hidden Sizes {args.hidden_sizes}" + f" LoRA Ranks {args.lora_ranks}") + + # Get all benchmarking contexts + bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( + hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args) + + run(args, bench_contexts) + + +def run_range_bench(args: argparse.Namespace): + print(args) + + hidden_sizes = list( + range(args.hidden_sizes_start, args.hidden_sizes_end + 1, + args.hidden_sizes_increment)) + lora_ranks = list( + range(args.lora_ranks_start, args.lora_ranks_end + 1, + args.lora_ranks_increment)) + + print("Range bench :\n" + f" Hidden Sizes {hidden_sizes}" + f" LoRA Ranks {lora_ranks}") + + # Get all benchmarking contexts + bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( + hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args) + + run(args, bench_contexts) + + +def run_model_bench(args: argparse.Namespace): + print(args) + + def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]: + hidden_sizes = set() + for KN, tp_split_dim in WEIGHT_SHAPES[model]: + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + hidden_sizes.add(KN[1]) + return hidden_sizes + + # Get all hidden sizes + hidden_sizes: set[int] = set() + for model_name, tp_size in product(args.models, args.tp_sizes): + hidden_sizes = hidden_sizes.union( + hidden_sizes_from_model(model_name, tp_size)) + + print("Model bench :\n" + f" Hidden Sizes {hidden_sizes}" + f" LoRA Ranks {args.lora_ranks}") + + # Get all benchmarking contexts + bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( + hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args) + + run(args, bench_contexts) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "torch.float16": + return torch.float16 + if dt == "torch.bfloat16": + return torch.bfloat16 + raise ValueError("unsupported dtype") + + def get_bool(s: str) -> bool: + return s.lower() in ['true', '1'] + + def add_common_command_args(p: argparse.ArgumentParser): + p.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['torch.float16', 'torch.bfloat16']") + + p.add_argument( + "--arg-pool-size", + type=int, + default=32, + help="Run profiles with a pool of input/output/meta tensors instead" + "of simply reusing the same tensors for all runs") + + p.add_argument("--with-cuda-graph", + action="store_true", + help="when set profiling is done using cudagraph") + p.add_argument("--num-loras", + nargs="+", + type=int, + default=DEFAULT_NUM_LORAS) + p.add_argument("--num-active-loras", + type=int, + default=None, + help="Active LoRAs. When None, all LoRAs are active") + p.add_argument("--sort-by-lora-id", + nargs="+", + type=get_bool, + default=DEFAULT_SORT_BY_LORA_IDS) + p.add_argument("--op-types", + nargs="+", + type=OpType.from_str, + default=list(OpType)) + p.add_argument('--seq-lengths', + nargs="+", + type=int, + default=DEFAULT_SEQ_LENGTHS) + p.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + + parser = FlexibleArgumentParser( + description=""" +Benchmark LoRA kernels: + + list_bench example: + python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --with-cuda-graph + + model_bench example: + python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --with-cuda-graph + + range_bench example: + python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --with-cuda-graph --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + subparsers = parser.add_subparsers(dest="cmd", required=True) + + list_parser = subparsers.add_parser("list_bench") + list_parser.add_argument("--hidden-sizes", + nargs="+", + type=int, + default=DEFAULT_HIDDEN_SIZES) + list_parser.add_argument("--lora-ranks", + nargs="+", + type=int, + default=DEFAULT_LORA_RANKS) + add_common_command_args(list_parser) + list_parser.set_defaults(func=run_list_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--hidden-sizes-start", type=int, required=True) + range_parser.add_argument("--hidden-sizes-end", type=int, required=True) + range_parser.add_argument("--hidden-sizes-increment", + type=int, + required=True) + range_parser.add_argument("--lora-ranks-start", type=int, required=True) + range_parser.add_argument("--lora-ranks-end", type=int, required=True) + range_parser.add_argument("--lora-ranks-increment", + type=int, + required=True) + add_common_command_args(range_parser) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--lora-ranks", + nargs="+", + type=int, + default=DEFAULT_LORA_RANKS) + add_common_command_args(model_parser) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py new file mode 100644 index 0000000000000..7ca43f0ec8230 --- /dev/null +++ b/benchmarks/kernels/utils.py @@ -0,0 +1,209 @@ +import dataclasses +from typing import Any, Callable, Iterable, Optional + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement + + +@dataclasses.dataclass +class CudaGraphBenchParams: + num_ops_in_cuda_graph: int + + +@dataclasses.dataclass +class ArgPool: + ''' + When some argument of the benchmarking function is annotated with this type, + the benchmarking class (BenchMM) will collapse the argument to a pick a + single value from the given list of values, during function invocation. + For every invocation during a benchmarking run, it will choose a + different value from the list. + ''' + values: Iterable[Any] + + def __getitem__(self, index): + return self.values[index] + + +class Bench: + + class ArgsIterator: + + def __init__(self, args_list, kwargs_list): + assert len(args_list) == len(kwargs_list) + self.args_list = args_list + self.kwargs_list = kwargs_list + self.n = len(self.args_list) + self.idx = 0 + + def __next__(self): + while True: + yield (self.args_list[self.idx], self.kwargs_list[self.idx]) + self.idx += 1 + self.idx = self.idx % self.n + + def reset(self): + self.idx = 0 + + @property + def n_args(self): + return self.n + + def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], + label: str, sub_label: str, description: str, fn: Callable, + *args, **kwargs): + + self.cuda_graph_params = cuda_graph_params + self.use_cuda_graph = self.cuda_graph_params is not None + self.label = label + self.sub_label = sub_label + self.description = description + self.fn = fn + + # Process args + self._args = args + self._kwargs = kwargs + self.args_list, self.kwargs_list = self.collapse_argpool( + *args, **kwargs) + self.args_iterator = self.ArgsIterator(self.args_list, + self.kwargs_list) + + # Cudagraph runner + self.g = None + if self.use_cuda_graph: + self.g = self.get_cuda_graph_runner() + + # benchmark run params + self.min_run_time = 1 + + def collapse_argpool(self, *args, **kwargs): + argpool_args = [arg for arg in args if isinstance(arg, ArgPool)] + [ + arg for arg in kwargs.values() if isinstance(arg, ArgPool) + ] + if len(argpool_args) == 0: + return [args], [kwargs] + + # Make sure all argpools are of the same size + argpool_size = len(argpool_args[0].values) + assert all([argpool_size == len(arg.values) for arg in argpool_args]) + + # create copies of the args + args_list = [] + kwargs_list = [] + for _ in range(argpool_size): + args_list.append(args) + kwargs_list.append(kwargs.copy()) + + for i in range(argpool_size): + # collapse args; Just pick the ith value + args_list[i] = tuple([ + arg[i] if isinstance(arg, ArgPool) else arg + for arg in args_list[i] + ]) + + # collapse kwargs + kwargs_i = kwargs_list[i] + arg_pool_keys = [ + k for k, v in kwargs_i.items() if isinstance(v, ArgPool) + ] + for k in arg_pool_keys: + # again just pick the ith value + kwargs_i[k] = kwargs_i[k][i] + kwargs_list[i] = kwargs_i + + return args_list, kwargs_list + + def get_cuda_graph_runner(self): + assert self.use_cuda_graph + assert self.args_iterator is not None + + num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph + + # warmup + args_it = self.args_iterator.__next__() + for _ in range(2): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) + + self.args_iterator.reset() + args_it = self.args_iterator.__next__() + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(num_graph_ops): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) + return g + + def run_cudagrah(self) -> TMeasurement: + assert self.use_cuda_graph + globals = {'g': self.g} + + return TBenchmark.Timer( + stmt="g.replay()", + globals=globals, + label=(f"{self.label}" + f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops"), + sub_label=self.sub_label, + description=self.description, + ).blocked_autorange(min_run_time=self.min_run_time) + + def run_eager(self) -> TMeasurement: + setup = None + stmt = None + globals = None + + has_arg_pool = self.args_iterator.n_args > 1 + if has_arg_pool: + setup = ''' + args_iterator.reset() + args_it = args_iterator.__next__() + ''' + stmt = ''' + args, kwargs = next(args_it) + fn(*args, **kwargs) + ''' + globals = {'fn': self.fn, 'args_iterator': self.args_iterator} + else: + # no arg pool. Just use the args and kwargs directly + self.args_iterator.reset() + args_it = self.args_iterator.__next__() + args, kwargs = next(args_it) + + setup = "" + stmt = ''' + fn(*args, **kwargs) + ''' + globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs} + + return TBenchmark.Timer( + stmt=stmt, + setup=setup, + globals=globals, + label=self.label, + sub_label=self.sub_label, + description=self.description, + ).blocked_autorange(min_run_time=self.min_run_time) + + def run(self) -> TMeasurement: + timer = None + if self.use_cuda_graph: # noqa SIM108 + timer = self.run_cudagrah() + else: + timer = self.run_eager() + if not timer.meets_confidence() or timer.has_warnings: + print("Doesn't meet confidence - re-running bench ...") + return self.run() + return timer + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type: + print(f"exc type {exc_type}") + print(f"exc value {exc_value}") + print(f"exc traceback {traceback}") From d66a4b099cea09f6934c290ea151f85537e05bea Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 28 Dec 2024 02:22:39 -0500 Subject: [PATCH 02/17] format Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 5 +++-- benchmarks/kernels/utils.py | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index e9519da7c34ff..920295e5f8781 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -129,7 +129,7 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int, lora_index = prompt_lora_mapping[b_id].item() s = current_offset e = s + seq_len_tensor[b_id].item() - token_lora_mapping[s:e] = [lora_index] * (e - s) + token_lora_mapping[s:e] = [lora_index] * (e - s) current_offset += seq_len_tensor[b_id].item() return torch.tensor(token_lora_mapping, dtype=torch.long, device=device) @@ -623,6 +623,7 @@ def print_timers(timers: List[TMeasurement]): compare = TBenchmark.Compare(timers) compare.print() + def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): timers = [] @@ -657,7 +658,7 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): # Result file dump timestamp = int(time.time()) pkl_file = f"lora_bench-{timestamp}.pkl" - print (f"Writing benchmarks to {pkl_file}") + print(f"Writing benchmarks to {pkl_file}") with open(pkl_file, "wb") as f: pickle.dump(timers, f) diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py index 7ca43f0ec8230..3b71689a751c4 100644 --- a/benchmarks/kernels/utils.py +++ b/benchmarks/kernels/utils.py @@ -145,8 +145,10 @@ def run_cudagrah(self) -> TMeasurement: return TBenchmark.Timer( stmt="g.replay()", globals=globals, - label=(f"{self.label}" - f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops"), + label=( + f"{self.label}" + f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops" + ), sub_label=self.sub_label, description=self.description, ).blocked_autorange(min_run_time=self.min_run_time) From 2406afc6ac1ce928700dcf5f7aa07e4d0b18ca18 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 28 Dec 2024 02:28:57 -0500 Subject: [PATCH 03/17] fix Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 920295e5f8781..05e69791f6ad4 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -108,7 +108,7 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, prompt_lora_mapping = [] while len(prompt_lora_mapping) < num_prompts: prompt_lora_mapping.extend([lora_id] * part_size) - lora_id = lora_id + 1 if lora_id < num_active_loras else lora_id + lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id return torch.tensor(prompt_lora_mapping[:num_prompts], dtype=torch.long, device=device) From 5b437881bb8e852cc37dad08a959f513e0cc7836 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 2 Jan 2025 11:45:45 -0500 Subject: [PATCH 04/17] add output directory Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 05e69791f6ad4..3c663b5aa33db 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -7,6 +7,7 @@ from enum import Enum, auto from itertools import product from typing import Any, Callable, Dict, List, Optional, Tuple +from pathlib import Path import torch import torch.utils.benchmark as TBenchmark @@ -655,12 +656,17 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): print("== All Results ====") print_timers(timers) - # Result file dump - timestamp = int(time.time()) - pkl_file = f"lora_bench-{timestamp}.pkl" - print(f"Writing benchmarks to {pkl_file}") - with open(pkl_file, "wb") as f: - pickle.dump(timers, f) + if args.output_directory: + # Result file dump + od = Path(args.output_directory) + if not od.exists(): + od.mkdir() + + timestamp = int(time.time()) + pkl_file = od / f"lora_bench-{timestamp}.pkl" + print(f"Writing benchmarks to {pkl_file}") + with open(pkl_file, "wb") as f: + pickle.dump(timers, f) def as_benchmark_contexts(hidden_sizes: List[int], lora_ranks: List[int], @@ -803,6 +809,8 @@ def add_common_command_args(p: argparse.ArgumentParser): nargs="+", type=int, default=DEFAULT_BATCH_SIZES) + p.add_argument('-o', '--output-directory', + type=str) parser = FlexibleArgumentParser( description=""" From e356facfdde27e2491adf7307bf89f12c5021f10 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 2 Jan 2025 12:29:15 -0500 Subject: [PATCH 05/17] fix num slices Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 3c663b5aa33db..ffd024651fde1 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -180,10 +180,10 @@ def is_decode_op(self) -> bool: OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE ] - def num_slices(self) -> int: + def num_slices(self) -> List[int]: if self.is_expand_slice_fn(): - return 3 - return 1 + return [2, 3] + return [1] def mkn(self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int) -> Tuple[int, int, int]: @@ -274,7 +274,8 @@ def bench_sublabel(self, op_type: OpType) -> str: 'k': k, 'n': n, 'num_loras': self.num_loras, - 'sort_by_lora': self.sort_by_lora_id + 'sort_by_lora': self.sort_by_lora_id, + 'num_slices' : self.num_slices, } return json.dumps(desc) @@ -640,14 +641,14 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): seq_len_timers = [] for bench_op in bench_ops: - _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices( - bench_op.num_slices()) - seq_len_timers.append( - bench_baseline(_ctx, args.arg_pool_size, bench_op, - args.with_cuda_graph)) - seq_len_timers.append( - bench_optype(_ctx, args.arg_pool_size, bench_op, - args.with_cuda_graph)) + for num_slices in bench_op.num_slices(): + _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices(num_slices) + seq_len_timers.append( + bench_baseline(_ctx, args.arg_pool_size, bench_op, + args.with_cuda_graph)) + seq_len_timers.append( + bench_optype(_ctx, args.arg_pool_size, bench_op, + args.with_cuda_graph)) print_timers(seq_len_timers) timers.extend(seq_len_timers) From bcc15188ba6f4c26f4f55905a412669b26f33242 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 2 Jan 2025 12:41:40 -0500 Subject: [PATCH 06/17] format Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index ffd024651fde1..685c46f192646 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -6,8 +6,8 @@ from dataclasses import dataclass from enum import Enum, auto from itertools import product -from typing import Any, Callable, Dict, List, Optional, Tuple from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.utils.benchmark as TBenchmark @@ -275,7 +275,7 @@ def bench_sublabel(self, op_type: OpType) -> str: 'n': n, 'num_loras': self.num_loras, 'sort_by_lora': self.sort_by_lora_id, - 'num_slices' : self.num_slices, + 'num_slices': self.num_slices, } return json.dumps(desc) @@ -642,7 +642,8 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): seq_len_timers = [] for bench_op in bench_ops: for num_slices in bench_op.num_slices(): - _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices(num_slices) + _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices( + num_slices) seq_len_timers.append( bench_baseline(_ctx, args.arg_pool_size, bench_op, args.with_cuda_graph)) @@ -659,7 +660,7 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): if args.output_directory: # Result file dump - od = Path(args.output_directory) + od = Path(args.output_directory) if not od.exists(): od.mkdir() @@ -810,8 +811,7 @@ def add_common_command_args(p: argparse.ArgumentParser): nargs="+", type=int, default=DEFAULT_BATCH_SIZES) - p.add_argument('-o', '--output-directory', - type=str) + p.add_argument('-o', '--output-directory', type=str) parser = FlexibleArgumentParser( description=""" From 96332eef348407afa3d51041199403a8a84be51e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 2 Jan 2025 14:11:02 -0500 Subject: [PATCH 07/17] add expand_fn_add_inputs arg Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 68 +++++++++++++++++++--------- 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 685c46f192646..b506956bd3542 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -34,6 +34,7 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4] DEFAULT_SORT_BY_LORA_IDS = [False, True] DEFAULT_SEQ_LENGTHS = [1] +DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False] ## Utilities @@ -442,7 +443,7 @@ def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]: 'scaling': 1.0, } - def as_sgmv_expand_kwargs(self) -> Dict[str, Any]: + def as_sgmv_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]: assert len(self.lora_weights_lst) == 1 self.convert_to_sgmv_benchmark_tensors() @@ -460,7 +461,7 @@ def as_sgmv_expand_kwargs(self) -> Dict[str, Any]: 'batches': num_seqs, 'max_seq_length': max_seq_len, 'token_nums': num_tokens, - 'add_inputs': True, + 'add_inputs': add_inputs, } def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]: @@ -474,7 +475,7 @@ def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]: 'scaling': 1.0 } - def as_bgmv_expand_kwargs(self): + def as_bgmv_expand_kwargs(self, add_inputs: bool): assert len(self.lora_weights_lst) == 1 self.to_device(self.input.device) return { @@ -482,10 +483,10 @@ def as_bgmv_expand_kwargs(self): 'lora_b_weights': self.lora_weights_lst[0], 'output_tensor': self.output, 'lora_indices_tensor': self.token_lora_mapping, - 'add_inputs': True + 'add_inputs': add_inputs } - def as_sgmv_expand_slice_kwargs(self) -> Dict[str, Any]: + def as_sgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]: assert len(self.lora_weights_lst) > 1 self.convert_to_sgmv_benchmark_tensors() self.sanity_check() @@ -511,11 +512,11 @@ def as_sgmv_expand_slice_kwargs(self) -> Dict[str, Any]: 'token_nums': num_tokens, 'slice_offset': i * slice_size, 'slice_size': slice_size, - 'add_inputs': True, + 'add_inputs': add_inputs, }) return {'kwargs_list': kwargs_list} - def as_bgmv_expand_slice_kwargs(self) -> Dict[str, Any]: + def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]: assert len(self.lora_weights_lst) > 1 num_slices = len(self.lora_weights_lst) slice_size = self.lora_weights_lst[0].shape[-2] # n @@ -532,32 +533,42 @@ def as_bgmv_expand_slice_kwargs(self) -> Dict[str, Any]: 'lora_indices_tensor': self.token_lora_mapping, 'slice_offset': i * slice_size, 'slice_size': slice_size, - 'add_inputs': True + 'add_inputs': add_inputs, }) return {'kwargs_list': kwargs_list} - def bench_fn_kwargs(self, op_type: OpType) -> Dict[str, Any]: + def bench_fn_kwargs(self, op_type: OpType, add_inputs: Optional[bool] = None) -> Dict[str, Any]: + if op_type.is_shrink_fn(): + assert add_inputs is None + else: + assert add_inputs is not None + if op_type == OpType.SGMV_SHRINK: return self.as_sgmv_shrink_kwargs() if op_type == OpType.SGMV_EXPAND: - return self.as_sgmv_expand_kwargs() + return self.as_sgmv_expand_kwargs(add_inputs) if op_type == OpType.BGMV_SHRINK: return self.as_bgmv_shrink_kwargs() if op_type == OpType.BGMV_EXPAND: - return self.as_bgmv_expand_kwargs() + return self.as_bgmv_expand_kwargs(add_inputs) if op_type == OpType.SGMV_EXPAND_SLICE: - return self.as_sgmv_expand_slice_kwargs() + return self.as_sgmv_expand_slice_kwargs(add_inputs) if op_type == OpType.BGMV_EXPAND_SLICE: - return self.as_bgmv_expand_slice_kwargs() + return self.as_bgmv_expand_slice_kwargs(add_inputs) raise ValueError(f"Unrecognized optype {self}") def bench_optype(ctx: BenchmarkContext, arg_pool_size: int, op_type: OpType, - with_cuda_graph: bool = False) -> TMeasurement: + with_cuda_graph: bool = False, + expand_fn_add_inputs: Optional[bool] = None) -> TMeasurement: assert arg_pool_size >= 1 + if op_type.is_shrink_fn(): + assert expand_fn_add_inputs is None + else: + assert expand_fn_add_inputs is not None # BenchmarkContext -> BenchmarkTensors bench_tensors : List[BenchmarkTensors] = \ @@ -566,7 +577,7 @@ def bench_optype(ctx: BenchmarkContext, bt.sanity_check() # BenchmarkTensors -> Dict (kwargs) - kwargs_list = [bt.bench_fn_kwargs(op_type) for bt in bench_tensors] + kwargs_list = [bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) for bt in bench_tensors] # Merge into a single kwargs and quality arguments as ArgPool kwargs = {k: ArgPool([]) for k in kwargs_list[0]} @@ -574,7 +585,8 @@ def bench_optype(ctx: BenchmarkContext, for k, v in _kwargs.items(): kwargs[k].values.append(v) - description = f"{op_type.name}({bench_tensors[0].io_types()})" + describe_args = f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else "" + description = f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})" cuda_graph_params = CudaGraphBenchParams( num_ops_in_cuda_graph=arg_pool_size) if with_cuda_graph else None with Bench(cuda_graph_params, @@ -583,10 +595,14 @@ def bench_optype(ctx: BenchmarkContext, return bench.run() -def bench_baseline(ctx: BenchmarkContext, +def bench_torch_mm(ctx: BenchmarkContext, arg_pool_size: int, op_type: OpType, with_cuda_graph: bool = False) -> TMeasurement: + """ + Benchmark basic torch.mm as a roofline. + input op_type is used in determining the m, k, n dimensions for the matmul. + """ batch_size, hidden_size, lora_rank, seq_length, dtype = (ctx.batch_size, ctx.hidden_size, @@ -644,12 +660,18 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): for num_slices in bench_op.num_slices(): _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices( num_slices) + # Benchmark torch.mm as a roofline seq_len_timers.append( - bench_baseline(_ctx, args.arg_pool_size, bench_op, + bench_torch_mm(_ctx, args.arg_pool_size, bench_op, args.with_cuda_graph)) - seq_len_timers.append( - bench_optype(_ctx, args.arg_pool_size, bench_op, - args.with_cuda_graph)) + + # Benchmark bench_op + expand_fn_add_inputs = [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + for add_input_arg in expand_fn_add_inputs: + seq_len_timers.append( + bench_optype(_ctx, args.arg_pool_size, bench_op, + args.with_cuda_graph, + add_input_arg)) print_timers(seq_len_timers) timers.extend(seq_len_timers) @@ -811,6 +833,10 @@ def add_common_command_args(p: argparse.ArgumentParser): nargs="+", type=int, default=DEFAULT_BATCH_SIZES) + p.add_argument("--expand-fn-add-inputs", + nargs="+", + type=get_bool, + default=DEFAULT_EXPAND_FN_ADD_INPUTS) p.add_argument('-o', '--output-directory', type=str) parser = FlexibleArgumentParser( From 75ca94d5bf9e64fd6d40fb5839345646cfa38cf9 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 2 Jan 2025 14:12:17 -0500 Subject: [PATCH 08/17] format Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index b506956bd3542..5ff7589e9beb5 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -537,7 +537,9 @@ def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]: }) return {'kwargs_list': kwargs_list} - def bench_fn_kwargs(self, op_type: OpType, add_inputs: Optional[bool] = None) -> Dict[str, Any]: + def bench_fn_kwargs(self, + op_type: OpType, + add_inputs: Optional[bool] = None) -> Dict[str, Any]: if op_type.is_shrink_fn(): assert add_inputs is None else: @@ -577,7 +579,10 @@ def bench_optype(ctx: BenchmarkContext, bt.sanity_check() # BenchmarkTensors -> Dict (kwargs) - kwargs_list = [bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) for bt in bench_tensors] + kwargs_list = [ + bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) + for bt in bench_tensors + ] # Merge into a single kwargs and quality arguments as ArgPool kwargs = {k: ArgPool([]) for k in kwargs_list[0]} @@ -585,8 +590,10 @@ def bench_optype(ctx: BenchmarkContext, for k, v in _kwargs.items(): kwargs[k].values.append(v) - describe_args = f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else "" - description = f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})" + describe_args = (f"add_inputs={expand_fn_add_inputs}" + if expand_fn_add_inputs is not None else "") + description = ( + f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})") cuda_graph_params = CudaGraphBenchParams( num_ops_in_cuda_graph=arg_pool_size) if with_cuda_graph else None with Bench(cuda_graph_params, @@ -666,12 +673,13 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): args.with_cuda_graph)) # Benchmark bench_op - expand_fn_add_inputs = [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + expand_fn_add_inputs = [ + None + ] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs for add_input_arg in expand_fn_add_inputs: seq_len_timers.append( bench_optype(_ctx, args.arg_pool_size, bench_op, - args.with_cuda_graph, - add_input_arg)) + args.with_cuda_graph, add_input_arg)) print_timers(seq_len_timers) timers.extend(seq_len_timers) From f8c900ef9dc490095908881e4500d0c5612d8c31 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 2 Jan 2025 16:20:29 -0500 Subject: [PATCH 09/17] add test_correctness Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 109 ++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 3 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 5ff7589e9beb5..a75a3b0ce02b2 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -137,6 +137,52 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int, return torch.tensor(token_lora_mapping, dtype=torch.long, device=device) +def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, + lora_weights: List[torch.Tensor], + seq_lens_cpu: torch.Tensor, + prompt_lora_mapping_cpu: torch.Tensor, scaling: float, + add_inputs: Optional[bool]): + """ + Torch group gemm reference implementation to test correctness of + benchmarking operations. + """ + batches = seq_lens_cpu.size(0) + out_list = [] + current_offset = 0 + for lora_index, b_length in zip(range(batches), seq_lens_cpu): + x = input[current_offset:b_length + current_offset, :] + current_offset += b_length + w = lora_weights[prompt_lora_mapping_cpu[lora_index]] + result = torch.nn.functional.linear(x, w) + result *= scaling + out_list.append(result) + torch.cat(out_list, dim=0) + + cat_result = torch.cat(out_list, dim=0) + + if add_inputs: + ref_out += cat_result + else: + ref_out.copy_(cat_result) + + +def ref_group_gemm_with_slices(ref_out: torch.Tensor, input: torch.Tensor, + lora_weights: List[torch.Tensor], + seq_lens_cpu: torch.Tensor, + prompt_lora_mapping_cpu: torch.Tensor, + scaling: float, add_inputs: Optional[bool], + num_slices: int, hidden_size: int): + for slice_idx in range(num_slices): + slice_offset = slice_idx * hidden_size + ref_group_gemm(ref_out[:, slice_offset:slice_offset + hidden_size], + input, + lora_weights[slice_idx], + seq_lens_cpu, + prompt_lora_mapping_cpu, + scaling=scaling, + add_inputs=add_inputs) + + ## LoRA Ops to Benchmark and its properties class OpType(Enum): SGMV_SHRINK = auto() @@ -559,12 +605,54 @@ def bench_fn_kwargs(self, return self.as_bgmv_expand_slice_kwargs(add_inputs) raise ValueError(f"Unrecognized optype {self}") + def test_correctness(self, op_type: OpType, + expand_fn_add_inputs: Optional[bool]) -> bool: + """ + Test correctness of the given benchmarking operation against a + grouped gemm reference implementation. + """ + seq_lens_cpu = self.seq_lens.to(device="cpu") + prompt_lora_mapping_cpu = self.prompt_lora_mapping.to(device="cpu") + + ref_output = self.output.clone() + + num_slices = len(self.lora_weights_lst) + hidden_size = self.lora_weights_lst[0].shape[-2] # n + assert hidden_size * num_slices == self.output.shape[-1] + + do_input_cast: bool = op_type.is_expand_fn( + ) or op_type.is_expand_slice_fn() + weight_dtype = self.lora_weights_lst[0].dtype + ref_group_gemm_with_slices( + ref_output, + self.input.clone().to( + dtype=weight_dtype) if do_input_cast else self.input, + self.lora_weights_lst, + seq_lens_cpu, + prompt_lora_mapping_cpu, + scaling=1.0, + add_inputs=expand_fn_add_inputs, + num_slices=num_slices, + hidden_size=hidden_size, + ) + + op_type.bench_fn()( + **self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) + + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[self.output.dtype] + return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol) + def bench_optype(ctx: BenchmarkContext, arg_pool_size: int, op_type: OpType, with_cuda_graph: bool = False, - expand_fn_add_inputs: Optional[bool] = None) -> TMeasurement: + expand_fn_add_inputs: Optional[bool] = None, + test_correcntess: bool = False) -> TMeasurement: assert arg_pool_size >= 1 if op_type.is_shrink_fn(): @@ -578,6 +666,9 @@ def bench_optype(ctx: BenchmarkContext, for bt in bench_tensors: bt.sanity_check() + if test_correcntess: + assert bench_tensors[0].test_correctness(op_type, expand_fn_add_inputs) + # BenchmarkTensors -> Dict (kwargs) kwargs_list = [ bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) @@ -679,7 +770,8 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): for add_input_arg in expand_fn_add_inputs: seq_len_timers.append( bench_optype(_ctx, args.arg_pool_size, bench_op, - args.with_cuda_graph, add_input_arg)) + args.with_cuda_graph, add_input_arg, + args.test_correctness)) print_timers(seq_len_timers) timers.extend(seq_len_timers) @@ -845,7 +937,18 @@ def add_common_command_args(p: argparse.ArgumentParser): nargs="+", type=get_bool, default=DEFAULT_EXPAND_FN_ADD_INPUTS) - p.add_argument('-o', '--output-directory', type=str) + p.add_argument( + '-o', + '--output-directory', + type=str, + help=("Output directory to store a the list of benchmarking" + "TMeasurement objects as a pickle file")) + + p.add_argument( + "--test-correctness", + action='store_true', + help=("When enabled, the benchmarking objects are additionally " + "checked for correctness")) parser = FlexibleArgumentParser( description=""" From d9aadfac570ac0fa15a0965d917973270ddd0aac Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 3 Jan 2025 13:03:42 -0500 Subject: [PATCH 10/17] add correctness testing and prints Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 97 ++++++++++++++++++++-------- benchmarks/kernels/utils.py | 14 ++-- 2 files changed, 76 insertions(+), 35 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index a75a3b0ce02b2..0e843a7592a03 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -608,9 +608,22 @@ def bench_fn_kwargs(self, def test_correctness(self, op_type: OpType, expand_fn_add_inputs: Optional[bool]) -> bool: """ - Test correctness of the given benchmarking operation against a - grouped gemm reference implementation. + Test correctness of self.output against a grouped gemm reference implementation. + + For expand-related operations with add_inputs = True, since the benchmarking + setup runs the function multiple times, the accumulation into the self.output + is intractable. Correctness testing is skipped for that case. """ + + if op_type.is_shrink_fn(): + assert expand_fn_add_inputs is None + else: + assert expand_fn_add_inputs is not None + + if expand_fn_add_inputs: + print (f"WARNING: Skipping correctness testing for {op_type} with add_inputs={expand_fn_add_inputs}") + return True + seq_lens_cpu = self.seq_lens.to(device="cpu") prompt_lora_mapping_cpu = self.prompt_lora_mapping.to(device="cpu") @@ -636,9 +649,6 @@ def test_correctness(self, op_type: OpType, hidden_size=hidden_size, ) - op_type.bench_fn()( - **self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) - rtol, atol = { torch.float16: (6e-2, 6e-2), torch.bfloat16: (6e-2, 6e-2), @@ -650,9 +660,9 @@ def test_correctness(self, op_type: OpType, def bench_optype(ctx: BenchmarkContext, arg_pool_size: int, op_type: OpType, - with_cuda_graph: bool = False, + cuda_graph_nops: Optional[int] = None, expand_fn_add_inputs: Optional[bool] = None, - test_correcntess: bool = False) -> TMeasurement: + test_correctness: bool = False) -> TMeasurement: assert arg_pool_size >= 1 if op_type.is_shrink_fn(): @@ -666,9 +676,6 @@ def bench_optype(ctx: BenchmarkContext, for bt in bench_tensors: bt.sanity_check() - if test_correcntess: - assert bench_tensors[0].test_correctness(op_type, expand_fn_add_inputs) - # BenchmarkTensors -> Dict (kwargs) kwargs_list = [ bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) @@ -681,22 +688,31 @@ def bench_optype(ctx: BenchmarkContext, for k, v in _kwargs.items(): kwargs[k].values.append(v) + cuda_graph_params = None + if cuda_graph_nops: + cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) + describe_args = (f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else "") description = ( f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})") - cuda_graph_params = CudaGraphBenchParams( - num_ops_in_cuda_graph=arg_pool_size) if with_cuda_graph else None + + timer = None with Bench(cuda_graph_params, ctx.bench_label(), ctx.bench_sublabel(op_type), description, op_type.bench_fn(), **kwargs) as bench: - return bench.run() + timer = bench.run() + + if test_correctness: + assert all([bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors]) + + return timer def bench_torch_mm(ctx: BenchmarkContext, arg_pool_size: int, op_type: OpType, - with_cuda_graph: bool = False) -> TMeasurement: + cuda_graph_nops: Optional[int] = None) -> TMeasurement: """ Benchmark basic torch.mm as a roofline. input op_type is used in determining the m, k, n dimensions for the matmul. @@ -726,8 +742,9 @@ def bench_torch_mm(ctx: BenchmarkContext, description = (f"torch.mm({dtype_to_str(dtype)}" f"x{dtype_to_str(dtype)}" f"=>{dtype_to_str(dtype)})") - cuda_graph_params = CudaGraphBenchParams( - num_ops_in_cuda_graph=arg_pool_size) if with_cuda_graph else None + cuda_graph_params = None + if cuda_graph_nops: + cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) with Bench(cuda_graph_params, ctx.bench_label(), ctx.bench_sublabel(op_type), description, torch.mm, **mm_kwargs) as bench: @@ -735,13 +752,31 @@ def bench_torch_mm(ctx: BenchmarkContext, # runner -def print_timers(timers: List[TMeasurement]): +def use_cuda_graph_recommendation() -> str: + return """ + Triton kernels have a significant launch overhead with launched directly via python. + This overhead is more noticeable for small the problem sizes. For these cases, it is + recommended to use the script with `--cuda-graph-nops N` to benchmark N consecutive invocations + of the benchmarking operations from inside a CUDA Graph. Note that the returned measurement + is for N invocations of the operation. + """ + +def print_timers(timers: List[TMeasurement], args: Optional[argparse.Namespace] = None): compare = TBenchmark.Compare(timers) compare.print() + if args and args.cuda_graph_nops: + print (f"The timings reported above is for {args.cuda_graph_nops} consecutive invocations of the benchmarking functions. Please divide by {args.cuda_graph_nops} for single invocation timings ") + def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): + if args.cuda_graph_nops is not None: + assert args.cuda_graph_nops > 0 + print (f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA Graph") + else: + print (f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}") + timers = [] for bench_ctx in bench_ctxs: for seq_len in args.seq_lengths: @@ -761,7 +796,7 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): # Benchmark torch.mm as a roofline seq_len_timers.append( bench_torch_mm(_ctx, args.arg_pool_size, bench_op, - args.with_cuda_graph)) + args.cuda_graph_nops)) # Benchmark bench_op expand_fn_add_inputs = [ @@ -770,7 +805,7 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): for add_input_arg in expand_fn_add_inputs: seq_len_timers.append( bench_optype(_ctx, args.arg_pool_size, bench_op, - args.with_cuda_graph, add_input_arg, + args.cuda_graph_nops, add_input_arg, args.test_correctness)) print_timers(seq_len_timers) @@ -778,7 +813,7 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): # Result stdout dump print("== All Results ====") - print_timers(timers) + print_timers(timers, args) if args.output_directory: # Result file dump @@ -904,11 +939,16 @@ def add_common_command_args(p: argparse.ArgumentParser): type=int, default=32, help="Run profiles with a pool of input/output/meta tensors instead" - "of simply reusing the same tensors for all runs") + "of simply reusing the same tensors for all runs. A bigger arg-pool " + "mitigates hardware caching effects during benchmarking.") - p.add_argument("--with-cuda-graph", - action="store_true", - help="when set profiling is done using cudagraph") + p.add_argument("--cuda-graph-nops", + type=int, + help=("when set profiling is done using cudagraph, " + "with the given number of operations in a graph." + "Note that the measurement returned is the time " + "taken for N consecutive executions of the benchmarking " + "functions, where N is the value of this argument.")) p.add_argument("--num-loras", nargs="+", type=int, @@ -951,17 +991,18 @@ def add_common_command_args(p: argparse.ArgumentParser): "checked for correctness")) parser = FlexibleArgumentParser( - description=""" + description=f""" Benchmark LoRA kernels: + {use_cuda_graph_recommendation()} list_bench example: - python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --with-cuda-graph + python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 model_bench example: - python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --with-cuda-graph + python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 range_bench example: - python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --with-cuda-graph --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 + python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 """, # noqa: E501 formatter_class=argparse.RawTextHelpFormatter) diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py index 3b71689a751c4..fd255306e905b 100644 --- a/benchmarks/kernels/utils.py +++ b/benchmarks/kernels/utils.py @@ -129,13 +129,13 @@ def get_cuda_graph_runner(self): self.args_iterator.reset() args_it = self.args_iterator.__next__() - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - for _ in range(num_graph_ops): - args, kwargs = next(args_it) - self.fn(*args, **kwargs) + #stream = torch.cuda.current_stream() + #with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(num_graph_ops): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) return g def run_cudagrah(self) -> TMeasurement: From 75ca40b784b3fec53916a5a17c23c40149f3cca3 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 3 Jan 2025 14:36:01 -0500 Subject: [PATCH 11/17] use stream capture Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py index fd255306e905b..3b71689a751c4 100644 --- a/benchmarks/kernels/utils.py +++ b/benchmarks/kernels/utils.py @@ -129,13 +129,13 @@ def get_cuda_graph_runner(self): self.args_iterator.reset() args_it = self.args_iterator.__next__() - #stream = torch.cuda.current_stream() - #with torch.cuda.stream(stream): - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - for _ in range(num_graph_ops): - args, kwargs = next(args_it) - self.fn(*args, **kwargs) + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(num_graph_ops): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) return g def run_cudagrah(self) -> TMeasurement: From c7d6620c021b4d41107f9fed3dbacec6981990ed Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 3 Jan 2025 14:42:30 -0500 Subject: [PATCH 12/17] format Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 61 +++++++++++++++++----------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 0e843a7592a03..6739ee2142fc0 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -608,11 +608,13 @@ def bench_fn_kwargs(self, def test_correctness(self, op_type: OpType, expand_fn_add_inputs: Optional[bool]) -> bool: """ - Test correctness of self.output against a grouped gemm reference implementation. + Test correctness of self.output against a grouped gemm reference + implementation. - For expand-related operations with add_inputs = True, since the benchmarking - setup runs the function multiple times, the accumulation into the self.output - is intractable. Correctness testing is skipped for that case. + For expand-related operations with add_inputs = True, since the + benchmarking setup runs the function multiple times, the accumulation + into the self.output is intractable. Correctness testing is skipped + for that case. """ if op_type.is_shrink_fn(): @@ -621,7 +623,8 @@ def test_correctness(self, op_type: OpType, assert expand_fn_add_inputs is not None if expand_fn_add_inputs: - print (f"WARNING: Skipping correctness testing for {op_type} with add_inputs={expand_fn_add_inputs}") + print(f"WARNING: Skipping correctness testing for {op_type} with " + f"add_inputs={expand_fn_add_inputs}") return True seq_lens_cpu = self.seq_lens.to(device="cpu") @@ -704,7 +707,10 @@ def bench_optype(ctx: BenchmarkContext, timer = bench.run() if test_correctness: - assert all([bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors]) + assert all([ + bt.test_correctness(op_type, expand_fn_add_inputs) + for bt in bench_tensors + ]) return timer @@ -754,28 +760,36 @@ def bench_torch_mm(ctx: BenchmarkContext, # runner def use_cuda_graph_recommendation() -> str: return """ - Triton kernels have a significant launch overhead with launched directly via python. - This overhead is more noticeable for small the problem sizes. For these cases, it is - recommended to use the script with `--cuda-graph-nops N` to benchmark N consecutive invocations - of the benchmarking operations from inside a CUDA Graph. Note that the returned measurement - is for N invocations of the operation. + Triton kernels have a significant launch overhead with + launched directly via python. This overhead is more noticeable + for small the problem sizes. For these cases, it is recommended + to use the script with `--cuda-graph-nops N` to benchmark N + consecutive invocations of the benchmarking operations from + inside a CUDA Graph. Note that the returned measurement is for N + invocations of the operation. """ -def print_timers(timers: List[TMeasurement], args: Optional[argparse.Namespace] = None): + +def print_timers(timers: List[TMeasurement], + args: Optional[argparse.Namespace] = None): compare = TBenchmark.Compare(timers) compare.print() if args and args.cuda_graph_nops: - print (f"The timings reported above is for {args.cuda_graph_nops} consecutive invocations of the benchmarking functions. Please divide by {args.cuda_graph_nops} for single invocation timings ") + print(f"The timings reported above is for {args.cuda_graph_nops} " + "consecutive invocations of the benchmarking functions. " + "Please divide by {args.cuda_graph_nops} for single invocation " + "timings ") def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): if args.cuda_graph_nops is not None: assert args.cuda_graph_nops > 0 - print (f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA Graph") + print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA " + "Graph") else: - print (f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}") + print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}") timers = [] for bench_ctx in bench_ctxs: @@ -939,16 +953,17 @@ def add_common_command_args(p: argparse.ArgumentParser): type=int, default=32, help="Run profiles with a pool of input/output/meta tensors instead" - "of simply reusing the same tensors for all runs. A bigger arg-pool " + "of simply reusing the same tensors for all runs. A bigger arg-pool" "mitigates hardware caching effects during benchmarking.") - p.add_argument("--cuda-graph-nops", - type=int, - help=("when set profiling is done using cudagraph, " - "with the given number of operations in a graph." - "Note that the measurement returned is the time " - "taken for N consecutive executions of the benchmarking " - "functions, where N is the value of this argument.")) + p.add_argument( + "--cuda-graph-nops", + type=int, + help=("when set profiling is done using cudagraph, " + "with the given number of operations in a graph." + "Note that the measurement returned is the time " + "taken for N consecutive executions of the benchmarking " + "functions, where N is the value of this argument.")) p.add_argument("--num-loras", nargs="+", type=int, From 06127d3c7189cd1046b450f68eaafbe23ea43f6d Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 3 Jan 2025 15:47:16 -0500 Subject: [PATCH 13/17] fix comment print test only benchmark tensors that participated Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 6739ee2142fc0..97c4a7987af6c 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -709,7 +709,8 @@ def bench_optype(ctx: BenchmarkContext, if test_correctness: assert all([ bt.test_correctness(op_type, expand_fn_add_inputs) - for bt in bench_tensors + for bt in bench_tensors[:cuda_graph_nops if cuda_graph_nops + is not None else arg_pool_size] ]) return timer @@ -778,7 +779,7 @@ def print_timers(timers: List[TMeasurement], if args and args.cuda_graph_nops: print(f"The timings reported above is for {args.cuda_graph_nops} " "consecutive invocations of the benchmarking functions. " - "Please divide by {args.cuda_graph_nops} for single invocation " + f"Please divide by {args.cuda_graph_nops} for single invocation " "timings ") From 433d1298a61cb8fb3679f590d77c68252e6d1c33 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 7 Jan 2025 09:47:53 -0500 Subject: [PATCH 14/17] add comments Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 97c4a7987af6c..3141b64a4d5e6 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -722,6 +722,10 @@ def bench_torch_mm(ctx: BenchmarkContext, cuda_graph_nops: Optional[int] = None) -> TMeasurement: """ Benchmark basic torch.mm as a roofline. + + When all the input tokens have the same LoRA ID, the LoRA kernels are just + a matmul. This torch.mm benchmark serves as a roofline for that case. + input op_type is used in determining the m, k, n dimensions for the matmul. """ @@ -746,9 +750,10 @@ def bench_torch_mm(ctx: BenchmarkContext, # Make torch.mm kwargs mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)} - description = (f"torch.mm({dtype_to_str(dtype)}" - f"x{dtype_to_str(dtype)}" - f"=>{dtype_to_str(dtype)})") + description = ( + f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}" + f"x{dtype_to_str(dtype)}" + f"=>{dtype_to_str(dtype)})") cuda_graph_params = None if cuda_graph_nops: cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) @@ -777,10 +782,18 @@ def print_timers(timers: List[TMeasurement], compare.print() if args and args.cuda_graph_nops: - print(f"The timings reported above is for {args.cuda_graph_nops} " - "consecutive invocations of the benchmarking functions. " - f"Please divide by {args.cuda_graph_nops} for single invocation " - "timings ") + print( + f"Note : The timings reported above is for {args.cuda_graph_nops} " + "consecutive invocations of the benchmarking functions. " + f"Please divide by {args.cuda_graph_nops} for single invocation " + "timings.") + + print("Note on Comparison with torch.mm : The torch.mm numbers are " + "benchmark numbers of a simple matmul emulating the single lora " + "case. It is provided as a roofline for comparing our LoRA Kernel " + "implementations. It is expected that the LoRA kernels will be " + "slower than torch.mm in cases where num_loras is big. But for " + "small num_loras the goal should be to match the torch.mm numbers.") def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): From faebdc994d7418e6f1846d248817d78b5253b526 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 8 Jan 2025 21:40:55 -0500 Subject: [PATCH 15/17] refactor Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 421 +++++++++++++++------------ benchmarks/kernels/utils.py | 5 +- 2 files changed, 244 insertions(+), 182 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 3141b64a4d5e6..7f5df010a913e 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -19,8 +19,8 @@ from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) @@ -37,7 +37,7 @@ DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False] -## Utilities +# Utilities def dtype_to_str(dtype: torch.dtype): if dtype == torch.float16: return "f16" @@ -59,32 +59,27 @@ def make_rand_lora_weight_tensor(k: int, def make_rand_tensors( - m: int, - k: int, - n: int, - num_loras: int, - num_slices: Optional[int], + a_shape: Tuple[int], + b_shape: Tuple[int], + c_shape: Tuple[int], a_dtype: torch.dtype, b_dtype: torch.dtype, c_dtype: torch.dtype, + num_slices: int, device: str = "cuda", ) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: - # Make input / output tensors - # Input matrix A of shape {m, k} - # num_slices Input matrices B of shape {k, n} - # Output matrix C of shape {m, n * num_slices} - num_slices = num_slices if num_slices is not None else 1 - - A = torch.rand((m, k), dtype=a_dtype).to(device) + """ + Make LoRA input/output matrices. + """ + A = torch.rand(a_shape, dtype=a_dtype).to(device) # LoRA weights column major Bs = [ - make_rand_lora_weight_tensor(k, n, num_loras, b_dtype, device) + torch.rand(b_shape, dtype=b_dtype).to(device) for _ in range(num_slices) ] - C = torch.zeros((m, n * num_slices), dtype=c_dtype).to(device) - + C = torch.zeros(c_shape, dtype=c_dtype).to(device) return A, Bs, C @@ -166,30 +161,14 @@ def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, ref_out.copy_(cat_result) -def ref_group_gemm_with_slices(ref_out: torch.Tensor, input: torch.Tensor, - lora_weights: List[torch.Tensor], - seq_lens_cpu: torch.Tensor, - prompt_lora_mapping_cpu: torch.Tensor, - scaling: float, add_inputs: Optional[bool], - num_slices: int, hidden_size: int): - for slice_idx in range(num_slices): - slice_offset = slice_idx * hidden_size - ref_group_gemm(ref_out[:, slice_offset:slice_offset + hidden_size], - input, - lora_weights[slice_idx], - seq_lens_cpu, - prompt_lora_mapping_cpu, - scaling=scaling, - add_inputs=add_inputs) - - -## LoRA Ops to Benchmark and its properties class OpType(Enum): + """ + LoRA Ops to benchmark and its properties. + """ SGMV_SHRINK = auto() BGMV_SHRINK = auto() SGMV_EXPAND = auto() BGMV_EXPAND = auto() - SGMV_EXPAND_SLICE = auto() BGMV_EXPAND_SLICE = auto() @staticmethod @@ -202,8 +181,6 @@ def from_str(s: str) -> "OpType": return OpType.BGMV_SHRINK if s.lower() == 'bgmv_expand': return OpType.BGMV_EXPAND - if s.lower() == "sgmv_expand_slice": - return OpType.SGMV_EXPAND_SLICE if s.lower() == "bgmv_expand_slice": return OpType.BGMV_EXPAND_SLICE raise ValueError(f"Unrecognized str {s} to convert to OpType") @@ -214,23 +191,26 @@ def is_shrink_fn(self) -> bool: def is_expand_fn(self) -> bool: return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND] - def is_expand_slice_fn(self) -> bool: - return self in [OpType.SGMV_EXPAND_SLICE, OpType.BGMV_EXPAND_SLICE] - def is_prefill_op(self) -> bool: - return self in [ - OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.SGMV_EXPAND_SLICE - ] + return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND] def is_decode_op(self) -> bool: return self in [ OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE ] + def is_expand_slice_fn(self) -> bool: + return self in [OpType.BGMV_EXPAND_SLICE] + def num_slices(self) -> List[int]: - if self.is_expand_slice_fn(): + if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]: + # SGMV kernels supports slices + return [1, 2, 3] + if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]: + return [1] + if self in [OpType.BGMV_EXPAND_SLICE]: return [2, 3] - return [1] + raise ValueError(f"Unrecognized OpType {self}") def mkn(self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int) -> Tuple[int, int, int]: @@ -258,11 +238,33 @@ def matmul_dtypes( assert self.is_expand_fn() or self.is_expand_slice_fn() return torch.float32, op_dtype, op_dtype - def bench_fn(self) -> Callable: + def matmul_shapes( + self, batch_size: int, seq_length: int, hidden_size: int, + lora_rank: int, num_loras: int, + num_slices: int) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]: + """ + Given num_slices, return the shapes of the A, B, and C matrices + in A x B = C, for the op_type + """ + m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank) - def emulate_sgmv_expand_slice(kwargs_list: List[Dict[str, Any]]): - for x in kwargs_list: - sgmv_expand_slice(**x) + b_shape = (num_loras, n, k) # col-major + if self == OpType.SGMV_SHRINK: + # SGMV shrink supports num_slices inherently in the kernel + return ((m, k), b_shape, (num_slices, m, n)) + if self == OpType.SGMV_EXPAND: + # SGMV expand supports num_slices inherently in the kernel + return ((num_slices, m, k), b_shape, (m, n * num_slices)) + if self == OpType.BGMV_SHRINK: + return ((m, k), b_shape, (m, n)) + if self == OpType.BGMV_EXPAND: + return ((m, k), b_shape, (m, n)) + if self == OpType.BGMV_EXPAND_SLICE: + return ((num_slices, m, k), b_shape, (m, n * num_slices)) + + raise ValueError(f"Unrecognized op_type {self}") + + def bench_fn(self) -> Callable: def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]): for x in kwargs_list: @@ -276,12 +278,58 @@ def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]): return bgmv_shrink if self == OpType.BGMV_EXPAND: return bgmv_expand - if self == OpType.SGMV_EXPAND_SLICE: - return emulate_sgmv_expand_slice if self == OpType.BGMV_EXPAND_SLICE: return emulate_bgmv_expand_slice raise ValueError(f"Unrecognized optype {self}") + def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor, + lora_weights: List[torch.Tensor], + **kwargs) -> Callable: + """Each benchmark operation expected the input, lora_weights and outputs + in a slightly different format. Refer to self.matmul_shapes(). + run_ref_group_gemm accounts for those differences in executing a + reference group gemm for correctness testing. + """ + w_dtype = lora_weights[0].dtype + num_slices = len(lora_weights) + if self == OpType.SGMV_SHRINK: + for slice_idx in range(num_slices): + ref_group_gemm(ref_out=output[slice_idx, :], + input=input, + lora_weights=lora_weights[slice_idx], + **kwargs) + if self == OpType.SGMV_EXPAND: + hidden_size = lora_weights[0].shape[1] + for slice_idx in range(num_slices): + slice_offset = slice_idx * hidden_size + ref_group_gemm( + ref_out=output[:, slice_offset:slice_offset + hidden_size], + input=input[slice_idx].clone().to(dtype=w_dtype), + lora_weights=lora_weights[slice_idx], + **kwargs) + if self == OpType.BGMV_SHRINK: + assert num_slices == 1 + ref_group_gemm(ref_out=output, + input=input, + lora_weights=lora_weights[0], + **kwargs) + if self == OpType.BGMV_EXPAND: + assert num_slices == 1 + ref_group_gemm(ref_out=output, + input=input.clone().to(dtype=w_dtype), + lora_weights=lora_weights[0], + **kwargs) + if self == OpType.BGMV_EXPAND_SLICE: + hidden_size = lora_weights[0].shape[1] + for slice_idx in range(num_slices): + slice_offset = slice_idx * hidden_size + ref_group_gemm( + ref_out=output[:, slice_offset:slice_offset + hidden_size], + input=input[slice_idx].clone().to(dtype=w_dtype), + lora_weights=lora_weights[slice_idx], + **kwargs) + raise ValueError(f"Unrecognized optype {self}") + @dataclass class BenchmarkContext: @@ -296,14 +344,14 @@ class BenchmarkContext: sort_by_lora_id: bool dtype: torch.dtype seq_length: Optional[int] = None - num_slices: Optional[int] = None # num_slices for expand_slice kernels + num_slices: Optional[int] = None # num_slices for slice based ops def with_seq_length(self, seq_length: int) -> "BenchmarkContext": ctx = copy.copy(self) ctx.seq_length = seq_length return ctx - def with_num_slices(self, num_slices: Optional[int]) -> "BenchmarkContext": + def with_num_slices(self, num_slices: int) -> "BenchmarkContext": ctx = copy.copy(self) ctx.num_slices = num_slices return ctx @@ -352,18 +400,16 @@ def make(ctx: BenchmarkContext, op_type: OpType, device: str = "cuda") -> "BenchmarkTensors": - ## Make input / output matmul tensors + # Make input / output matmul tensors. + a_shape, b_shape, c_shape = op_type.matmul_shapes( + ctx.batch_size, ctx.seq_length, ctx.hidden_size, ctx.lora_rank, + ctx.num_loras, ctx.num_slices) a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) - m, k, n = op_type.mkn(ctx.batch_size, ctx.seq_length, ctx.hidden_size, - ctx.lora_rank) input_tensor, lora_weights, output_tensor = \ - make_rand_tensors(m, k, n, ctx.num_loras, - num_slices = ctx.num_slices, - a_dtype = a_type, - b_dtype = b_type, - c_dtype = c_type) + make_rand_tensors(a_shape, b_shape, c_shape, a_type, b_type, c_type, + num_slices = ctx.num_slices) - ## Make metadata tensors + # Make metadata tensors. # Keep the metadata tensors in the CPU for further processing if needed. # The tensors get moved to the GPU before benchmarking. assert ctx.num_active_loras <= ctx.num_loras @@ -394,24 +440,13 @@ def sanity_check(self) -> None: """ Fails asserts when non-conformality is detected. """ - # Check that the tensors have the right shapes - m = self.input.shape[0] - k = self.input.shape[1] - n = self.output.shape[1] - - # check matmul tensors - assert self.output.shape[0] == m - assert len(self.lora_weights_lst) >= 1 - num_slices = len(self.lora_weights_lst) - for w in self.lora_weights_lst: - _, w_n, w_k = w.shape # n, k flipped due to col-major ordering. - assert (w_n, w_k) == (n, k) or (w_n * num_slices, w_k) == (n, k) + num_tokens = self.input.shape[-2] # check metadata tensors - assert torch.sum(self.seq_lens) == m + assert torch.sum(self.seq_lens) == num_tokens num_seqs = self.seq_lens.shape[0] assert self.seq_start_loc.shape[0] == num_seqs assert self.prompt_lora_mapping.shape[0] == num_seqs - assert self.token_lora_mapping.shape[0] == m + assert self.token_lora_mapping.shape[0] == num_tokens def to_device(self, device: str): """ @@ -437,13 +472,14 @@ def metadata(self) -> Tuple[int, int, int]: Return num_seqs, num_tokens and max_seq_len """ num_seqs = self.seq_lens.shape[0] - num_tokens = self.input.shape[0] + num_tokens = self.token_lora_mapping.shape[0] max_seq_len = torch.max(self.seq_lens).item() - return num_seqs, num_tokens, max_seq_len + num_slices = len(self.lora_weights_lst) + return num_seqs, num_tokens, max_seq_len, num_slices def convert_to_sgmv_benchmark_tensors(self): """ - for sgmv punica kernels, when consecutive sequences have the + For sgmv punica kernels, when consecutive sequences have the same LoRA ID, we just merge them together. This happens in punica.py::compute_metadata """ @@ -467,18 +503,31 @@ def convert_to_sgmv_benchmark_tensors(self): self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype) self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype) - ## Benchmark function args. def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]: - assert len(self.lora_weights_lst) == 1 - self.convert_to_sgmv_benchmark_tensors() self.sanity_check() self.to_device(self.input.device) - num_seqs, num_tokens, max_seq_len = self.metadata() + num_seqs, num_tokens, max_seq_len, num_slices = self.metadata() + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape [num_tokens, hidden_size] + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + hidden_size = i_shape[1] + # Expected lora weight shape [num_loras, lora_rank, hidden_size] + assert len(lw_shape) == 3 + assert lw_shape[2] == hidden_size + lora_rank = lw_shape[1] + # Expected output shape [num_slices, num_tokens, lora_rank] + assert len(o_shape) == 3 + assert o_shape == (num_slices, num_tokens, lora_rank) + return { 'inputs': self.input, - 'lora_a_weights': self.lora_weights_lst[0], + 'lora_a_weights': self.lora_weights_lst, 'output_tensor': self.output, 'b_seq_start_loc': self.seq_start_loc, 'seq_len_tensor': self.seq_lens, @@ -490,16 +539,32 @@ def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]: } def as_sgmv_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]: - assert len(self.lora_weights_lst) == 1 self.convert_to_sgmv_benchmark_tensors() self.sanity_check() self.to_device(self.input.device) - num_seqs, num_tokens, max_seq_len = self.metadata() + num_seqs, num_tokens, max_seq_len, num_slices = self.metadata() + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape : [num_slices, num_tokens, lora_rank] + assert len(i_shape) == 3 + assert i_shape[0] == num_slices + assert i_shape[1] == num_tokens + lora_rank = i_shape[2] + # Expected lora weight shape : [num_lora, hidden_size, lora_rank] + assert len(lw_shape) == 3 + assert lw_shape[2] == lora_rank + hidden_size = lw_shape[1] + # Expected output shape : [num_tokens, hidden_size * num_slices] + assert len(o_shape) == 2 + assert o_shape == (num_tokens, hidden_size * num_slices) + return { 'inputs': self.input, - 'lora_b_weights': self.lora_weights_lst[0], + 'lora_b_weights': self.lora_weights_lst, 'output_tensor': self.output, 'b_seq_start_loc': self.seq_start_loc, 'seq_len_tensor': self.seq_lens, @@ -507,12 +572,30 @@ def as_sgmv_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]: 'batches': num_seqs, 'max_seq_length': max_seq_len, 'token_nums': num_tokens, + 'offset_start': 0, 'add_inputs': add_inputs, } def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]: assert len(self.lora_weights_lst) == 1 self.to_device(self.input.device) + + _, num_tokens, _, _ = self.metadata() + # Sanity check shapes + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape [num_tokens, hidden_size] + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + hidden_size = i_shape[1] + # Expected lora weight shape [num_loras, lora_rank, hidden_size] + assert len(lw_shape) == 3 + assert lw_shape[2] == hidden_size + lora_rank = lw_shape[1] + # Expected output shape [num_tokens, lora_rank] + assert len(o_shape) == 2 + assert o_shape == (num_tokens, lora_rank) + return { 'inputs': self.input, 'lora_a_weights': self.lora_weights_lst[0], @@ -524,6 +607,23 @@ def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]: def as_bgmv_expand_kwargs(self, add_inputs: bool): assert len(self.lora_weights_lst) == 1 self.to_device(self.input.device) + + _, num_tokens, _, _ = self.metadata() + # Sanity check shapes + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape [num_tokens, lora_rank] + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + lora_rank = i_shape[1] + # Expected lora weight shape [num_loras, hidden_size, lora_rank] + assert len(lw_shape) == 3 + assert lw_shape[2] == lora_rank + hidden_size = lw_shape[1] + # Expected output shape [num_tokens, hidden_size] + assert len(o_shape) == 2 + assert o_shape == (num_tokens, hidden_size) + return { 'inputs': self.input, 'lora_b_weights': self.lora_weights_lst[0], @@ -532,53 +632,36 @@ def as_bgmv_expand_kwargs(self, add_inputs: bool): 'add_inputs': add_inputs } - def as_sgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]: - assert len(self.lora_weights_lst) > 1 - self.convert_to_sgmv_benchmark_tensors() - self.sanity_check() - - self.to_device(self.input.device) - num_seqs, num_tokens, max_seq_len = self.metadata() - - num_slices = len(self.lora_weights_lst) - slice_size = self.lora_weights_lst[0].shape[-2] # n - assert slice_size * num_slices == self.output.shape[-1] - - kwargs_list = [] - for i in range(num_slices): - kwargs_list.append({ - 'inputs': self.input, - 'lora_b_weights': self.lora_weights_lst[i], - 'output_tensor': self.output, - 'b_seq_start_loc': self.seq_start_loc, - 'seq_len_tensor': self.seq_lens, - 'lora_indices_tensor': self.prompt_lora_mapping, - 'batches': num_seqs, - 'max_seq_length': max_seq_len, - 'token_nums': num_tokens, - 'slice_offset': i * slice_size, - 'slice_size': slice_size, - 'add_inputs': add_inputs, - }) - return {'kwargs_list': kwargs_list} - def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]: - assert len(self.lora_weights_lst) > 1 - num_slices = len(self.lora_weights_lst) - slice_size = self.lora_weights_lst[0].shape[-2] # n - assert slice_size * num_slices == self.output.shape[-1] + + _, num_tokens, _, num_slices = self.metadata() + # Sanity check shapes + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape [num_slices, num_tokens, lora_rank] + assert len(i_shape) == 3 + assert i_shape[0] == num_slices + assert i_shape[1] == num_tokens + lora_rank = i_shape[2] + # Expected lora weight shape [num_loras, hidden_size, lora_rank] + assert len(lw_shape) == 3 + assert lw_shape[2] == lora_rank + hidden_size = lw_shape[1] + # Expected output shape [num_tokens, hidden_size * num_slices] + assert len(o_shape) == 2 + assert o_shape == (num_tokens, hidden_size * num_slices) self.to_device(self.input.device) kwargs_list = [] for i in range(num_slices): kwargs_list.append({ - 'inputs': self.input, + 'inputs': self.input[i], 'lora_b_weights': self.lora_weights_lst[i], 'output_tensor': self.output, 'lora_indices_tensor': self.token_lora_mapping, - 'slice_offset': i * slice_size, - 'slice_size': slice_size, + 'slice_offset': i * hidden_size, + 'slice_size': hidden_size, 'add_inputs': add_inputs, }) return {'kwargs_list': kwargs_list} @@ -599,8 +682,6 @@ def bench_fn_kwargs(self, return self.as_bgmv_shrink_kwargs() if op_type == OpType.BGMV_EXPAND: return self.as_bgmv_expand_kwargs(add_inputs) - if op_type == OpType.SGMV_EXPAND_SLICE: - return self.as_sgmv_expand_slice_kwargs(add_inputs) if op_type == OpType.BGMV_EXPAND_SLICE: return self.as_bgmv_expand_slice_kwargs(add_inputs) raise ValueError(f"Unrecognized optype {self}") @@ -608,55 +689,32 @@ def bench_fn_kwargs(self, def test_correctness(self, op_type: OpType, expand_fn_add_inputs: Optional[bool]) -> bool: """ - Test correctness of self.output against a grouped gemm reference - implementation. - - For expand-related operations with add_inputs = True, since the - benchmarking setup runs the function multiple times, the accumulation - into the self.output is intractable. Correctness testing is skipped - for that case. + Test correctness of op_type implementation against a grouped gemm + reference implementation. """ - - if op_type.is_shrink_fn(): - assert expand_fn_add_inputs is None - else: - assert expand_fn_add_inputs is not None - - if expand_fn_add_inputs: - print(f"WARNING: Skipping correctness testing for {op_type} with " - f"add_inputs={expand_fn_add_inputs}") - return True - seq_lens_cpu = self.seq_lens.to(device="cpu") prompt_lora_mapping_cpu = self.prompt_lora_mapping.to(device="cpu") - ref_output = self.output.clone() - num_slices = len(self.lora_weights_lst) - hidden_size = self.lora_weights_lst[0].shape[-2] # n - assert hidden_size * num_slices == self.output.shape[-1] + self.output.zero_() + op_type.bench_fn()( + **self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) - do_input_cast: bool = op_type.is_expand_fn( - ) or op_type.is_expand_slice_fn() - weight_dtype = self.lora_weights_lst[0].dtype - ref_group_gemm_with_slices( + op_type.run_ref_group_gemm( ref_output, - self.input.clone().to( - dtype=weight_dtype) if do_input_cast else self.input, + self.input, self.lora_weights_lst, - seq_lens_cpu, - prompt_lora_mapping_cpu, + seq_lens_cpu=seq_lens_cpu, + prompt_lora_mapping_cpu=prompt_lora_mapping_cpu, scaling=1.0, - add_inputs=expand_fn_add_inputs, - num_slices=num_slices, - hidden_size=hidden_size, - ) + add_inputs=expand_fn_add_inputs) rtol, atol = { torch.float16: (6e-2, 6e-2), torch.bfloat16: (6e-2, 6e-2), torch.float32: (1e-2, 1e-2), }[self.output.dtype] + return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol) @@ -679,40 +737,46 @@ def bench_optype(ctx: BenchmarkContext, for bt in bench_tensors: bt.sanity_check() + # Test correctness of our implementation. + if test_correctness: + assert all([ + bt.test_correctness(op_type, expand_fn_add_inputs) + for bt in bench_tensors + ]) + # BenchmarkTensors -> Dict (kwargs) kwargs_list = [ bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) for bt in bench_tensors ] - # Merge into a single kwargs and quality arguments as ArgPool + # Clear LoRA optimization hash-maps. + _LORA_A_PTR_DICT.clear() + _LORA_B_PTR_DICT.clear() + # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup + for kwargs in kwargs_list: + op_type.bench_fn()(**kwargs) + torch.cuda.synchronize() + + # Merge into a single kwargs and qualify arguments as ArgPool kwargs = {k: ArgPool([]) for k in kwargs_list[0]} for _kwargs in kwargs_list: for k, v in _kwargs.items(): kwargs[k].values.append(v) - cuda_graph_params = None - if cuda_graph_nops: - cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) - describe_args = (f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else "") description = ( f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})") + cuda_graph_params = None + if cuda_graph_nops: + cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) timer = None with Bench(cuda_graph_params, ctx.bench_label(), ctx.bench_sublabel(op_type), description, op_type.bench_fn(), **kwargs) as bench: timer = bench.run() - - if test_correctness: - assert all([ - bt.test_correctness(op_type, expand_fn_add_inputs) - for bt in bench_tensors[:cuda_graph_nops if cuda_graph_nops - is not None else arg_pool_size] - ]) - return timer @@ -736,9 +800,8 @@ def bench_torch_mm(ctx: BenchmarkContext, ctx.dtype) m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank) - if op_type.is_expand_slice_fn(): - # For a fairer comparison. - n = n * ctx.num_slices + # For a fairer comparison. + n = n * ctx.num_slices # Get matmul input and output tensors for A x B = C As, Bs, Cs = [], [], [] @@ -1016,8 +1079,8 @@ def add_common_command_args(p: argparse.ArgumentParser): p.add_argument( "--test-correctness", action='store_true', - help=("When enabled, the benchmarking objects are additionally " - "checked for correctness")) + help=("When enabled, the benchmarking functions are tested" + "for correctness before the actual benchmarking")) parser = FlexibleArgumentParser( description=f""" diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py index 3b71689a751c4..fee877b6f76fa 100644 --- a/benchmarks/kernels/utils.py +++ b/benchmarks/kernels/utils.py @@ -13,13 +13,13 @@ class CudaGraphBenchParams: @dataclasses.dataclass class ArgPool: - ''' + """ When some argument of the benchmarking function is annotated with this type, the benchmarking class (BenchMM) will collapse the argument to a pick a single value from the given list of values, during function invocation. For every invocation during a benchmarking run, it will choose a different value from the list. - ''' + """ values: Iterable[Any] def __getitem__(self, index): @@ -128,7 +128,6 @@ def get_cuda_graph_runner(self): self.args_iterator.reset() args_it = self.args_iterator.__next__() - stream = torch.cuda.Stream() with torch.cuda.stream(stream): g = torch.cuda.CUDAGraph() From 1b412912c8931c289c8584b1ed74f9945a0fbe7a Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 13 Jan 2025 23:19:44 -0500 Subject: [PATCH 16/17] fix example commands Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 7f5df010a913e..e3a75522cafea 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -1088,13 +1088,13 @@ def add_common_command_args(p: argparse.ArgumentParser): {use_cuda_graph_recommendation()} list_bench example: - python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 + python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 model_bench example: - python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 + python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 range_bench example: - python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 + python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 """, # noqa: E501 formatter_class=argparse.RawTextHelpFormatter) From e31e1fda2e4d9f5a3051ae55848fd87db9a1c980 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 16 Jan 2025 10:59:42 +0000 Subject: [PATCH 17/17] fix imports Signed-off-by: Varun Sundar Rabindranath --- benchmarks/kernels/benchmark_lora.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index e3a75522cafea..e1f613e1da509 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -15,12 +15,12 @@ from utils import ArgPool, Bench, CudaGraphBenchParams from weight_shapes import WEIGHT_SHAPES -from vllm.lora.ops.bgmv_expand import bgmv_expand -from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice -from vllm.lora.ops.bgmv_shrink import bgmv_shrink -from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_shrink import sgmv_shrink -from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand +from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice +from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink +from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand +from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())