diff --git a/compress.py b/compress.py index 34b7c7d..75ba484 100644 --- a/compress.py +++ b/compress.py @@ -3,9 +3,9 @@ import sys from loguru import logger from utils import set_seed, dump_to_huggingface_repos, load_model_and_tokenizer -from palu.rank_search import rank_search +from palu.compress.rank_search import rank_search from tqdm import tqdm -from palu.decomposition import compress_model +from palu.compress.decomposition import compress_model from run_lm_eval import run_lm_eval_zero_shot import os diff --git a/kernel/abx_rope.py b/kernel/abx_rope.py deleted file mode 100644 index 2ec7199..0000000 --- a/kernel/abx_rope.py +++ /dev/null @@ -1,280 +0,0 @@ -"""We want triton==3.0.0 for this script -""" - -import torch -import triton -import triton.language as tl -import argparse - -from .pytorch_reference import LlamaRotaryEmbedding, apply_rotary_pos_emb_pytorch - - -def set_random_seed(seed=0): - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - -@triton.jit -def get_freq_multi_tokens(starting_idx, theta: tl.constexpr, NB_TOKENS: tl.constexpr): - DIM: tl.constexpr = 128 # in model, dim = self.params.dim // self.params.n_heads - DIM_2: tl.constexpr = 64 - freqs = tl.arange(0, DIM_2) * 2 - freqs = freqs.to(tl.float32) / DIM - freqs = tl.extra.cuda.libdevice.fast_powf(theta, freqs) - freqs = (tl.arange(0, NB_TOKENS) + starting_idx)[:, None] / freqs[None, :] - return tl.extra.cuda.libdevice.fast_cosf(freqs), tl.extra.cuda.libdevice.fast_sinf(freqs) - - -def get_configs(): - configs = [] - for block_l in [16, 32, 64, 128]: - for block_r in [16, 32]: - for num_warps in [1, 4, 8, 16]: - for num_stages in [1, 2, 3]: - configs.append( - triton.Config({'BLOCK_SIZE_L': block_l, 'BLOCK_SIZE_R': block_r}, - num_stages=num_stages, num_warps=num_warps)) - # return configs - # return [triton.Config({'BLOCK_SIZE_L': 128, 'BLOCK_SIZE_R': 32}, num_warps=4, num_stages=3)] # for gs=4 - # return [triton.Config({'BLOCK_SIZE_L': 64, 'BLOCK_SIZE_R': 32}, num_warps=4, num_stages=3)] # for gs=2 - return [triton.Config({'BLOCK_SIZE_L': 64, 'BLOCK_SIZE_R': 32}, num_warps=4, num_stages=1)] # for gs=1 - -@triton.autotune( - configs= get_configs(), - key=["seq_len"] -) -@triton.jit -def _abx_fwd( - a_ptr, b_ptr, x_ptr, out_ptr, - stride_az, stride_aa, stride_ad, - stride_bz, stride_br, stride_bd, - stride_xhg, stride_xl, stride_xr, - stride_oz, stride_oa, stride_ol, - R, D, seq_len, - BLOCK_SIZE_D: tl.constexpr, - BLOCK_SIZE_R: tl.constexpr, - BLOCK_SIZE_L: tl.constexpr, - NUM_GROUPS: tl.constexpr, - THETA: tl.constexpr, -): - pid_h = tl.program_id(axis=0) # number of heads - pid_l = tl.program_id(axis=1) # nubmer of block along seq_length dimension - - # Assuming NUM_GROUPS = 4, then pid_h = 0, 1, 2, 3 will be assigned to head group 0 - HEAD_GROUPS_ID = pid_h // (32 // NUM_GROUPS) - offs_ds = tl.arange(0, BLOCK_SIZE_D) # same as offs_bds - offs_rs = tl.arange(0, BLOCK_SIZE_R) - offs_ls = (pid_l * BLOCK_SIZE_L) + tl.arange(0, BLOCK_SIZE_L) - - A_ptrs = a_ptr + pid_h * stride_az + (0*stride_aa + offs_ds[None, :]*stride_ad) # assume a is always (bs, 1, d) - B_ptrs = b_ptr + pid_h * stride_bz + (offs_rs[:, None]*stride_br + offs_ds[None, :]*stride_bd) - X_ptrs = x_ptr + HEAD_GROUPS_ID * stride_xhg + (offs_ls[:, None]*stride_xl + offs_rs[None, :]*stride_xr) - O_ptrs = out_ptr + pid_h * stride_oz + (0*stride_oa + offs_ls[None, :]*stride_ol) - - # Fix BLOCK_SIZE_D = 64, and head_dim = 128 - xb_0 = tl.zeros((BLOCK_SIZE_L, BLOCK_SIZE_D), dtype=tl.float32) - xb_1 = tl.zeros((BLOCK_SIZE_L, BLOCK_SIZE_D), dtype=tl.float32) - for _ in range(0, tl.cdiv(R, BLOCK_SIZE_R)): - # Load next block of B, X - x = tl.load(X_ptrs) - b_0 = tl.load(B_ptrs) - b_1 = tl.load(B_ptrs + BLOCK_SIZE_D * stride_bd) - # Accumulate along R dimension. - xb_0 = tl.dot(x, b_0, xb_0) - xb_1 = tl.dot(x, b_1, xb_1) - # Advance the pointers to next blocks - B_ptrs += BLOCK_SIZE_R * stride_br - X_ptrs += BLOCK_SIZE_R * stride_xr - - xb_0 = xb_0.to(tl.float16) - xb_1 = xb_1.to(tl.float16) - - # RoPE - start_block = pid_l * BLOCK_SIZE_L - cos, sin = get_freq_multi_tokens(starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_L) - cos = cos.to(tl.float16) - sin = sin.to(tl.float16) - - xb_rope_0 = xb_0 * cos - xb_1 * sin - xb_rope_1 = xb_1 * cos + xb_0 * sin - xb_0 = xb_rope_0.to(tl.float16) - xb_1 = xb_rope_1.to(tl.float16) - - # GEMV - a_0 = tl.load(A_ptrs) - a_1 = tl.load(A_ptrs + BLOCK_SIZE_D * stride_ad) - abx_0 = tl.sum(a_0 * xb_0, 1) - abx_1 = tl.sum(a_1 * xb_1, 1) - abx = abx_0 + abx_1 - tl.store(O_ptrs, abx[None, :]) - - -def abx(a: torch.Tensor, b: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - # U x V x X - assert a.dim() == 3 - assert b.dim() == 3 - assert x.dim() == 3 - - num_heads, _, head_dim = a.shape - num_heads,rank_per_head_groups, head_dim = b.shape - num_groups, seq_len, rank_per_head_groups = x.shape - # Allocate output tensor - out = torch.empty((num_heads, 1, seq_len), dtype=x.dtype, device=x.device) - BLOCK_SIZE_D = 64 - # BLOCK_SIZE_R = 32 - # BLOCK_SIZE_L = 128 - # num_stages = 1 - # num_warps = 8 - NUM_GROUPS = num_groups - - grid = lambda META: (32, triton.cdiv(seq_len, META["BLOCK_SIZE_L"])) - _abx_fwd[grid]( - a, b, x, out, - a.stride(0), a.stride(1), a.stride(2), - b.stride(0), b.stride(1), b.stride(2), - x.stride(0), x.stride(1), x.stride(2), - out.stride(0), out.stride(1), out.stride(2), - R = rank_per_head_groups, - D = head_dim, - seq_len = seq_len, - BLOCK_SIZE_D = BLOCK_SIZE_D, - # BLOCK_SIZE_L = BLOCK_SIZE_L, - # BLOCK_SIZE_R = BLOCK_SIZE_R, - # num_stages=num_stages, - # num_warps=num_warps, - NUM_GROUPS = NUM_GROUPS, - THETA = 10000., - ) - return out - -def torch_abx(a, b, x): - # Input shape - # a: (num_heads, 1, head_dim) - # b: (num_heads, rank_per_groups, head_dim) - # x: (num_groups, seq_len, rank_per_groups) - - # Recompute the key states - # x: (num_groups, 1, seq_len, rank_per_groups) - # b: (num_groups, group_size, rank_per_groups, head_dim) - # xb: (num_heads, seq_len, head_dim) - x_expand = x.unsqueeze(1) - b_reshape = b.reshape(-1, b.shape[0] // x.shape[0], b.shape[-2], b.shape[-1]) - xb = x_expand @ b_reshape - xb = xb.reshape(b.shape[0], -1, b.shape[-1]) - - # Apply RoPE - cos, sin = LlamaRotaryEmbedding(dim=128, end=x.shape[1]) - xb_rope = apply_rotary_pos_emb_pytorch(x=xb, cos=cos, sin=sin) - axb = a @ xb_rope.transpose(-1, -2).to(torch.float16) - return axb - -def run_benchmark(args): - configs = [] - configs.append( - triton.testing.Benchmark( - x_names=["seq_len"], - x_vals=args.target_seq_lens, - line_arg="provider", - line_vals=["WX", "torch", "ours"], - line_names=["WX", "Torch", "Ours"], - styles=[("gray", "--"), ("green", "--"), ("blue", "-")], - ylabel="us", - plot_name=f"low-rank-rank-{args.total_rank}-group-{args.num_groups}", - args={ - "dtype": torch.float16, - "num_heads": args.num_heads, - "head_dim": args.head_dim, - "total_rank": args.total_rank, - "num_groups": args.num_groups, # number of head groups - }, - )) - - @triton.testing.perf_report(configs) - def bench_low_rank(num_heads, head_dim, total_rank, seq_len, num_groups, provider, dtype=torch.float16, device="cuda"): - rank_per_groups = total_rank // num_groups - - warmup = 25 - rep = 100 - A = torch.randn(num_heads, 1, head_dim, dtype=dtype, device=device) - B = torch.randn(num_heads, rank_per_groups, head_dim, dtype=dtype, device=device) - X = torch.randn(num_groups, seq_len, rank_per_groups, dtype=dtype, device=device) - org_A = torch.randn(num_heads, 1, head_dim, dtype=dtype, device=device) - org_X = torch.randn(num_heads, seq_len, head_dim, dtype=dtype, device=device) - - - quantiles = [0.5, 0.2, 0.8] - if provider == "torch": - def fn(): return torch_abx(A, B, X) - ms, min_ms, max_ms = triton.testing.do_bench( - fn, quantiles=quantiles, warmup=warmup, rep=rep) - - if provider == "ours": - def fn(): return abx(A, B, X) - ms, min_ms, max_ms = triton.testing.do_bench( - fn, quantiles=quantiles, warmup=warmup, rep=rep) - - if provider == "WX": - def fn(): return torch.matmul(org_A, org_X.transpose(-1, -2)) - ms, min_ms, max_ms = triton.testing.do_bench( - fn, quantiles=quantiles, warmup=warmup, rep=rep) - - return ms*1000, min_ms*1000, max_ms*1000 - - import os - # create a directory to store the results - os.makedirs('results', exist_ok=True) - bench_low_rank.run(print_data=True, show_plots=True, save_path='results/') - -def run_test(args): - num_heads = args.num_heads - head_dim = args.head_dim - total_rank = args.total_rank - seq_len = 64 - num_groups = args.num_groups - rank_per_groups = total_rank // num_groups - dtype = torch.float16 - device = "cuda" - - A = torch.randn(num_heads, 1, head_dim, dtype=dtype, device=device) - B = torch.randn(num_heads, rank_per_groups, head_dim, dtype=dtype, device=device) - X = torch.randn(num_groups, seq_len, rank_per_groups, dtype=dtype, device=device) - - x, xb, xb_rope, xb_rope_0, xb_rope_1, axb, cos, sin, freqs = torch_abx(A, B, X) - ours = abx(A, B, X) - - print("Max diff: ", torch.max(torch.abs(axb - ours))) - -def parse_args(): - parser = argparse.ArgumentParser(description="Argument Parser") - parser.add_argument("--total_rank", type=int, default=2048, help="Total rank") - parser.add_argument("--num_heads", type=int, default=32, help="Number of heads, default to 32 (llama)") - parser.add_argument("--head_dim", type=int, default=128, help="Head dimension, default to 128 (llama)") - parser.add_argument("--group_size", type=int, default=4, help="Number of heads per group") - parser.add_argument("--target_seq_lens", nargs="+", type=int, - default=[4096, 16384, 65536, 262144], help="Target sequence lengths") - parser.add_argument("--check", action="store_true", help="Check the correctness of the implementation") - args = parser.parse_args() - return args - -def main(args): - args.num_groups = args.num_heads // args.group_size - args.group_rank = args.total_rank // args.num_groups - print("Start benchmarking fused low-rank KV Cache Kernels...") - print("Total Rank: ", args.total_rank) - print("Number of Heads: ", args.num_heads) - print("Head Dimension: ", args.head_dim) - print("Group Size:", args.group_size) - print("Number of Groups: ", args.num_groups) - print("Rank per Group: ", args.group_rank) - if args.check: - run_test(args) - else: - run_benchmark(args) - -if __name__ == "__main__": - set_random_seed() - args = parse_args() - main(args) - diff --git a/kernel/palu_attention.py b/kernel/palu_attention.py deleted file mode 100644 index 3d408ee..0000000 --- a/kernel/palu_attention.py +++ /dev/null @@ -1,308 +0,0 @@ -import math -import warnings -from typing import Optional, Tuple - -import torch -from torch import nn - -from transformers.models.llama.modeling_llama import ( - Cache, apply_rotary_pos_emb, - LlamaAttention, LlamaConfig, -) - -from .abx_rope import abx as recompute_k_gemv - - -class HeadwiseLowRankModule(nn.Module): - """ Headwise Low-Rank module """ - def __init__(self, ranks, in_features, out_features, bias): - super().__init__() - - self.ranks = ranks - self.num_groups = len(ranks) - self.in_features = in_features - self.out_features = out_features - self.group_dim = out_features // self.num_groups - - if (self.group_dim * self.num_groups) != self.out_features: - raise ValueError( - f"out_features must be divisible by num_groups (got `out_features`: {self.out_features}" - f" and `num_groups`: {self.num_groups})." - ) - - self.VT = nn.Linear(in_features, sum(ranks), bias=False) - - # Create the list of linear layers first - Us = [] - for r in ranks: - linear_layer = nn.Linear(r, self.group_dim, bias=bias) - nn.init.normal_(linear_layer.weight) - Us.append(linear_layer) - - self.U_list = nn.ModuleList(Us) - - def forward(self, hidden_states: torch.Tensor): - """ hidden_states: Tensor of shape (batch_size, seq_len, in_features) """ - assert hidden_states.dim() == 3, f"hidden_states should have 3 dimensions, got {hidden_states.dim()}" - - hidden_states = self.VT(hidden_states) - - # hidden_states: Tensor of shape (batch_size, seq_len, r1 + r2 + ... ) - outputs = [] - total_ranks = 0 - for i in range(self.num_groups): - outputs.append(self.U_list[i](hidden_states[:, :, total_ranks: total_ranks+self.ranks[i]])) - total_ranks += self.ranks[i] - - return torch.cat(outputs, dim=-1) - - def project_to_latent(self, hidden_states: torch.Tensor): - """ hidden_states: Tensor of shape (batch_size, seq_len, in_features) """ - assert hidden_states.dim() == 3, f"hidden_states should have 3 dimensions, got {hidden_states.dim()}" - - hidden_states = self.VT(hidden_states) - - return hidden_states - - def reconstruct(self, hidden_states: torch.Tensor): - """ hidden_states: Tensor of shape (batch_size, seq_len, sum(ranks)) """ - assert hidden_states.dim() == 3, f"hidden_states should have 3 dimensions, got {hidden_states.dim()}" - - outputs = [] - total_ranks = 0 - for i in range(self.num_groups): - outputs.append(self.U_list[i](hidden_states[:, :, total_ranks: total_ranks+self.ranks[i]])) - total_ranks += self.ranks[i] - - return torch.cat(outputs, dim=-1) - - @staticmethod - def from_linear( - old_module: nn.Linear, - ranks: list, - attn_module: LlamaAttention = None, - ): - new_module = HeadwiseLowRankModule(ranks, old_module.in_features, old_module.out_features, bias=old_module.bias is not None) - w = old_module.weight.data.reshape(len(ranks), -1, old_module.in_features).float() - - wl = [] - wr = [] - for i in range(len(ranks)): - l, s, r = torch.linalg.svd(w[i], full_matrices=False) - l = l[:, 0:ranks[i]] - s = s[0:ranks[i]] - r = r[0:ranks[i], :] - l = l.mul(s) - - # l: (head_dim, rank), r: (rank, hidden_size) - wl.append(l) - wr.append(r) - - # load to U - for i in range(len(ranks)): - if new_module.U_list[i].weight.data.shape != wl[i].shape: - raise ValueError(f"{new_module.U_list[i].weight.data.shape} != {wl[i].shape}") - new_module.U_list[i].weight.data = wl[i].contiguous() - - # Create B matrix for kernel - if attn_module is not None: - U_list_T = [x.weight.data.T for x in new_module.U_list] - b = torch.stack(U_list_T) - b = b.reshape(new_module.num_groups, new_module.ranks[0], attn_module.group_size, attn_module.head_dim) - b = b.transpose(1, 2) - b = b.reshape(attn_module.num_heads, new_module.ranks[0], attn_module.head_dim) - new_module.B = nn.Parameter(b) - - # load to VT - # shape (sum(ranks), hidden_size) - VT_weight = torch.cat(wr, dim=0).contiguous() - assert new_module.VT.weight.data.shape == VT_weight.shape - new_module.VT.weight.data = VT_weight - - return new_module - -class LlamaPaluAttention(LlamaAttention): - """ - Llama Attention with Low-Rank KV-Cache with Palu. This module inherits from - `LlamaAttention` but change linear layer and add custom Triton kernel. - """ - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - - self.group_size = config.group_size - self.num_groups = config.num_groups - self.total_rank_k = config.total_rank_k - self.total_rank_v = config.total_rank_v - self.group_rank_k = self.total_rank_k // self.num_groups - self.group_rank_v = self.total_rank_v // self.num_groups - self.fused_hidden_dim_o = self.group_rank_v * self.num_heads - self.rank_k_list = [self.group_rank_k for _ in range(self.num_groups)] - self.rank_v_list = [self.group_rank_v for _ in range(self.num_groups)] - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = HeadwiseLowRankModule(self.rank_k_list, self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = HeadwiseLowRankModule(self.rank_v_list, self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.fused_hidden_dim_o, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - golden_kernel: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - # key_states = self.k_proj(hidden_states) - # value_states = self.v_proj(hidden_states) - key_h_states = self.k_proj.project_to_latent(hidden_states) - value_h_states = self.v_proj.project_to_latent(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - # key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - # value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - key_h_states = key_h_states.view(bsz, q_len, self.num_groups, self.group_rank_k).transpose(1, 2) - value_h_states = value_h_states.view(bsz, q_len, self.num_groups, self.group_rank_v).transpose(1, 2) - - # kv_seq_len = key_states.shape[-2] - kv_seq_len = key_h_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len) - # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_h_states, value_h_states = past_key_value.update(key_h_states, value_h_states, self.layer_idx) - - - if q_len > 1: - # Prompting - # Recompute the key states - key_h_states = key_h_states.transpose(1, 2).reshape(bsz, kv_seq_len, self.total_rank_k) - key_states = self.k_proj.reconstruct(key_h_states) - key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - # Apply RoPE after recomputing the key states - cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - else: - # Generating (Apply our reconsturction kernel) - # A: (num_heads, 1, head_dim) - # B: (num_heads, rank_per_groups, head_dim) - # X: (num_head_groups, seq_len, rank_per_groups) - # TODO: Optimize RoPE & sqrt(head_dim) into kernel - # TODO: Check if sin & cos are share among different blocks - cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len) - query_states, _ = apply_rotary_pos_emb(query_states, query_states, cos, sin, position_ids) - A = query_states.squeeze(0) - B = self.k_proj.B - X = key_h_states.squeeze(0) - attn_weights = recompute_k_gemv(A, B, X).unsqueeze(0) / math.sqrt(self.head_dim) - - # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - - # Upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - - # Original version - # value_states = self.v_proj.reconstruct(value_h_states) - # value_states = value_states.reshape(1, q_len, self.num_heads, self.head_dim).transpose(1, 2) - # attn_output = torch.matmul(attn_weights, value_states) - - # Fusion version - # attn_weights: (bsz, num_groups, q_len * group_size, kv_seq_len) - attn_h_weights = attn_weights.reshape(1, self.num_groups, q_len * self.group_size, kv_seq_len) - attn_h_output = torch.matmul(attn_h_weights, value_h_states) - # attn_h_output: (bsz, num_heads, q_len * group_size, group_rank) - attn_output = attn_h_output.reshape(1, self.num_heads, q_len, self.group_rank_v) - - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - @staticmethod - def from_attention( - module: LlamaAttention, - config: LlamaConfig, - no_fusion: bool = False, - ): - new_module = LlamaPaluAttention(config, module.layer_idx) - new_module.q_proj = module.q_proj - new_module.k_proj = HeadwiseLowRankModule.from_linear(module.k_proj, new_module.rank_k_list, new_module) - new_module.v_proj = HeadwiseLowRankModule.from_linear(module.v_proj, new_module.rank_v_list) - - # No fusion version - if no_fusion: - new_module.o_proj = module.o_proj - return new_module - - # Fusion version - # new_module.v_proj = new_v_proj.VT - - # fuse v_proj.U into o_proj - new_o_weight = torch.zeros(new_module.o_proj.weight.size()) - - head_dim = module.head_dim - num_groups = config.num_groups - group_size = config.group_size - group_rank = new_module.group_rank_v - - total_dims_2, total_ranks, total_fused_dims = 0, 0, 0 - for i in range(num_groups): - total_dims = 0 - for _ in range(group_size): - new_o_weight[:, total_fused_dims:total_fused_dims + group_rank] = \ - module.o_proj.weight[:, total_dims_2:total_dims_2 + head_dim] @ \ - new_module.v_proj.U_list[i].weight[total_dims:total_dims + head_dim, :] - total_dims += head_dim - total_dims_2 += head_dim - total_fused_dims += group_rank - - total_ranks += group_rank - - with torch.no_grad(): - new_module.o_proj.weight.copy_(new_o_weight) - - return new_module diff --git a/kernel/pytorch_reference.py b/kernel/pytorch_reference.py deleted file mode 100644 index 77b063d..0000000 --- a/kernel/pytorch_reference.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch - -def LlamaRotaryEmbedding(dim: int, end: int, theta: float = 10000.0): - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) - t = torch.arange(end, dtype=torch.int64).type_as(inv_freq) - freqs = torch.outer(t, inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos(), emb.sin() - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - -def apply_rotary_pos_emb_pytorch(x, cos, sin, unsqueeze_dim=0): - cos = cos.unsqueeze(unsqueeze_dim).to(x.device) - sin = sin.unsqueeze(unsqueeze_dim).to(x.device) - x_emb = (x * cos) + (rotate_half(x) * sin) - return x_emb diff --git a/kernel/test_palu_attention.py b/kernel/test_palu_attention.py deleted file mode 100644 index 356eb47..0000000 --- a/kernel/test_palu_attention.py +++ /dev/null @@ -1,207 +0,0 @@ -import torch -import torch.nn as nn -import pytest - -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaConfig, DynamicCache - -from palu_attention import HeadwiseLowRankModule, LlamaPaluAttention - -@pytest.fixture(autouse=True) -def set_random_seed(): - torch.manual_seed(0) - torch.cuda.manual_seed(0) - torch.cuda.manual_seed_all(0) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - -@pytest.fixture() -def config(): - config = LlamaConfig() - config.group_size = 4 - config.num_groups = config.num_attention_heads // 4 - config.total_rank_k = 4096 - config.total_rank_v = 4096 - return config - - -def _set_random_seed(): - torch.manual_seed(0) - torch.cuda.manual_seed(0) - torch.cuda.manual_seed_all(0) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - -def test_lr_layer_init(): - batch_size = 1 - seq_len = 5 - ranks = [2, 2] - in_features = 6 - out_features = 6 - bias = False - - module = HeadwiseLowRankModule(ranks, in_features, out_features, bias) - - hidden_states = torch.randn(batch_size, seq_len, in_features) - - # forward - forward_output = module(hidden_states) - - # project_to_latent & reconstruct - latent_states = module.project_to_latent(hidden_states) - reconstructed_output = module.reconstruct(latent_states) - - torch.testing.assert_close(forward_output, reconstructed_output) - -def test_lr_layer_from_linear(): - batch_size = 1 - seq_len = 5 - ranks = [3, 3] - in_features = 10 - out_features = 6 - bias = False - - linear = nn.Linear(in_features, out_features, bias) - svd_linear = HeadwiseLowRankModule.from_linear(linear, ranks) - - inputs = torch.randn(batch_size, seq_len, in_features) - - # Golden linear - linear_output = linear(inputs) - - # Low-Ranl linear - svd_linear_output = svd_linear(inputs) - - torch.testing.assert_close(linear_output, svd_linear_output) - -def test_palu_attention_inherit_no_fusion(config): - batch_size = 1 - seq_len = 64 - - attention = LlamaAttention(config, 0) - palu_attention = LlamaPaluAttention.from_attention(attention, config, no_fusion=True) - - # q, o proj - torch.testing.assert_close(attention.q_proj.weight, palu_attention.q_proj.weight) - torch.testing.assert_close(attention.o_proj.weight, palu_attention.o_proj.weight) - - # k, v proj - inputs = torch.randn(batch_size, seq_len, config.hidden_size) - torch.testing.assert_close(attention.k_proj(inputs), palu_attention.k_proj(inputs)) - torch.testing.assert_close(attention.v_proj(inputs), palu_attention.v_proj(inputs)) - -def test_palu_attention_inherit_fusion(config): - batch_size = 1 - seq_len = 64 - - attention = LlamaAttention(config, 0) - palu_attention = LlamaPaluAttention.from_attention(attention, config) - - # q, v proj - torch.testing.assert_close(attention.q_proj.weight, palu_attention.q_proj.weight) - - # k, v proj - inputs = torch.randn(batch_size, seq_len, config.hidden_size) - torch.testing.assert_close(attention.k_proj(inputs), palu_attention.k_proj(inputs)) - torch.testing.assert_close(attention.v_proj(inputs), palu_attention.v_proj(inputs)) - - # o proj - q_len = seq_len - group_size = config.group_size - num_heads = config.num_attention_heads - hidden_dim = config.hidden_size - num_groups = num_heads // group_size - head_dim = hidden_dim // num_heads - group_rank = config.total_rank_v // num_groups - - # original - inputs = torch.randn(1, q_len, hidden_dim) - attn_weight = torch.randn(1, num_heads, q_len, q_len) - v_states = attention.v_proj(inputs).view(1, q_len, num_heads, head_dim).transpose(1, 2) - attn_output = torch.matmul(attn_weight, v_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(1, q_len, -1) - ori_output = attention.o_proj(attn_output) - - # fusion - attn_weight = attn_weight.reshape(1, num_groups, q_len * group_size, q_len) - v_h_states = palu_attention.v_proj.project_to_latent(inputs).reshape(1, q_len, num_groups, group_rank).transpose(1, 2) - - attn_h_output = torch.matmul(attn_weight, v_h_states) - attn_h_output = attn_h_output.reshape(1, num_heads, q_len, group_rank) - - final_fused_o_output = palu_attention.o_proj(attn_h_output.transpose(1, 2).reshape(1, q_len, -1)) - torch.testing.assert_close(ori_output, final_fused_o_output) - -def test_palu_attention_fusion(config): - batch_size = 1 - seq_len = 64 - dev = 'cuda:0' - dtype = torch.float16 - - attention = LlamaAttention(config, 0) - palu_attention = LlamaPaluAttention.from_attention(attention, config) - - attention = attention.to(dev, dtype) - palu_attention = palu_attention.to(dev, dtype) - - inputs = torch.randn(batch_size, seq_len, config.hidden_size).to(dev, dtype) - - # Golden - golden_output, golden_attn_weights, _ = attention(inputs, output_attentions = True) - - # Fusion - fusion_output, fusion_attn_weights, _ = palu_attention(inputs, output_attentions = True, golden_kernel=True) - - torch.testing.assert_close(golden_attn_weights, fusion_attn_weights, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(golden_output, fusion_output, rtol=1e-3, atol=1e-3) - -def test_palu_attention_kernel(config): - batch_size = 1 - prompt_len = 63 - seq_len = 1 - dev = 'cuda:0' - dtype = torch.float16 - - attention = LlamaAttention(config, 0) - palu_attention = LlamaPaluAttention.from_attention(attention, config) - - attention = attention.to(dev, dtype) - palu_attention = palu_attention.to(dev, dtype) - prompt_inputs = torch.randn(batch_size, prompt_len, config.hidden_size).to(dev, dtype) - inputs = torch.randn(batch_size, seq_len, config.hidden_size).to(dev, dtype) - prompt_position_ids = torch.arange(prompt_len).unsqueeze(0) # Shape: [1, seq_length] - generate_position_ids = torch.arange(prompt_len, prompt_len+seq_len).unsqueeze(0) # Shape: [1, seq_length] - kv_cache = DynamicCache() - palu_kv_cache = DynamicCache() - - # Prompt - attn_output, attn_weights, kv_cache = attention(prompt_inputs, output_attentions=True, - past_key_value=kv_cache, position_ids=prompt_position_ids) - palu_attn_output, palu_attn_weights, palu_kv_cache = palu_attention(prompt_inputs, output_attentions=True, - past_key_value=palu_kv_cache, position_ids=prompt_position_ids) - - torch.testing.assert_close(attn_output, palu_attn_output, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(attn_weights, palu_attn_weights, rtol=1e-3, atol=1e-3) - - # Generate - # Golden - golden_attn_output, golden_attn_weights, _ = attention(inputs, output_attentions=True, - past_key_value=kv_cache, position_ids=generate_position_ids) - # Kernel - palu_attn_output, palu_attn_weights, _ = palu_attention(inputs, output_attentions=True, - past_key_value=palu_kv_cache, position_ids=generate_position_ids) - - torch.testing.assert_close(golden_attn_weights, palu_attn_weights, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(golden_attn_output, palu_attn_output, rtol=1e-3, atol=1e-3) - - -if __name__ == "__main__": - _set_random_seed() - _config = LlamaConfig() - _config.group_size = 4 - _config.num_groups = _config.num_attention_heads // 4 - _config.total_rank_k = 4096 - _config.total_rank_v = 4096 - test_palu_attention_fusion(_config) - test_palu_attention_kernel(_config) - \ No newline at end of file diff --git a/kernel/__init__.py b/palu/backend/__init__.py similarity index 100% rename from kernel/__init__.py rename to palu/backend/__init__.py diff --git a/palu/backend/fused_recompute.py b/palu/backend/fused_recompute.py new file mode 100644 index 0000000..cf291b0 --- /dev/null +++ b/palu/backend/fused_recompute.py @@ -0,0 +1,158 @@ +"""We want triton==3.0.0 for this script +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def get_freq_multi_tokens(starting_idx, theta: tl.constexpr, NB_TOKENS: tl.constexpr): + DIM: tl.constexpr = 128 # in model, dim = self.params.dim // self.params.n_heads + DIM_2: tl.constexpr = 64 + freqs = tl.arange(0, DIM_2) * 2 + freqs = freqs.to(tl.float32) / DIM + freqs = tl.extra.cuda.libdevice.fast_powf(theta, freqs) + freqs = (tl.arange(0, NB_TOKENS) + starting_idx)[:, None] / freqs[None, :] + return tl.extra.cuda.libdevice.fast_cosf(freqs), tl.extra.cuda.libdevice.fast_sinf(freqs) + + +def get_configs(): + configs = [] + for block_l in [16, 32, 64, 128]: + for block_r in [16, 32]: + for num_warps in [1, 4, 8, 16]: + for num_stages in [1, 2, 3]: + configs.append( + triton.Config({'BLOCK_SIZE_L': block_l, 'BLOCK_SIZE_R': block_r}, + num_stages=num_stages, num_warps=num_warps)) + # return configs + # return [triton.Config({'BLOCK_SIZE_L': 128, 'BLOCK_SIZE_R': 32}, num_warps=4, num_stages=3)] # for gs=4 + # return [triton.Config({'BLOCK_SIZE_L': 64, 'BLOCK_SIZE_R': 32}, num_warps=4, num_stages=3)] # for gs=2 + return [triton.Config({'BLOCK_SIZE_L': 64, 'BLOCK_SIZE_R': 32}, num_warps=4, num_stages=1)] # for gs=1 + +@triton.autotune( + configs= get_configs(), + key=["seq_len"] +) +@triton.jit +def _abx_fwd( + a_ptr, b_ptr, x_ptr, out_ptr, + stride_az, stride_aa, stride_ad, + stride_bz, stride_br, stride_bd, + stride_xhg, stride_xl, stride_xr, + stride_oz, stride_oa, stride_ol, + R, D, seq_len, + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_R: tl.constexpr, + BLOCK_SIZE_L: tl.constexpr, + NUM_GROUPS: tl.constexpr, + THETA: tl.constexpr, +): + pid_h = tl.program_id(axis=0) # number of heads + pid_l = tl.program_id(axis=1) # nubmer of block along seq_length dimension + + # Assuming NUM_GROUPS = 4, then pid_h = 0, 1, 2, 3 will be assigned to head group 0 + HEAD_GROUPS_ID = pid_h // (32 // NUM_GROUPS) + offs_ds = tl.arange(0, BLOCK_SIZE_D) # same as offs_bds + offs_rs = tl.arange(0, BLOCK_SIZE_R) + offs_ls = (pid_l * BLOCK_SIZE_L) + tl.arange(0, BLOCK_SIZE_L) + + A_ptrs = a_ptr + pid_h * stride_az + (0*stride_aa + offs_ds[None, :]*stride_ad) # assume a is always (bs, 1, d) + B_ptrs = b_ptr + pid_h * stride_bz + (offs_rs[:, None]*stride_br + offs_ds[None, :]*stride_bd) + X_ptrs = x_ptr + HEAD_GROUPS_ID * stride_xhg + (offs_ls[:, None]*stride_xl + offs_rs[None, :]*stride_xr) + O_ptrs = out_ptr + pid_h * stride_oz + (0*stride_oa + offs_ls[None, :]*stride_ol) + + # Fix BLOCK_SIZE_D = 64, and head_dim = 128 + xb_0 = tl.zeros((BLOCK_SIZE_L, BLOCK_SIZE_D), dtype=tl.float32) + xb_1 = tl.zeros((BLOCK_SIZE_L, BLOCK_SIZE_D), dtype=tl.float32) + for _ in range(0, tl.cdiv(R, BLOCK_SIZE_R)): + # Load next block of B, X + x = tl.load(X_ptrs, mask=offs_ls[:, None] < seq_len, other=0.0) + b_0 = tl.load(B_ptrs) + b_1 = tl.load(B_ptrs + BLOCK_SIZE_D * stride_bd) + # Accumulate along R dimension. + xb_0 = tl.dot(x, b_0, xb_0) + xb_1 = tl.dot(x, b_1, xb_1) + # Advance the pointers to next blocks + B_ptrs += BLOCK_SIZE_R * stride_br + X_ptrs += BLOCK_SIZE_R * stride_xr + + xb_0 = xb_0.to(tl.float16) + xb_1 = xb_1.to(tl.float16) + + # RoPE + start_block = pid_l * BLOCK_SIZE_L + cos, sin = get_freq_multi_tokens(starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_L) + cos = cos.to(tl.float16) + sin = sin.to(tl.float16) + + xb_rope_0 = xb_0 * cos - xb_1 * sin + xb_rope_1 = xb_1 * cos + xb_0 * sin + xb_0 = xb_rope_0.to(tl.float16) + xb_1 = xb_rope_1.to(tl.float16) + + # GEMV + a_0 = tl.load(A_ptrs) + a_1 = tl.load(A_ptrs + BLOCK_SIZE_D * stride_ad) + abx_0 = tl.sum(a_0 * xb_0, 1) + abx_1 = tl.sum(a_1 * xb_1, 1) + abx = abx_0 + abx_1 + tl.store(O_ptrs, abx[None, :], mask=offs_ls[None, :] < seq_len) + + +def abx(a: torch.Tensor, b: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Computes the operation A x B x X using a custom Triton kernel. + + Args: + a (torch.Tensor): Tensor of shape (num_heads, 1, head_dim). + b (torch.Tensor): Tensor of shape (num_heads, rank_per_head_groups, head_dim). + x (torch.Tensor): Tensor of shape (num_groups, seq_len, rank_per_head_groups). + + Returns: + torch.Tensor: Output tensor of shape (num_heads, 1, seq_len). + """ + # U x V x X + assert a.dim() == 3 + assert b.dim() == 3 + assert x.dim() == 3 + + num_heads, _, head_dim = a.shape + num_heads,rank_per_head_groups, head_dim = b.shape + num_groups, seq_len, rank_per_head_groups = x.shape + # Allocate output tensor + out = torch.empty((num_heads, 1, seq_len), dtype=x.dtype, device=x.device) + BLOCK_SIZE_D = 64 + # BLOCK_SIZE_R = 32 + # BLOCK_SIZE_L = 128 + # num_stages = 1 + # num_warps = 8 + NUM_GROUPS = num_groups + grid = lambda META: (32, triton.cdiv(seq_len, META["BLOCK_SIZE_L"])) + _abx_fwd[grid]( + a, b, x, out, + a.stride(0), a.stride(1), a.stride(2), + b.stride(0), b.stride(1), b.stride(2), + x.stride(0), x.stride(1), x.stride(2), + out.stride(0), out.stride(1), out.stride(2), + R = rank_per_head_groups, + D = head_dim, + seq_len = seq_len, + BLOCK_SIZE_D = BLOCK_SIZE_D, + # BLOCK_SIZE_L = BLOCK_SIZE_L, + # BLOCK_SIZE_R = BLOCK_SIZE_R, + # num_stages=num_stages, + # num_warps=num_warps, + NUM_GROUPS = NUM_GROUPS, + THETA = 10000., + ) + return out + + + + + + + + + diff --git a/palu/backend/q_matmul.py b/palu/backend/q_matmul.py new file mode 100644 index 0000000..c807658 --- /dev/null +++ b/palu/backend/q_matmul.py @@ -0,0 +1,82 @@ +import torch +import palu.palu_cuda as palu_kernel + +def cuda_bmm_fA_qB_inner(group_size: int, + fA: torch.FloatTensor, + qB: torch.IntTensor, + scales: torch.FloatTensor, + zeros: torch.FloatTensor, + bits: int) -> torch.FloatTensor: + """ + fA is of shape (B, nh, M, K) float16 + qB is of shape (B, nh, K // feat_per_int, N) int32 + scales is of shape (B, nh, G, N) float16 + zeros is of shape (B, nh, G, N) float16 + + groupsize is the number of inner dimension in each groups. + G = K // groupsize + + Returns C of shape (B, nh, M, N) float16 + """ + assert len(fA.shape) == 4 and len(qB.shape) == 4 + B, nh, M, K = fA.shape + #assert M == 1, "Currently only supoort M=1" + feat_per_int = 32 // bits + # flatten to a 3D tensor + #print(fA.view(-1, M, K).is_contiguous()) + fA = fA.view(-1, M, K).contiguous() + #print(qB.view(-1, K // feat_per_int, qB.shape[-1]).transpose(1, 2).is_contiguous()) + qB = qB.view(-1, K // feat_per_int, qB.shape[-1]).transpose(1, 2).contiguous() + flatten_B = B * nh + scales = scales.view(flatten_B, scales.shape[-2], scales.shape[-1]).transpose(1, 2).contiguous() + zeros = zeros.view(flatten_B, zeros.shape[-2], zeros.shape[-1]).transpose(1, 2).contiguous() + #print(scales.shape, zeros.shape) + assert bits in [4] + c = palu_kernel.batched_gemm_forward_cuda(fA, qB, scales, zeros, bits, group_size) + c = c.view(B, nh, c.shape[-2], c.shape[-1]) + return c + +def cuda_bmm_fA_qB_outer(group_size: int, + fA: torch.FloatTensor, + qB: torch.IntTensor, + scales: torch.FloatTensor, + zeros: torch.FloatTensor, + bits: int, + mqa: bool=False) -> torch.FloatTensor: + """ + Compute the matrix multiplication C = query x key. + Where key is quantized into 2-bit values. + + fA is of shape (B, nh, M, K) float16 + qB is of shape (B, nh, K, N // feat_per_int) int32 + scales is of shape (B, nh, K, G) float16 + zeros is of shape (B, nh, K, G) float16 + + groupsize is the number of outer dimensions in each group. + G = N // groupsize + + Returns C of shape (B, nh, M, N) float16 + """ + assert len(fA.shape) == 4 and len(qB.shape) == 4 + B, nh, M, K = fA.shape + feat_per_int = 32 // bits + # flatten to a 3D tensor + fA = fA.view(-1, M, K).contiguous() + N = qB.shape[-1] * feat_per_int + qB = qB.reshape(-1, K, qB.shape[-1]).transpose(1, 2).contiguous() + # This is based on the possible BLOCK_SIZE_Ks + # assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" + # This is based on the possible BLOCK_SIZE_Ns + # assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" + # This is based on the possible BLOCK_SIZE_Ks + # assert group_size % 64 == 0, "groupsize must be a multiple of 64, and 128" + flatten_B = B * nh + # if mqa: + # flatten_B = B + scales = scales.view(flatten_B, scales.shape[-2], scales.shape[-1]).transpose(1, 2).contiguous() + zeros = zeros.view(flatten_B, zeros.shape[-2], zeros.shape[-1]).transpose(1, 2).contiguous() + assert bits in [4] + #c = palu_kernel.gemv_forward_cuda_outer_dim(fA, qB, scales, zeros, bits, group_size, nh, mqa) + c = palu_kernel.batched_gemm_forward_outer_cuda(fA, qB, scales, zeros, bits, group_size) + c = c.view(B, nh, c.shape[-2], c.shape[-1]) + return c \ No newline at end of file diff --git a/palu/compress/__init__.py b/palu/compress/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/palu/decomposition.py b/palu/compress/decomposition.py similarity index 99% rename from palu/decomposition.py rename to palu/compress/decomposition.py index 0fe6d6f..7f13b8a 100644 --- a/palu/decomposition.py +++ b/palu/compress/decomposition.py @@ -4,8 +4,8 @@ import os import click from tqdm import tqdm -from .data_utils import get_calib_data -from .model import HeadwiseLowRankModule +from ..utils.data_utils import get_calib_data +from ..model import HeadwiseLowRankModule def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): if type(module) in layers: diff --git a/palu/rank_search.py b/palu/compress/rank_search.py similarity index 99% rename from palu/rank_search.py rename to palu/compress/rank_search.py index 76a0369..2b24107 100644 --- a/palu/rank_search.py +++ b/palu/compress/rank_search.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn from loguru import logger -from .model import AVAILABLE_MODELS -from .data_utils import get_calib_data +from ..model import AVAILABLE_MODELS +from ..utils.data_utils import get_calib_data import math from tqdm import tqdm diff --git a/palu/csrc/palu_gemm_cuda.cu b/palu/csrc/palu_gemm_cuda.cu new file mode 100644 index 0000000..94a6d74 --- /dev/null +++ b/palu/csrc/palu_gemm_cuda.cu @@ -0,0 +1,170 @@ +#include +#include +#include +#include "palu_gemm_cuda.h" + +#define BM 4 // Tile size in the M dimension (number of rows per tile) +#define BK 128 // Tile size in the K dimension (group size for quantization) +#define BN 16 // Tile size in the N dimension (number of columns per tile) +#define PACK_FACTOR 8 // Number of elements packed together (for vectorized loads) + +// Kernel function to perform batched GEMM with quantized weights +__global__ void batched_gemm_kernel_quantized( + const float4* __restrict__ A_packed, // Packed input activations [B, M, K / PACK_FACTOR] + const uint32_t* __restrict__ qB_packed, // Packed quantized weights [B, N, K / PACK_FACTOR] + const half* __restrict__ zeros, // Zero offsets for quantization [B, N, K / group_size] + const half* __restrict__ scaling_factors, // Scaling factors for quantization [B, N, K / group_size] + half* __restrict__ C, // Output matrix [B, M, N] + const int B, const int M, const int N, const int K, + const int group_size) { + + // Compute batch index + const int batch_idx = blockIdx.z; + + // Tile indices along the N (columns) and M (rows) dimensions + const int tile_idx_N = blockIdx.x; // Tile index along the N dimension + const int tile_idx_M = blockIdx.y; // Tile index along the M dimension + + // Each thread computes one element in the output tile of size BM x BN + const int thread_id = threadIdx.x; // Thread index within the block + const int thread_row = thread_id / BN; // Row index within the tile [0, BM) + const int thread_col = thread_id % BN; // Column index within the tile [0, BN) + + // Compute global indices in the M and N dimensions + const int global_row = tile_idx_M * BM + thread_row; // Global row index in the output matrix + const int global_col = tile_idx_N * BN + thread_col; // Global column index in the output matrix + + // Bounds check to prevent out-of-bounds memory access + if (global_row >= M || global_col >= N || batch_idx >= B) + return; + + // Pointer to the output element computed by this thread + half* output_ptr = C + batch_idx * (M * N) + global_row * N + global_col; + + // Initialize the partial sum for the dot product + float partial_sum = 0.0f; + + // Number of groups along the K dimension (since we process group_size elements at a time) + const int num_groups = K / group_size; + + // Loop over each group along the K dimension + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + + // Compute index for scaling factors and zeros + int sf_zero_idx = batch_idx * (N * num_groups) + global_col * num_groups + group_idx; + + // Load the scaling factor and zero point for the current group + float scaling_factor = __half2float(scaling_factors[sf_zero_idx]); + float zero_point = __half2float(zeros[sf_zero_idx]); + + // Number of iterations within the group (since we process 32 elements per iteration) + const int iterations = group_size / 32; + + // Loop over iterations within the group + #pragma unroll + for (int iter = 0; iter < iterations; iter++) { + // Calculate offsets for qB_packed and A_packed + int qB_offset = batch_idx * (N * K / PACK_FACTOR) + + global_col * (K / PACK_FACTOR) + + group_idx * (group_size / PACK_FACTOR) + + iter * 4; // 4 because we load 4 uint32_t (128 bits) at a time + + int A_offset = batch_idx * (M * K / PACK_FACTOR) + + global_row * (K / PACK_FACTOR) + + group_idx * (group_size / PACK_FACTOR) + + iter * 4; // Matching offset for A_packed + + // Load 128 bits (4 uint32_t) of packed quantized weights from qB_packed + uint32_t qB_values[4]; + *((float4*)(qB_values)) = *((float4*)(qB_packed + qB_offset)); + + // Process each of the 4 uint32_t values + #pragma unroll + for (int j = 0; j < 4; j++) { + uint32_t packed_weights = qB_values[j]; + + // Load 128 bits (8 half-precision floats) of activations from A_packed + float4 A_packed_value = A_packed[A_offset + j]; + half* A_values = reinterpret_cast(&A_packed_value); // Access as half* + + // Process 8 elements (since each uint32_t packs 8 quantized weights) + #pragma unroll + for (int k = 0; k < PACK_FACTOR; k++) { + // Extract a 4-bit quantized weight + float quantized_weight = static_cast(packed_weights & 0xF); + + // Dequantize the weight + float dequantized_weight = quantized_weight * scaling_factor + zero_point; + + // Multiply with the activation and accumulate the result + partial_sum += dequantized_weight * __half2float(A_values[k]); + + // Shift to the next 4-bit quantized weight + packed_weights >>= 4; + } + } + } + } + + // Store the computed partial sum as the output element + *output_ptr = __float2half(partial_sum); +} + +// Host function to launch the kernel +torch::Tensor batched_gemm_forward_cuda( + torch::Tensor _A, // Input activations tensor [B, M, K] + torch::Tensor _qB, // Packed quantized weights tensor [B, N, K / PACK_FACTOR] + torch::Tensor _scaling_factors, // Scaling factors tensor [B, N, K / group_size] + torch::Tensor _zeros, // Zero points tensor [B, N, K / group_size] + const int bit, // Bit-width for quantization (e.g., 4 bits) + const int group_size) { // Group size used for quantization + + // Extract input tensor dimensions + int B = _A.size(0); // Batch size + int M = _A.size(1); // Number of rows in A (and C) + int K = _A.size(2); // Number of columns in A and rows in qB + int N = _qB.size(1); // Number of columns in qB (and C) + + // Ensure that K is divisible by PACK_FACTOR and group_size + TORCH_CHECK(K % PACK_FACTOR == 0, "K must be divisible by PACK_FACTOR"); + TORCH_CHECK(K % group_size == 0, "K must be divisible by group_size"); + + // Ensure that input tensors are on CUDA + TORCH_CHECK(_A.is_cuda(), "Input tensor A must be a CUDA tensor"); + TORCH_CHECK(_qB.is_cuda(), "Input tensor qB must be a CUDA tensor"); + TORCH_CHECK(_scaling_factors.is_cuda(), "Input tensor scaling_factors must be a CUDA tensor"); + TORCH_CHECK(_zeros.is_cuda(), "Input tensor zeros must be a CUDA tensor"); + + // Cast input tensors to appropriate data types + auto A_packed = reinterpret_cast(_A.data_ptr()); + auto qB_packed = reinterpret_cast(_qB.data_ptr()); + auto zeros_ptr = reinterpret_cast(_zeros.data_ptr()); + auto scaling_factors_ptr = reinterpret_cast(_scaling_factors.data_ptr()); + + // Create an output tensor + auto options = torch::TensorOptions().dtype(_A.dtype()).device(_A.device()); + at::Tensor _C = torch::empty({B, M, N}, options); + auto C_ptr = reinterpret_cast(_C.data_ptr()); + + // Calculate grid and block dimensions for kernel launch + dim3 blockDim(BN * BM); // Total threads per block (BM x BN) + dim3 gridDim((N + BN - 1) / BN, // Number of blocks along the N dimension + (M + BM - 1) / BM, // Number of blocks along the M dimension + B); // Number of blocks along the batch dimension + + // Ensure that blockDim.x does not exceed the maximum threads per block + TORCH_CHECK(blockDim.x <= 1024, "blockDim.x exceeds the maximum number of threads per block"); + + // Launch the CUDA kernel + batched_gemm_kernel_quantized<<>>( + A_packed, qB_packed, zeros_ptr, scaling_factors_ptr, C_ptr, B, M, N, K, group_size + ); + + // Check for CUDA errors + cudaError_t err = cudaGetLastError(); + TORCH_CHECK(err == cudaSuccess, "CUDA Error: ", cudaGetErrorString(err)); + + // Return the output tensor + return _C; +} + diff --git a/palu/csrc/palu_gemm_cuda.h b/palu/csrc/palu_gemm_cuda.h new file mode 100644 index 0000000..541b288 --- /dev/null +++ b/palu/csrc/palu_gemm_cuda.h @@ -0,0 +1,19 @@ +#pragma once +#include + +torch::Tensor batched_gemm_forward_cuda( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + const int bit, + const int group_size); + +torch::Tensor batched_gemm_forward_outer_cuda( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + const int bit, + const int group_size +); \ No newline at end of file diff --git a/palu/csrc/palu_gemm_outer_cuda.cu b/palu/csrc/palu_gemm_outer_cuda.cu new file mode 100644 index 0000000..afc31fd --- /dev/null +++ b/palu/csrc/palu_gemm_outer_cuda.cu @@ -0,0 +1,172 @@ +#include +#include +#include +#include "palu_gemm_cuda.h" + +#define BM 4 // Tile size in the M dimension (number of rows per tile) +#define BK 128 // Tile size in the K dimension (group size for quantization) +#define BN 8 // Tile size in the N dimension (number of columns per tile) +#define PACK_FACTOR 8 // Number of elements packed together (for vectorized loads) + +// Reduce sum within the warp using the tree reduction algorithm. +__device__ __forceinline__ float warp_reduce_sum(float sum) { + #pragma unroll + for(int i = 4; i >= 0; i--){ + sum += __shfl_down_sync(0xffffffff, sum, 1<> 4; + psum[j] += __half2float(cur_a) * dequant_b; + } + } + } + // Write back the results + for(int i=0; i(fA.data_ptr()); + const uint32_t* qB_ptr = reinterpret_cast(qB.data_ptr()); + const __half* scaling_factors_ptr = reinterpret_cast(scaling_factors.data_ptr()); + const __half* zeros_ptr = reinterpret_cast(zeros.data_ptr()); + + // Create output tensor + auto options = torch::TensorOptions().dtype(torch::kHalf).device(fA.device()); + at::Tensor output = torch::empty({B, M, N}, options); + __half* output_ptr = reinterpret_cast<__half*>(output.data_ptr()); + + // Define block and grid dimensions + dim3 gridDim((N + BN - 1) / BN, + (M + BM - 1) / BM, + B); // (Batches, N tiles) + dim3 blockDim(32 * BM); // M warps per block (BM=4) + + // Launch the kernel + batched_gemm_kernel_quantized_outer<<>>( + fA_ptr, qB_ptr, scaling_factors_ptr, zeros_ptr, output_ptr, B, M, N, K, group_size + ); + + // Check for CUDA errors + cudaError_t err = cudaGetLastError(); + TORCH_CHECK(err == cudaSuccess, "CUDA Error: ", cudaGetErrorString(err)); + + return output; +} \ No newline at end of file diff --git a/palu/csrc/pybind.cpp b/palu/csrc/pybind.cpp new file mode 100644 index 0000000..d1c7974 --- /dev/null +++ b/palu/csrc/pybind.cpp @@ -0,0 +1,9 @@ +#include +#include +#include "palu_gemm_cuda.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("batched_gemm_forward_cuda", &batched_gemm_forward_cuda); + m.def("batched_gemm_forward_outer_cuda", &batched_gemm_forward_outer_cuda); +} \ No newline at end of file diff --git a/palu/model/modules/svd_linear.py b/palu/model/modules/svd_linear.py index 64372d4..1d9ce25 100644 --- a/palu/model/modules/svd_linear.py +++ b/palu/model/modules/svd_linear.py @@ -47,7 +47,6 @@ def _per_head_decomposition_from_weight(weight, rank): # Fuse the SVD components L = torch.matmul(U, sqrtSigma).to(original_dtype) R = torch.matmul(sqrtSigma, Vt).to(original_dtype) - assert torch.allclose(torch.matmul(L, R), weight, atol=1e-3), "SVD decomposition failed" return L, R class HeadwiseLowRankModule(nn.Module): diff --git a/palu/model/svd_llama/configuration_palu_llama.py b/palu/model/svd_llama/configuration_palu_llama.py index b33104b..653f68a 100644 --- a/palu/model/svd_llama/configuration_palu_llama.py +++ b/palu/model/svd_llama/configuration_palu_llama.py @@ -109,6 +109,9 @@ def __init__( rope_scaling=None, attention_bias=False, head_wise_ranks=None, + k_bits=16, + v_bits=16, + palu_attn_linear_only=True, **kwargs, ): self.vocab_size = vocab_size @@ -143,6 +146,11 @@ def __init__( # for avsd self.head_wise_ranks = head_wise_ranks + # for quantization + self.k_bits = k_bits + self.v_bits = v_bits + self.palu_attn_linear_only = palu_attn_linear_only + def _rope_scaling_validation(self): """ diff --git a/palu/model/svd_llama/modeling_palu_llama.py b/palu/model/svd_llama/modeling_palu_llama.py index cf3820a..975d542 100644 --- a/palu/model/svd_llama/modeling_palu_llama.py +++ b/palu/model/svd_llama/modeling_palu_llama.py @@ -1,7 +1,12 @@ -from transformers import LlamaForCausalLM +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM +) import torch.nn as nn +import torch from types import SimpleNamespace from .configuration_palu_llama import PaluLlamaConfig +from .palu_llama_attention import LlamaPaluAttention from ..modules.svd_linear import HeadwiseLowRankModule class PaluLlamaForCausalLM(LlamaForCausalLM): @@ -9,31 +14,9 @@ class PaluLlamaForCausalLM(LlamaForCausalLM): def __init__(self, config:PaluLlamaConfig): super().__init__(config) self.head_wise_ranks=config.head_wise_ranks - - full_name_dict = {module: name for name, module in self.named_modules()} - linear_info = {} - modules = [self] - while len(modules) > 0: - submodule = modules.pop() - for name, raw_linear in submodule.named_children(): - if isinstance(raw_linear, nn.Linear): - full_name = full_name_dict[raw_linear] - linear_info[raw_linear] = { - "father": submodule, - "name": name, - "full_name": full_name, - } - else: - modules.append(raw_linear) - - - for name,module in self.named_modules(): - if name in self.head_wise_ranks: - info=linear_info[module] - new_layer=HeadwiseLowRankModule(self.head_wise_ranks[name],module.in_features,module.out_features,bias=module.bias is not None) - setattr(info["father"], info["name"], new_layer) - - + self.palu_attn_linear_only = config.palu_attn_linear_only + self._replace_modules(self.palu_attn_linear_only) + @staticmethod def get_kv_info(llama: LlamaForCausalLM, num_heads_in_lr_groups: int): num_lr_groups = llama.config.num_attention_heads // num_heads_in_lr_groups @@ -57,3 +40,58 @@ def get_kv_info(llama: LlamaForCausalLM, num_heads_in_lr_groups: int): num_lr_groups=num_lr_kv_groups, lr_group_dims=lr_group_dims, ) + + + def _replace_modules(self, linear_only=True): + if linear_only: + # Mode 1: Only replace the linear layers to simulate the low-rank approximation + full_name_dict = {module: name for name, module in self.named_modules()} + linear_info = {} + modules = [self] + while len(modules) > 0: + submodule = modules.pop() + for name, raw_linear in submodule.named_children(): + if isinstance(raw_linear, nn.Linear): + full_name = full_name_dict[raw_linear] + linear_info[raw_linear] = { + "father": submodule, + "name": name, + "full_name": full_name, + } + else: + modules.append(raw_linear) + + + for name,module in self.named_modules(): + if name in self.head_wise_ranks: + info=linear_info[module] + new_layer=HeadwiseLowRankModule(self.head_wise_ranks[name],module.in_features,module.out_features,bias=module.bias is not None) + setattr(info["father"], info["name"], new_layer) + else: + # Mode 2: Replace all the attention modules with update forward path + #FIXME (brian1009): Could be simplified further + full_name_dict = {module: name for name, module in self.named_modules()} + attn_info = {} + modules = [self] + while len(modules) > 0: + submodule = modules.pop() + for name, child in submodule.named_children(): + if isinstance(child, LlamaAttention): + full_name = full_name_dict[child] + attn_info[child] = { + "father": submodule, + "name": name, + "full_name": full_name, + } + else: + modules.append(child) + + layer_id_counter = 0 + for name, module in self.named_modules(): + if isinstance(module, LlamaAttention): + info = attn_info[module] + new_layer = LlamaPaluAttention(self.config, layer_id_counter) + setattr(info["father"], info["name"], new_layer) + layer_id_counter += 1 + + torch.cuda.empty_cache() \ No newline at end of file diff --git a/palu/model/svd_llama/palu_llama_attention.py b/palu/model/svd_llama/palu_llama_attention.py new file mode 100644 index 0000000..4bc570c --- /dev/null +++ b/palu/model/svd_llama/palu_llama_attention.py @@ -0,0 +1,285 @@ +import math +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from transformers.models.llama.modeling_llama import ( + apply_rotary_pos_emb, + LlamaAttention, + LlamaConfig, +) + +from transformers.cache_utils import Cache + +from .configuration_palu_llama import PaluLlamaConfig +from ..modules.svd_linear import HeadwiseLowRankModule + +from ...backend.fused_recompute import abx as recompute_k_gemv +from ...backend.q_matmul import cuda_bmm_fA_qB_outer + +from ...quant.quant_kv_cache import ValueQuantizedCacheV2 + +class LlamaPaluAttention(LlamaAttention): + """ + Llama Attention with Low-Rank KV-Cache with Palu. This module inherits from + `LlamaAttention` but change linear layer and add custom Triton kernel. + """ + def __init__(self, config: Union[PaluLlamaConfig, LlamaConfig], + rank_k_list: list[int], + rank_v_list: list[int], + layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + + #FIXME(brian1009): The interface design here is somehow inconsistent, need to be reviewed further + self.rank_k_list = rank_k_list + self.rank_v_list = rank_v_list + #NOTE(brian1009): We assume all groups sharing the same rank for now + self.group_rank_k = self.rank_k_list[0] + self.group_rank_v = self.rank_v_list[0] + self.total_rank_k = sum(self.rank_k_list) + self.total_rank_v = sum(self.rank_v_list) + assert len(self.rank_k_list) == len(self.rank_v_list), "The number of groups for k and v should be the same so far" + self.num_groups = len(self.rank_k_list) + self.group_size = self.num_heads // self.num_groups + + self.k_bits = config.k_bits + self.v_bits = config.v_bits + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = HeadwiseLowRankModule(self.rank_k_list, self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = HeadwiseLowRankModule(self.rank_v_list, self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + self.v_recompution_fused = False + self.has_prepared_for_k_recompute_kernel = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + # assert self.is_prepared, "Please call palu_prepare() method before forward" + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_h_states = self.k_proj.project_to_latent(hidden_states) + value_h_states = self.v_proj.project_to_latent(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_h_states = key_h_states.view(bsz, q_len, self.num_groups, self.group_rank_k).transpose(1, 2) + value_h_states = value_h_states.view(bsz, q_len, self.num_groups, self.group_rank_v).transpose(1, 2) + + #key_h_states_quant, key_scales, key_zeros = quant_and_pack_vcache(key_h_states, self.group_rank_k, self.k_bits) + #if self.v_bits != 16: + # value_h_states_quant, value_scales, value_zeros = quant_and_pack_vcache(value_h_states, self.group_rank_v, self.v_bits) + # kv_seq_len = key_states.shape[-2] + + kv_seq_len = key_h_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + #key_h_states, value_h_states = past_key_value.update(key_h_states, value_h_states, self.layer_idx) + if self.v_bits == 16: + key_h_states, value_h_states = past_key_value.update( + key_h_states, value_h_states, self.layer_idx + ) + else: + # key_h_states, value_h_states_quant, value_scales, value_zeros = past_key_value.update( + # key_h_states, value_h_states_quant, self.layer_idx, value_scales, value_zeros + # ) + assert isinstance(past_key_value, ValueQuantizedCacheV2), "When the value is not 16-bit, we assume past_key_value to be ValueQuantizedCacheV2" + key_h_states, value_h_states_quant, value_scales, value_zeros, value_h_states_full = past_key_value.update( + key_h_states, value_h_states, self.layer_idx + ) + #NOTE(brian1009): We already transposed the value_h_states_quant in the update function. + # Now the shape of value_h_states_quant is (bsz, num_heads, group_rank_v, seq_len) and its contiguous. + # This is for saving the an extra contigous operation in the kernel. + value_h_states_quant = value_h_states_quant.transpose(2, 3) + + if q_len > 1: + # Prompting + # Recompute the key states + key_h_states = key_h_states.transpose(1, 2).reshape(bsz, kv_seq_len, self.total_rank_k) + key_states = self.k_proj.reconstruct(key_h_states) + key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # Apply RoPE after recomputing the key states + cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + else: + # Generating (Apply our reconsturction kernel) + # A: (num_heads, 1, head_dim) + # B: (num_heads, rank_per_groups, head_dim) + # X: (num_head_groups, seq_len, rank_per_groups) + # TODO: Optimize RoPE & sqrt(head_dim) into kernel + # TODO: Check if sin & cos are share among different blocks + cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len) + #query_states, _ = apply_rotary_pos_emb(query_states, query_states, cos, sin, position_ids) + assert bsz == 1, "Only support batch size 1 for now" + A = query_states.squeeze(0) + + assert self.has_prepared_for_k_recompute_kernel, "Please call prepared_k_merged_U() before forward using customized Triton Kernel" + B = self.concated_B_for_k_recompute + #B = self.k_proj.B + X = key_h_states.squeeze(0) + attn_weights = recompute_k_gemv(A, B, X).unsqueeze(0) / math.sqrt(self.head_dim) + + # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + + # Upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + if self.v_recompution_fused: + # Fusion version + # attn_weights: (bsz, num_groups, q_len * group_size, kv_seq_len) + assert self.v_recompution_fused, "Please call fused_v_recompute_to_o() before forward" + attn_h_weights = attn_weights.reshape(1, self.num_groups, q_len * self.group_size, kv_seq_len) + if self.v_bits == 16: + attn_h_output = torch.matmul(attn_h_weights, value_h_states) + else: + if value_h_states_full is None: + value_full_length = 0 + attn_h_output = cuda_bmm_fA_qB_outer( + group_size=self.group_rank_v, fA=attn_h_weights[:, :, :, :], qB=value_h_states_quant, + scales=value_scales, zeros=value_zeros, + bits = self.v_bits + ) + else: + value_full_length = value_h_states_full.shape[-2] + attn_h_output = cuda_bmm_fA_qB_outer( + group_size=self.group_rank_v, fA=attn_h_weights[:, :, :, :-value_full_length], qB=value_h_states_quant, + scales=value_scales, zeros=value_zeros, + bits = self.v_bits + ) + attn_h_output += torch.matmul(attn_h_weights[:, :, :, -value_full_length:], value_h_states_full) + # attn_h_output: (bsz, num_heads, q_len * group_size, group_rank) + attn_output = attn_h_output.reshape(1, self.num_heads, q_len, self.group_rank_v) + else: + # Original version + value_h_states = value_h_states.transpose(1, 2).reshape(bsz, kv_seq_len, self.total_rank_v) + value_states = self.v_proj.reconstruct(value_h_states) + value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def fused_v_recompute_to_o(self): + """ + Fuses the `v_proj` components into `o_proj` for optimized computation. + Creates a new `o_proj` layer with fused weights. + """ + + # Calculate the required dimensions for the new o_proj layer + fused_hidden_dim_o = self.group_rank_v * self.num_heads + new_o_proj = nn.Linear(fused_hidden_dim_o, self.hidden_size, bias=self.o_proj.bias is not None) + new_o_weight = torch.zeros_like(new_o_proj.weight) # Ensure same device and dtype + + # Perform fusion by iterating over groups and ranks + total_dims_2, total_fused_dims = 0, 0 + head_dim = self.head_dim + for group_idx in range(self.num_groups): + total_dims = 0 + for _ in range(self.group_size): + v_proj_U = self.v_proj.U[group_idx].weight # Assume U_list holds the decomposed matrices + # Perform the fusion by multiplying corresponding segments + new_o_weight[:, total_fused_dims:total_fused_dims + self.group_rank_v] = ( + self.o_proj.weight[:, total_dims_2:total_dims_2 + head_dim] + @ v_proj_U[total_dims:total_dims + head_dim, :] + ) + total_dims += head_dim + total_dims_2 += head_dim + total_fused_dims += self.group_rank_v + + # Copy fused weights to new o_proj + with torch.no_grad(): + new_o_proj.weight.copy_(new_o_weight) + + # Update the o_proj to the fused version while preserving the device and dtype + new_o_proj = new_o_proj.to(self.o_proj.weight.device, self.o_proj.weight.dtype) + self.o_proj = new_o_proj + + #TODO(brian1009): Remove the U_list to CPU to save GPU memory + self.v_recompution_fused = True + torch.cuda.empty_cache() + + def prepared_k_merged_U(self): + U_list_T = [x.weight.data.T for x in self.k_proj.U] + b = torch.stack(U_list_T) + b = b.reshape(self.num_groups, self.group_rank_k, self.group_size, self.head_dim) + b = b.transpose(1, 2) + b = b.reshape(self.num_heads, self.rank_k_list[0], self.head_dim) + self.concated_B_for_k_recompute = nn.Parameter(b) + #TODO(brian1009): Remove the U_list to CPU to save GPU memory + self.has_prepared_for_k_recompute_kernel = True + + @staticmethod + def from_attention( + module: LlamaAttention, + config: LlamaConfig, + rank_k_list: list[int], + rank_v_list: list[int], + no_fusion: bool = False, + ): + new_module = LlamaPaluAttention(config, + rank_k_list, + rank_v_list, + module.layer_idx) + new_module.q_proj = module.q_proj + new_module.k_proj = HeadwiseLowRankModule.from_linear(module.k_proj, new_module.rank_k_list) + new_module.v_proj = HeadwiseLowRankModule.from_linear(module.v_proj, new_module.rank_v_list) + new_module.o_proj = module.o_proj + + + # No fusion version + if not no_fusion: + new_module.fused_v_recompute_to_o() + + return new_module diff --git a/palu/quant/__init__.py b/palu/quant/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/palu/quant/q_packing.py b/palu/quant/q_packing.py new file mode 100644 index 0000000..4c7c20b --- /dev/null +++ b/palu/quant/q_packing.py @@ -0,0 +1,223 @@ +# Adapted from https://github.com/jy-yuan/KIVI/blob/main/quant/new_pack.py +import triton +import triton.language as tl +import random +import numpy as np +import torch + + +def quant_and_pack_kcache(k: torch.FloatTensor, group_size: int, bits: int): + assert len(k.shape) == 4 + shape = k.shape + B, nh, T, D = shape + # ================== Get Scale & Zeros =============== + assert T % group_size == 0 + num_groups = T // group_size + new_shape = (B, nh, num_groups, group_size, D) + # Quantize + max_int = 2 ** bits - 1 + data = k.view(new_shape) + mn = torch.min(data, dim=-2, keepdim=True)[0] + mx = torch.max(data, dim=-2, keepdim=True)[0] + scale = (mx - mn) / max_int + data = data - mn + data.div_(scale) + data = data.clamp_(0, max_int).round_().to(torch.int32) + data = data.view(shape) + code = pack_tensor(data, bits, pack_dim=2) + return code, scale, mn + + +def quant_and_pack_vcache(v: torch.FloatTensor, group_size: int, bits: int): + shape = v.shape + assert len(shape) == 4 + assert v.shape[-1] % group_size == 0 + num_groups = shape[-1] // group_size + new_shape = (shape[:-1] + (num_groups, group_size)) + new_scales_shape = shape[:-1] + (num_groups,) + # Quantize + max_int = 2 ** bits - 1 + data = v.view(new_shape) + mn = torch.min(data, dim=-1, keepdim=True)[0] + mx = torch.max(data, dim=-1, keepdim=True)[0] + #print(mx.shape) + scale = (mx - mn) / max_int + data = data - mn + data.div_(scale) + data = data.clamp_(0, max_int).round_().to(torch.int32) + data = data.view(shape) + #print(data) + # Pack + code = pack_tensor(data, bits, pack_dim=3) + #print(code) + return code, scale.reshape(new_scales_shape), mn.reshape(new_scales_shape) + + +def unpack_and_dequant_kcache(k_code: torch.FloatTensor, + scale: torch.FloatTensor, + mn: torch.FloatTensor, + group_size: int, + bits: int, + ): + pack_dim = 2 + assert bits in [2, 4, 8] + assert len(k_code.shape) == 4 + data = unpack_tensor(k_code, bits, pack_dim=pack_dim) + shape = data.shape + num_groups = shape[pack_dim] // group_size + data = data.view(shape[:pack_dim] + (num_groups, group_size,) + shape[pack_dim+1:]) + data = data.to(torch.float16) + data = data * scale + mn + return data.view(shape) + + +def unpack_and_dequant_vcache(v_code: torch.FloatTensor, + scale: torch.FloatTensor, + mn: torch.FloatTensor, + group_size: int, + bits: int, + ): + assert bits in [2, 4, 8] + assert len(v_code.shape) == 4 + data = unpack_tensor(v_code, bits, pack_dim=3) + #print(data.shape) + shape = data.shape + num_groups = shape[-1] // group_size + data = data.view(shape[:-1] + (num_groups, group_size,)) + #print(data.shape) + data = data.to(torch.float16) + data = data * scale.unsqueeze(-1) + mn.unsqueeze(-1) + #print(data.shape) + return data.view(shape) + + +def pack_tensor(data, bits, pack_dim): + # Pack + shape = data.shape + feat_per_int = 32 // bits + assert bits in [2,4,8], "Only 2, 4, 8 bits are supported" + assert shape[pack_dim] % feat_per_int == 0, "Dimension length must be divisible by number of features per int" + # BS, nh, T, nd // 16 # 16 is for 2bit + code = torch.zeros(shape[:pack_dim] + (shape[pack_dim] // feat_per_int,)+shape[pack_dim+1:], + dtype=torch.int32, + device=data.device) + i = 0 + row = 0 + unpacked_indices = [slice(None)] * len(data.shape) + packed_indices = [slice(None)] * len(data.shape) + while row < code.shape[pack_dim]: + packed_indices[pack_dim] = row + for j in range(i, i + (32 // bits)): + unpacked_indices[pack_dim] = j + code[packed_indices] |= data[unpacked_indices] << (bits * (j - i)) + i += 32 // bits + row += 1 + return code + + +def unpack_tensor(v_code: torch.FloatTensor, + bits: int, + pack_dim: int): + assert bits in [2,4,8] + shape = v_code.shape + feat_per_int = 32 // bits + new_shape = shape[:pack_dim] + (shape[pack_dim] * feat_per_int,) + shape[pack_dim+1:] + unpacked_v_code = torch.zeros(new_shape, dtype=torch.int8, device=v_code.device) + i = torch.arange(new_shape[pack_dim], device=v_code.device) // feat_per_int + j = torch.arange(new_shape[pack_dim], device=v_code.device) % feat_per_int + num = 0xFF >> (8 - bits) + packed_indices = [slice(None)] * len(new_shape) + packed_indices[pack_dim] = i + if pack_dim == 2: + unpacked_v_code = ((v_code[packed_indices] >> (j * bits)[None, None, :, None]).to(torch.int16)) & num + elif pack_dim == 3: + unpacked_v_code = ((v_code[packed_indices] >> (j * bits)).to(torch.int16)) & num + else: + raise NotImplementedError + return unpacked_v_code + + +@triton.jit +def _pack_along_last_dim( + bits: tl.constexpr, + intensor_ptr, + code_ptr, + N, + num_feats: tl.constexpr, + feat_per_int: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr +): + num_int_per_y_dim = num_feats // feat_per_int + bid = tl.program_id(axis=0) + yid = tl.program_id(axis=1) + offs_N = bid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + block_start = intensor_ptr + offs_N * num_feats + yid * feat_per_int # offset of the first element at current tile + packed = tl.zeros((BLOCK_SIZE_N,), dtype=tl.int32) + for i in range(feat_per_int): + ptr = block_start + i + element = tl.load(ptr, mask=offs_N None: + super().__init__() # Initialize the base class + # Only quantization factors and full precision cache for values, not keys + self.value_scales_cache: List[torch.Tensor] = [] + self.value_zeros_cache: List[torch.Tensor] = [] + self.value_full_precision_cache: List[torch.Tensor] = [] # Full precision storage for values + self.residual_length = residual_length # Pre-defined residual length for quantization + self.bits = bits # Number of bits for quantization + def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Get the cache and quantization factors for a specific layer. + + Returns: + A tuple containing: + - key tensor + - value tensor + - value scales tensor + - value zeros tensor + - value full precision tensor + """ + if layer_idx < len(self): + return ( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + self.value_scales_cache[layer_idx], + self.value_zeros_cache[layer_idx], + self.value_full_precision_cache[layer_idx] + ) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def update( + self, + key_states: torch.Tensor, + value_full_precision: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Updates the cache with new `key_states`, `value_states`, `value_full_precision` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + layer_idx (`int`): The index of the layer to cache the states for. + value_full_precision (`torch.Tensor`): The full precision value states. + cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. + + Returns: + A tuple containing the updated key and value states, value scales, value zeros, and value full precision. + """ + # Update the number of seen tokens + if layer_idx == 0: + self.seen_tokens += key_states.shape[-2] + + # Update the key cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + + # Update the full precision cache + if len(self.value_full_precision_cache) <= layer_idx: + self.value_full_precision_cache.append(value_full_precision) + elif self.value_full_precision_cache[layer_idx] is None: + self.value_full_precision_cache[layer_idx] = value_full_precision + else: + self.value_full_precision_cache[layer_idx] = torch.cat( + [self.value_full_precision_cache[layer_idx], value_full_precision], dim=-2 + ) + + # Perform quantization if full precision cache exceeds the residual length + if self.value_full_precision_cache[layer_idx].shape[2] >= self.residual_length: + self.quantize_and_store(layer_idx) + + # Ensure value and quantization caches are updated + return ( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + self.value_scales_cache[layer_idx], + self.value_zeros_cache[layer_idx], + self.value_full_precision_cache[layer_idx] + ) + + def quantize_and_store(self, layer_idx: int) -> None: + """ + Quantizes the value_full_precision_cache if it exceeds the residual length and stores + the quantized values, scales, and zeros into the corresponding caches. + + Parameters: + layer_idx (`int`): The index of the layer to perform quantization on. + """ + value_full_precision = self.value_full_precision_cache[layer_idx] + assert len(value_full_precision.shape) == 4, "Value tensor must have 4 dimensions" + + current_length = value_full_precision.shape[2] + # Calculate how much to quantize: leave the remainder (mod residual_length) + quantize_length = (current_length // self.residual_length) * self.residual_length + remainder_length = current_length % self.residual_length + + + # Split the tensor into the part to quantize and the remainder + to_quantize = value_full_precision[:, :, :quantize_length, :].contiguous() # Part to be quantized + if remainder_length > 0: + remainder = value_full_precision[:, :, quantize_length:, :].contiguous() # Part to remain in full precision + + # Perform quantization + #print(value_full_precision.shape[-1]) + #quantized_value, scales, zeros = triton_quantize_and_pack_along_last_dim(to_quantize, value_full_precision.shape[-1], self.bits) + quantized_value, scales, zeros = quant_and_pack_vcache(to_quantize, value_full_precision.shape[-1], self.bits) + #NOTE(brian1009): # Transpose and make it contiguous to match the requirements of Kernel that is going to consume this tensor + quantized_value = quantized_value.transpose(3, 2).contiguous() + # Store quantized outputs + if len(self.value_cache) <= layer_idx: + self.value_cache.append(quantized_value) + self.value_scales_cache.append(scales) + self.value_zeros_cache.append(zeros) + else: + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], quantized_value], dim=-1) + self.value_scales_cache[layer_idx] = torch.cat([self.value_scales_cache[layer_idx], scales], dim=-2) + self.value_zeros_cache[layer_idx] = torch.cat([self.value_zeros_cache[layer_idx], zeros], dim=-2) + + # Update the full precision cache with the remainder + if remainder_length > 0: + self.value_full_precision_cache[layer_idx] = remainder + else: + self.value_full_precision_cache[layer_idx] = None # Clear the full precision cache + + def reorder_cache(self, beam_idx: torch.LongTensor): + raise NotImplementedError("Method `reorder_cache` is not implemented for ValueQuantizedCache.") + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("Method `to_legacy_cache` is not implemented for ValueQuantizedCache.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "ValueQuantizedCache": + raise NotImplementedError("Method `from_legacy_cache` is not implemented for ValueQuantizedCache.") + +class KeyValueQuantizedCacheV2(DynamicCache): + def __init__(self, residual_length: Union[int, Tuple[int, int]], bits: int) -> None: + super().__init__() # Initialize the base class + # Only quantization factors and full precision cache for values, not keys + self.key_scales_cache: List[torch.Tensor] = [] + self.key_zeros_cache: List[torch.Tensor] = [] + self.key_full_precision_cache: List[torch.Tensor] = [] # Full precision storage for values + self.value_scales_cache: List[torch.Tensor] = [] + self.value_zeros_cache: List[torch.Tensor] = [] + self.value_full_precision_cache: List[torch.Tensor] = [] # Full precision storage for values + # Pre-defined residual length for quantization + if isinstance(residual_length, int): + self.residual_length_k = self.residual_length_v = residual_length + else: + self.residual_length_k, self.residual_length_v = residual_length + + self.bits = bits # Number of bits for quantization + + def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Get the cache and quantization factors for a specific layer. + + Returns: + A tuple containing: + - key tensor + - key tensor + - key scales tensor + - key zeros tensor + - key full precision tensor + - value tensor + - value scales tensor + - value zeros tensor + - value full precision tensor + """ + if layer_idx < len(self): + return ( + self.key_cache[layer_idx], + self.key_scales_cache[layer_idx], + self.key_zeros_cache[layer_idx], + self.key_full_precision_cache[layer_idx], + self.value_cache[layer_idx], + self.value_scales_cache[layer_idx], + self.value_zeros_cache[layer_idx], + self.value_full_precision_cache[layer_idx] + ) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def update( + self, + key_full_precision: torch.Tensor, + value_full_precision: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Updates the cache with new `key_states`, `value_states`, `key_full_precision`, `value_full_precision` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + layer_idx (`int`): The index of the layer to cache the states for. + key_full_precision (`torch.Tensor`): The full precision key states. + value_full_precision (`torch.Tensor`): The full precision value states. + cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. + + Returns: + A tuple containing the updated key and value states, value scales, value zeros, and value full precision. + """ + # Update the number of seen tokens + if layer_idx == 0: + # TODO: Check the correctness + self.seen_tokens += key_full_precision.shape[-2] + + # Update the full precision cache + if len(self.key_full_precision_cache) <= layer_idx: + self.key_full_precision_cache.append(key_full_precision) + elif self.key_full_precision_cache[layer_idx] is None: + self.key_full_precision_cache[layer_idx] = key_full_precision + else: + self.key_full_precision_cache[layer_idx] = torch.cat( + [self.key_full_precision_cache[layer_idx], key_full_precision], dim=-2 + ) + if len(self.value_full_precision_cache) <= layer_idx: + self.value_full_precision_cache.append(value_full_precision) + elif self.value_full_precision_cache[layer_idx] is None: + self.value_full_precision_cache[layer_idx] = value_full_precision + else: + self.value_full_precision_cache[layer_idx] = torch.cat( + [self.value_full_precision_cache[layer_idx], value_full_precision], dim=-2 + ) + + # Perform quantization if full precision cache exceeds the residual length + if self.key_full_precision_cache[layer_idx].shape[-2] > self.residual_length_k: + self.quantize_and_store_key(layer_idx) + if self.value_full_precision_cache[layer_idx].shape[-2] > self.residual_length_v: + self.quantize_and_store_value(layer_idx) + + # Ensure value and quantization caches are updated + return ( + self.key_cache[layer_idx], + self.key_scales_cache[layer_idx], + self.key_zeros_cache[layer_idx], + self.key_full_precision_cache[layer_idx], + self.value_cache[layer_idx], + self.value_scales_cache[layer_idx], + self.value_zeros_cache[layer_idx], + self.value_full_precision_cache[layer_idx] + ) + + def quantize_and_store_key(self, layer_idx: int) -> None: + """ + Quantizes the key_full_precision_cache if it exceeds the residual length and stores + the quantized values, scales, and zeros into the corresponding caches. + + Parameters: + layer_idx (`int`): The index of the layer to perform quantization on. + """ + key_full_precision = self.key_full_precision_cache[layer_idx] + assert len(key_full_precision.shape) == 4, "Key tensor must have 4 dimensions" + + current_length = key_full_precision.shape[-2] + # Calculate how much to quantize: leave the remainder (mod residual_length) + quantize_length = (current_length // self.residual_length_k) * self.residual_length_k + remainder_length = current_length % self.residual_length_k + + + # Split the tensor into the part to quantize and the remainder + k_to_quantize = key_full_precision[:, :, :quantize_length, :].contiguous() # Part to be quantized + if remainder_length > 0: + k_remiander = key_full_precision[:, :, quantize_length:, :].contiguous() # Part to remain in full precision + + # Perform quantization + # NOTE(max410011): Use `quant_and_pack_vcache` because the key tensor is quantized along the last dimension too + # k_quantized, k_scales, k_zeros = quant_and_pack_vcache(k_to_quantize, key_full_precision.shape[-1], self.bits) + k_quantized, k_scales, k_zeros = triton_quantize_and_pack_along_last_dim(k_to_quantize, key_full_precision.shape[-1], self.bits) + + # NOTE(brian1009): Transpose and make it contiguous to match the requirements of Kernel that is going to consume this tensor + # NOTE(max410011): No need for key + # k_quantized = k_quantized.transpose(3, 2).contiguous().transpose(2, 3) + + # Store quantized outputs + if len(self.value_cache) <= layer_idx: + self.key_cache.append(k_quantized) + self.key_scales_cache.append(k_scales) + self.key_zeros_cache.append(k_zeros) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], k_quantized], dim=-2) + self.key_scales_cache[layer_idx] = torch.cat([self.key_scales_cache[layer_idx], k_scales], dim=-2) + self.key_zeros_cache[layer_idx] = torch.cat([self.key_zeros_cache[layer_idx], k_zeros], dim=-2) + + # Update the full precision cache with the remainder + if remainder_length > 0: + self.key_full_precision_cache[layer_idx] = k_remiander + else: + self.key_full_precision_cache[layer_idx] = None # Clear the full precision cache + + def quantize_and_store_value(self, layer_idx: int) -> None: + """ + Quantizes the value_full_precision_cache if it exceeds the residual length and stores + the quantized values, scales, and zeros into the corresponding caches. + + Parameters: + layer_idx (`int`): The index of the layer to perform quantization on. + """ + value_full_precision = self.value_full_precision_cache[layer_idx] + assert len(value_full_precision.shape) == 4, "Value tensor must have 4 dimensions" + + current_length = value_full_precision.shape[-2] + # Calculate how much to quantize: leave the remainder (mod residual_length) + quantize_length = (current_length // self.residual_length_v) * self.residual_length_v + remainder_length = current_length % self.residual_length_v + + + # Split the tensor into the part to quantize and the remainder + v_to_quantize = value_full_precision[:, :, :quantize_length, :].contiguous() # Part to be quantized + if remainder_length > 0: + v_remainder = value_full_precision[:, :, quantize_length:, :].contiguous() # Part to remain in full precision + + # Perform quantization + # NOTE(max410011): Use `quant_and_pack_vcache` because the key tensor is quantized along the last dimension too + v_quantized, v_scales, v_zeros = quant_and_pack_vcache(v_to_quantize, value_full_precision.shape[-1], self.bits) + # v_quantized, v_scales, v_zeros = triton_quantize_and_pack_along_last_dim(v_to_quantize, value_full_precision.shape[-1], self.bits) + + # NOTE(brian1009): Transpose and make it contiguous to match the requirements of Kernel that is going to consume this tensor + # NOTE(max410011): No need for key + # k_quantized = k_quantized.transpose(3, 2).contiguous().transpose(2, 3) + + # Store quantized outputs + if len(self.value_cache) <= layer_idx: + # FIXME(max410011): Move to above? + v_quantized = v_quantized.transpose(3, 2).contiguous().transpose(2, 3) + v_scales = v_scales.transpose(3, 2).contiguous().transpose(2, 3) + v_zeros = v_zeros.transpose(3, 2).contiguous().transpose(2, 3) + + self.value_cache.append(v_quantized) + self.value_scales_cache.append(v_scales) + self.value_zeros_cache.append(v_zeros) + else: + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], v_quantized], dim=-2) + self.value_scales_cache[layer_idx] = torch.cat([self.value_scales_cache[layer_idx], v_scales], dim=-2) + self.value_zeros_cache[layer_idx] = torch.cat([self.value_zeros_cache[layer_idx], v_zeros], dim=-2) + + # FIXME(max410011): Move to above? + self.value_cache[layer_idx] = self.value_cache[layer_idx].transpose(3, 2).contiguous().transpose(2, 3) + self.value_scales_cache[layer_idx] = self.value_scales_cache[layer_idx].transpose(3, 2).contiguous().transpose(2, 3) + self.value_zeros_cache[layer_idx] = self.value_zeros_cache[layer_idx].transpose(3, 2).contiguous().transpose(2, 3) + + # Update the full precision cache with the remainder + if remainder_length > 0: + self.value_full_precision_cache[layer_idx] = v_remainder + else: + self.value_full_precision_cache[layer_idx] = None # Clear the full precision cache + + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + seq_length = 0 + if self.key_cache[layer_idx] is not None: + seq_length += self.key_cache[layer_idx].shape[-2] + if self.key_full_precision_cache[layer_idx] is not None: + seq_length += self.key_full_precision_cache[layer_idx].shape[-2] + return seq_length + + # def get_max_length(self) -> Optional[int]: + # raise NotImplementedError("Method `get_max_length` is not implemented for KeyValueQuantizedCacheV2.") + + def reorder_cache(self, beam_idx: torch.LongTensor): + raise NotImplementedError("Method `reorder_cache` is not implemented for KeyValueQuantizedCacheV2.") + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("Method `to_legacy_cache` is not implemented for KeyValueQuantizedCacheV2.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "KeyValueQuantizedCacheV2": + raise NotImplementedError("Method `from_legacy_cache` is not implemented for KeyValueQuantizedCacheV2.") + +class FastValueDynamicCache(DynamicCache): + def __init__(self, residual_length: int) -> None: + super().__init__() # Initialize the base class + self.residual_length = residual_length # Pre-defined residual length for managing the cache + + # Maintain a single cache for keys + self.key_cache: List[Optional[torch.Tensor]] = [] + + # Maintain two caches for values: `Recent` and `Main` + self.value_recent_cache: List[Optional[torch.Tensor]] = [] # Temporary cache for recent tokens (not yet moved to `Main`) + self.value_main_cache: List[Optional[torch.Tensor]] = [] # Cumulative cache for storing moved tokens + + def __getitem__(self, layer_idx: int) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Get the cache for a specific layer. + + Returns: + A tuple containing: + - key tensor + - recent value tensor + - main value tensor + """ + if layer_idx < len(self.key_cache): + return ( + self.key_cache[layer_idx], + self.value_recent_cache[layer_idx], + self.value_main_cache[layer_idx] + ) + else: + raise KeyError(f"Cache only has {len(self.key_cache)} layers, attempted to access layer with index {layer_idx}") + + def update( + self, + key_full_precision: torch.Tensor, + value_full_precision: torch.Tensor, + layer_idx: int + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Updates the cache with new `key_full_precision`, `value_full_precision` for the layer `layer_idx`. + + If the number of tokens in the recent value cache exceeds the residual length, they are moved to the main value cache. + + Parameters: + key_full_precision (`torch.Tensor`): The full precision key states. + value_full_precision (`torch.Tensor`): The full precision value states. + layer_idx (`int`): The index of the layer to cache the states for. + + Returns: + A tuple containing the updated key and value states: + - key tensor + - recent value tensor + - main value tensor + """ + # Initialize the caches for the layer if they don't exist + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_full_precision) + self.value_recent_cache.append(value_full_precision) + self.value_main_cache.append(None) # None placeholder for the main value cache + else: + # Update key cache + if self.key_cache[layer_idx] is None: + self.key_cache[layer_idx] = key_full_precision + else: + # Concatenate the new keys into the key cache + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx], key_full_precision], dim=-2 + ) + + # Update recent value cache + if self.value_recent_cache[layer_idx] is None: + self.value_recent_cache[layer_idx] = value_full_precision + else: + # Concatenate the new values into the recent value cache + self.value_recent_cache[layer_idx] = torch.cat( + [self.value_recent_cache[layer_idx], value_full_precision], dim=-2 + ) + + # If the number of tokens in the recent value cache exceeds the residual length, move them to the main cache + if self.value_recent_cache[layer_idx].shape[-2] >= self.residual_length: + self.move_to_main_cache(layer_idx) + + # Return the updated caches + return ( + self.key_cache[layer_idx], + self.value_main_cache[layer_idx], + self.value_recent_cache[layer_idx] + ) + + def move_to_main_cache(self, layer_idx: int) -> None: + """ + Move the recent value cache tokens to the main value cache for the specified layer. + + This method is called when the number of tokens in the recent value cache exceeds the residual length. + """ + value_recent = self.value_recent_cache[layer_idx] + + if value_recent is None: + return + + # Calculate how many tokens to move (multiple of residual_length) + move_length = (value_recent.shape[-2] // self.residual_length) * self.residual_length + remainder_length = value_recent.shape[-2] % self.residual_length + + if move_length > 0: + value_to_move = value_recent[:, :, :move_length, :].contiguous() + + # Concatenate the tokens to the main value cache + if self.value_main_cache[layer_idx] is None: + self.value_main_cache[layer_idx] = value_to_move + else: + self.value_main_cache[layer_idx] = torch.cat( + [self.value_main_cache[layer_idx], value_to_move], dim=-2 + ) + + # Update the recent value cache with the remainder + if remainder_length > 0: + self.value_recent_cache[layer_idx] = value_recent[:, :, move_length:, :].contiguous() + else: + # Reset the recent value cache to None if no remainder + self.value_recent_cache[layer_idx] = None + + # Note: The key cache remains unchanged as we maintain a single key cache + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states.""" + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def reorder_cache(self, beam_idx: torch.LongTensor): + raise NotImplementedError("Method `reorder_cache` is not implemented for FastDynamicCache.") + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("Method `to_legacy_cache` is not implemented for FastDynamicCache.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "FastDynamicCache": + raise NotImplementedError("Method `from_legacy_cache` is not implemented for FastDynamicCache.") diff --git a/palu/quant_utils.py b/palu/quant/quant_utils.py similarity index 89% rename from palu/quant_utils.py rename to palu/quant/quant_utils.py index 6b8a6ae..d0216f9 100644 --- a/palu/quant_utils.py +++ b/palu/quant/quant_utils.py @@ -1,4 +1,4 @@ -from .model.modules import HeadwiseLowRankModule +from ..model.modules import HeadwiseLowRankModule import torch.nn as nn def configure_latent_quantizer( diff --git a/palu/data_utils.py b/palu/utils/data_utils.py similarity index 100% rename from palu/data_utils.py rename to palu/utils/data_utils.py diff --git a/run_lm_eval.py b/run_lm_eval.py index 920fd12..ddcc458 100644 --- a/run_lm_eval.py +++ b/run_lm_eval.py @@ -8,7 +8,7 @@ from lm_eval.models.huggingface import HFLM from lm_eval.utils import make_table from lm_eval.utils import eval_logger as logger -from palu.quant_utils import configure_latent_quantizer +from palu.quant.quant_utils import configure_latent_quantizer import os import json diff --git a/run_long_bench.py b/run_long_bench.py index f6a3b98..85fcc08 100755 --- a/run_long_bench.py +++ b/run_long_bench.py @@ -15,7 +15,7 @@ from longbench_utils import scorer, MODEL2MAXLEN, DATASET2PROMPT, DATASET2MAXLEN from utils import load_model_and_tokenizer, add_common_args -from palu.quant_utils import configure_latent_quantizer +from palu.quant.quant_utils import configure_latent_quantizer import palu.model def post_process(response, model_name): diff --git a/run_ppl_eval.py b/run_ppl_eval.py index aaed175..d121b1b 100644 --- a/run_ppl_eval.py +++ b/run_ppl_eval.py @@ -5,7 +5,7 @@ import argparse import os from utils import load_model_and_tokenizer, add_common_args -from palu.quant_utils import configure_latent_quantizer +from palu.quant.quant_utils import configure_latent_quantizer from loguru import logger def get_ppl_eval_loaders(name, tokenizer, seqlen=2048): diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..2b5e7a1 --- /dev/null +++ b/setup.py @@ -0,0 +1,52 @@ +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +extra_compile_args = { + "cxx": [ + "-g", + "-O3", + "-fopenmp", + "-lgomp", + "-std=c++17", + "-DENABLE_BF16" + ], + "nvcc": [ + "-O3", + "-std=c++17", + "-DENABLE_BF16", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--threads=8" + ], +} + +# Read requirements.txt +with open("requirements.txt") as f: + requirements = f.read().splitlines() + +setup( + name="palu", + version="0.1", + description="Palu package with CUDA extension", + packages=find_packages(), + ext_modules=[ + CUDAExtension( + name="palu.palu_cuda", + sources=[ + "palu/csrc/palu_gemm_cuda.cu", + "palu/csrc/palu_gemm_outer_cuda.cu", + "palu/csrc/pybind.cpp", + ], + extra_compile_args=extra_compile_args, + ), + ], + cmdclass={"build_ext": BuildExtension}, + install_requires=requirements, # Load requirements from requirements.txt +) diff --git a/tests/module/test_palu_attention_fp16.py b/tests/module/test_palu_attention_fp16.py new file mode 100644 index 0000000..46f3686 --- /dev/null +++ b/tests/module/test_palu_attention_fp16.py @@ -0,0 +1,97 @@ +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) + +import torch +import pytest +from palu.model.svd_llama.palu_llama_attention import LlamaPaluAttention +from transformers.models.llama.modeling_llama import LlamaAttention +from transformers.cache_utils import DynamicCache +from transformers import AutoConfig + + +@pytest.fixture(scope="module", autouse=True) +def set_random_seed(): + """Fixture to set a fixed random seed for reproducibility.""" + torch.manual_seed(0) + torch.cuda.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +@pytest.fixture +def llama_config(): + """Fixture to initialize Llama configuration.""" + config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf") + config.k_bits = 16 + config.v_bits = 16 + return config + + +def test_prefilling(llama_config): + """ + Test prefilling attention with both the original and PALU attention modules. + Verifies that the attention weights and outputs are nearly identical. + """ + bsz, seq_len, hidden_dim = 1, 256, 4096 + attn = LlamaAttention(llama_config, layer_idx=0).half().cuda() + palu_attn = LlamaPaluAttention.from_attention( + module=attn, + config=llama_config, + rank_k_list=[512 for _ in range(8)], + rank_v_list=[512 for _ in range(8)], + no_fusion=True + ).cuda() + + inputs = torch.rand(bsz, seq_len, hidden_dim, dtype=torch.float16).cuda() + orig_output, orig_attn_weights, _ = attn(inputs, output_attentions=True) + palu_output, palu_attn_weights, _ = palu_attn(inputs, output_attentions=True) + + # Assert closeness in prefilling + torch.testing.assert_close(orig_attn_weights, palu_attn_weights, rtol=1e-3, atol=7.5e-3) + torch.testing.assert_close(orig_output, palu_output, rtol=1e-3, atol=7.5e-3) + + #Merge the recomputation matrix into o_projection + palu_attn.fused_v_recompute_to_o() + orig_output, _, _ = attn(inputs) + palu_output, _, _ = palu_attn(inputs) + + torch.testing.assert_close(orig_output, palu_output, rtol=1e-3, atol=1e-3) + + +def test_decoding(llama_config): + """ + Test decoding with both the original and PALU attention modules. + Verifies that output and attention weights are nearly identical. + """ + bsz, prompt_len, decode_len = 1, 63, 1 + device, dtype = "cuda", torch.float16 + + attn = LlamaAttention(llama_config, layer_idx=0).to(device, dtype) + palu_attn = LlamaPaluAttention.from_attention( + module=attn, + config=llama_config, + rank_k_list=[512 for _ in range(8)], + rank_v_list=[512 for _ in range(8)], + no_fusion=False # Directly running in fusion mode + ).to(device, dtype) + + prompt_inputs = torch.rand(bsz, prompt_len, llama_config.hidden_size).to(device, dtype) + generate_inputs = torch.rand(bsz, decode_len, llama_config.hidden_size).to(device, dtype) + prompt_position_ids = torch.arange(prompt_len).unsqueeze(0) # Shape: [1, seq_length] + generate_position_ids = torch.arange(prompt_len, prompt_len + decode_len).unsqueeze(0) # Shape: [1, seq_length] + kv_cache, palu_kv_cache = DynamicCache(), DynamicCache() + + # Run prompting + attn_output, _, kv_cache = attn(prompt_inputs, output_attentions=False, past_key_value=kv_cache, position_ids=prompt_position_ids) + palu_attn_output, _, palu_kv_cache = palu_attn(prompt_inputs, output_attentions=False, past_key_value=palu_kv_cache, position_ids=prompt_position_ids) + + torch.testing.assert_close(attn_output, palu_attn_output, rtol=1e-3, atol=1e-3) + + # Run generation step + palu_attn.prepared_k_merged_U() + attn_output, attn_weights, kv_cache = attn(generate_inputs, output_attentions=True, past_key_value=kv_cache, position_ids=generate_position_ids) + palu_attn_output, palu_attn_weights, palu_kv_cache = palu_attn(generate_inputs, output_attentions=True, past_key_value=palu_kv_cache, position_ids=generate_position_ids) + + torch.testing.assert_close(attn_weights, palu_attn_weights, rtol=5e-3, atol=3e-2) + torch.testing.assert_close(attn_output, palu_attn_output, rtol=5e-3, atol=3e-2) diff --git a/tests/module/test_palu_attention_int4.py b/tests/module/test_palu_attention_int4.py new file mode 100644 index 0000000..2c4c5c3 --- /dev/null +++ b/tests/module/test_palu_attention_int4.py @@ -0,0 +1,93 @@ +import warnings +import torch +import pytest +from palu.model.svd_llama.palu_llama_attention import LlamaPaluAttention +from palu.quant.quant_kv_cache import ValueQuantizedCacheV2 +from palu.quant.q_packing import unpack_and_dequant_vcache +from transformers.models.llama.modeling_llama import LlamaAttention +from transformers import AutoConfig + +# Suppress FutureWarnings +warnings.simplefilter(action='ignore', category=FutureWarning) + +@pytest.fixture(scope="module", autouse=True) +def set_random_seed(): + """Fixture to set a fixed random seed for reproducibility.""" + torch.manual_seed(0) + torch.cuda.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +@pytest.fixture +def llama_config(): + """Fixture to initialize Llama configuration.""" + config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf") + config.k_bits = 16 + config.v_bits = 4 + return config + +# Define parameter values and custom IDs +params = [(129, 128), (256, 256), (512, 512)] +param_ids = [f"prompt_len={p[0]}, residual_length={p[1]}" for p in params] + +@pytest.mark.parametrize("prompt_len, residual_length", params, ids=param_ids) +def test_decoding(llama_config, prompt_len, residual_length): + """ + Test decoding with both the original and PALU attention modules. + Verifies that output and attention weights are nearly identical. + """ + bsz, decode_len = 1, 1 + device, dtype = "cuda", torch.float16 + + # Initialize attention modules + attn = LlamaAttention(llama_config, layer_idx=0).to(device, dtype) + palu_attn = LlamaPaluAttention.from_attention( + module=attn, + config=llama_config, + rank_k_list=[512 for _ in range(8)], + rank_v_list=[512 for _ in range(8)], + no_fusion=False # Directly running in fusion mode + ).to(device, dtype) + palu_attn.prepared_k_merged_U() + + # Generate random inputs + prompt_inputs = torch.rand(bsz, prompt_len, llama_config.hidden_size).to(device, dtype) + generate_inputs = torch.rand(bsz, decode_len, llama_config.hidden_size).to(device, dtype) + prompt_position_ids = torch.arange(prompt_len).unsqueeze(0) # Shape: [1, seq_length] + generate_position_ids = torch.arange(prompt_len, prompt_len + decode_len).unsqueeze(0) # Shape: [1, seq_length] + palu_kv_cache = ValueQuantizedCacheV2(bits=llama_config.v_bits, residual_length=residual_length) + + # Run prompting + _, _, palu_kv_cache = palu_attn( + prompt_inputs, + output_attentions=False, + past_key_value=palu_kv_cache, + position_ids=prompt_position_ids + ) + + # Run generation step (Palu) + palu_attn_output, palu_attn_weights, palu_kv_cache = palu_attn( + generate_inputs, + output_attentions=True, + past_key_value=palu_kv_cache, + position_ids=generate_position_ids + ) + + # Extract and process attention weights and value states for validation + _, value_quant, scales, zeros, value_full = palu_kv_cache[0] + value_quant = value_quant.transpose(2, 3) + value_states = unpack_and_dequant_vcache(value_quant, scales, zeros, value_quant.shape[-1] * 8, 4) + + if value_full is not None: + value_states = torch.cat([value_states, value_full], dim=-2) + + palu_attn_h_weights = palu_attn_weights.reshape(1, palu_attn.num_groups, decode_len * palu_attn.group_size, -1) + attn_h_output_golden = torch.matmul(palu_attn_h_weights, value_states) + attn_output_golden = attn_h_output_golden.reshape(1, palu_attn.num_heads, decode_len, -1) + attn_output_golden = attn_output_golden.transpose(1, 2).contiguous() + attn_output_golden = attn_output_golden.view(bsz, decode_len, -1) + attn_output_golden = palu_attn.o_proj(attn_output_golden) + + # Assert that the outputs are close + torch.testing.assert_close(palu_attn_output, attn_output_golden, rtol=1e-3, atol=1e-3) diff --git a/tests/ops/test_bgemm_outer.py b/tests/ops/test_bgemm_outer.py new file mode 100644 index 0000000..4387080 --- /dev/null +++ b/tests/ops/test_bgemm_outer.py @@ -0,0 +1,61 @@ +import torch +import pytest +from palu.quant.q_packing import unpack_and_dequant_vcache, quant_and_pack_vcache +from palu.backend.q_matmul import cuda_bmm_fA_qB_outer + +# Set the seed for reproducibility +torch.random.manual_seed(1234) + +# Define tolerance levels for comparison +ATOL = 1e-1 +RTOL = 1e-3 + +# Define test cases for different tensor shapes +# Each tuple represents (q_len, seq_len, num_heads, grouped_rank, block_size, num_bits) +TEST_CASES = [ + (4, 1024, 8, 128, 128, 4), + (8, 1024, 8, 128, 128, 4), + (256, 1024, 8, 128, 128, 4), + (1024, 1024, 8, 128, 128, 4), + (2048, 1024, 8, 128, 128, 4), + (4096, 1024, 8, 128, 128, 4), +] + +def run_orig_matmul(attn_weights, Value): + """Run the original matmul implementation for verification.""" + return torch.matmul(attn_weights, Value) + +def run_palu_matmul(attn_weights, V_quant, V_scales, V_mn, block_size, num_bits): + """Run the PALU custom matmul implementation.""" + return cuda_bmm_fA_qB_outer(block_size, attn_weights, V_quant, V_scales, V_mn, num_bits) + +# Define test case names for improved readability in test output +param_ids = [f"q_len={case[0]}, seq_len={case[1]}, num_heads={case[2]}" for case in TEST_CASES] + +@pytest.mark.parametrize("q_len, seq_len, num_heads, grouped_rank, block_size, num_bits", TEST_CASES, ids=param_ids) +def test_palu_matmul(q_len, seq_len, num_heads, grouped_rank, block_size, num_bits): + """Parameterized test to validate correctness across multiple input shapes.""" + + # Generate random attention weights and value tensors with updated shapes + attn_weights = torch.rand(1, num_heads, q_len, seq_len, dtype=torch.float16).cuda() + Value = torch.rand(1, num_heads, seq_len, num_heads * grouped_rank, dtype=torch.float16).cuda() + + # Quantize Value tensor with the updated shape + V_quant, V_scales, V_mn = quant_and_pack_vcache(Value, block_size, num_bits) + + # Dequantize to get the 'golden' reference + V_dequant = unpack_and_dequant_vcache(V_quant, V_scales, V_mn, block_size, num_bits) + + # Change the underlying memory layout implicitly to simulate memory layout adjustments + V_quant = V_quant.transpose(-2, -1).contiguous().transpose(-2, -1) + V_scales = V_scales.transpose(-2, -1).contiguous().transpose(-2, -1) + V_mn = V_mn.transpose(-2, -1).contiguous().transpose(-2, -1) + + # Run the original and PALU implementations + golden_output = run_orig_matmul(attn_weights, V_dequant) + palu_output = run_palu_matmul(attn_weights, V_quant, V_scales, V_mn, block_size, num_bits) + + # Check for correctness within tolerance + assert torch.allclose(palu_output, golden_output, atol=ATOL, rtol=RTOL), ( + f"Test failed for shape (q_len={q_len}, seq_len={seq_len}, num_heads={num_heads}, grouped_rank={grouped_rank})" + ) diff --git a/tests/ops/test_triton_recompute.py b/tests/ops/test_triton_recompute.py new file mode 100644 index 0000000..bd7dfc0 --- /dev/null +++ b/tests/ops/test_triton_recompute.py @@ -0,0 +1,89 @@ +import torch +import pytest +from palu.backend.fused_recompute import abx + +# Define tolerance levels for comparison +ATOL = 1 +RTOL = 1e-3 + +# Set random seed for reproducibility +def set_random_seed(seed=0): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +# Rotary Embedding - defined here as a helper function +def LlamaRotaryEmbedding(dim: int, end: int, theta: float = 10000.0): + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + t = torch.arange(end, dtype=torch.int64).type_as(inv_freq) + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos(), emb.sin() + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb_pytorch(x, cos, sin, unsqueeze_dim=0): + cos = cos.unsqueeze(unsqueeze_dim).to(x.device) + sin = sin.unsqueeze(unsqueeze_dim).to(x.device) + x_emb = (x * cos) + (rotate_half(x) * sin) + return x_emb + +def torch_abx(a, b, x): + x_expand = x.unsqueeze(1) + b_reshape = b.reshape(-1, b.shape[0] // x.shape[0], b.shape[-2], b.shape[-1]) + xb = x_expand @ b_reshape + xb = xb.reshape(b.shape[0], -1, b.shape[-1]) + cos, sin = LlamaRotaryEmbedding(dim=128, end=x.shape[1]) + xb_rope = apply_rotary_pos_emb_pytorch(x=xb, cos=cos, sin=sin) + axb = a @ xb_rope.transpose(-1, -2).to(torch.float16) + return axb + +# Define test cases with varied parameters +# Each tuple represents (num_heads, head_dim, total_rank, num_groups, seq_len) +TEST_CASES = [ + (32, 128, 1024, 8, 64), + (32, 128, 2048, 8, 64), + (32, 128, 1024, 8, 256), + (32, 128, 1024, 8, 1024), + (32, 128, 1024, 8, 4096), + # test arbitary output length + (32, 128, 1024, 8, 65), + (32, 128, 1024, 8, 78), + (32, 128, 1024, 8, 4099), +] + +@pytest.mark.parametrize("num_heads, head_dim, total_rank, num_groups, seq_len", TEST_CASES) +def test_abx(num_heads, head_dim, total_rank, num_groups, seq_len): + """Test the abx function for various configurations.""" + set_random_seed(0) + rank_per_groups = total_rank // num_groups + dtype = torch.float16 + device = "cuda" + + # Create test tensors with configurable seq_len + A = torch.randn(num_heads, 1, head_dim, dtype=dtype, device=device) + B = torch.randn(num_heads, rank_per_groups, head_dim, dtype=dtype, device=device) + X = torch.randn(num_groups, seq_len, rank_per_groups, dtype=dtype, device=device) + + # Run the original and custom implementations + axb = torch_abx(A, B, X) + ours = abx(A, B, X) + + # Check for correctness within tolerance + max_diff = torch.max(torch.abs(axb - ours)) + assert torch.allclose(axb, ours, atol=ATOL, rtol=RTOL), f"Test failed: Max diff {max_diff.item()} exceeded tolerance" + + print(f"Test passed for (num_heads={num_heads}, head_dim={head_dim}, total_rank={total_rank}, num_groups={num_groups}, seq_len={seq_len}) with max diff: {max_diff.item()}") + +# For manual testing without pytest, you could include a simple runner: +if __name__ == '__main__': + set_random_seed(0) + for case in TEST_CASES: + num_heads, head_dim, total_rank, num_groups, seq_len = case + print(f"Running test for (num_heads={num_heads}, head_dim={head_dim}, total_rank={total_rank}, num_groups={num_groups}, seq_len={seq_len})") + test_abx(num_heads, head_dim, total_rank, num_groups, seq_len) diff --git a/utils.py b/utils.py index 142cf21..a9fcfa2 100644 --- a/utils.py +++ b/utils.py @@ -47,7 +47,6 @@ def get_module_by_name(module, module_name): def dump_to_huggingface_repos(model, tokenizer, save_path, args): tokenizer.save_pretrained(save_path) - #model.generation_config = Gene #if "vicuna" in model.config._name_or_path.lower() or "longchat" in model.config._name_or_path.lower(): #NOTE(brian1009): Ad-hoc fixing the bug in Vicuna # model.config.generation_config = GenerationConfig(temperature=1.0, top_p=1.0)