diff --git a/benchmarks/dora/bench_utils.py b/benchmarks/dora/bench_utils.py new file mode 100644 index 000000000..2de4fa637 --- /dev/null +++ b/benchmarks/dora/bench_utils.py @@ -0,0 +1,131 @@ +import torch +from bitsandbytes.nn import Linear4bit +from hqq.core.quantize import BaseQuantizeConfig, HQQLinear + +from prototypes.dora.dora_layer import BNBDoRALinear, HQQDoRALinear +from prototypes.dora.kernels.matmul import triton_mm +from prototypes.dora.kernels.smallk import triton_mm_small_k + + +def make_lora_weights(ranks, in_features, out_features, dtype): + As = [torch.randn(rank, in_features, device="cuda", dtype=dtype) for rank in ranks] + Bs = [torch.randn(out_features, rank, device="cuda", dtype=dtype) for rank in ranks] + return As, Bs + + +def make_dora_source_and_magnitude(in_features, out_features, dtype): + source = torch.randn(out_features, in_features, device="cuda", dtype=dtype) + magnitude = torch.randn(out_features, device="cuda", dtype=dtype) + return source, magnitude + + +def make_inputs(batch_sizes, seqlen, in_features, dtype): + xs = [ + torch.randn(bs * seqlen, in_features, device="cuda", dtype=dtype) + for bs in batch_sizes + ] + return xs + + +def make_weights(batch_sizes, in_features, out_features, dtype): + weights = [ + torch.randn(in_features, out_features, device="cuda", dtype=dtype) + for _ in range(len(batch_sizes)) + ] + return weights + + +def make_epilogue_sources(batch_sizes, seqlen, out_features, dtype): + epilogue_sources = [ + torch.randn(bs * seqlen, out_features, device="cuda", dtype=dtype) + for bs in batch_sizes + ] + return epilogue_sources + + +def make_epilogue_scales(batch_sizes, out_features, dtype): + epilogue_scales = [ + torch.randn(out_features, device="cuda", dtype=dtype) + for _ in range(len(batch_sizes)) + ] + return epilogue_scales + + +def dora_colnorm_ref( + A: torch.Tensor, + B: torch.Tensor, + base_weight: torch.Tensor, + magnitude_vector: torch.Tensor, +): + column_norm = (base_weight + B @ A).norm(p=2, dim=1) + return magnitude_vector / column_norm + + +def dora_mm_epilogue_ref( + A: torch.Tensor, + B: torch.Tensor, + epilogue_source: torch.Tensor, + epilogue_scale: torch.Tensor, +): + out = (A @ B + epilogue_source) * epilogue_scale[None, :] + return out + + +def dora_ref(x, w, lora_A, lora_B, magnitude_vector): + # (bs x seq_len x out_features) = (bs x seq_len x in_features) @ (in_features x rank) @ (rank x out_features) + lora_out = (x @ lora_A.T) @ lora_B.T + # (out_features) + magnitude_scale = dora_colnorm_ref(lora_A, lora_B, w, magnitude_vector) + # (bs x seq_len x out_features) + dora_out_ref = dora_mm_epilogue_ref(x, w, lora_out, magnitude_scale) + return dora_out_ref + + +def dora_triton(x, w, lora_A, lora_B, magnitude_vector): + lora_out = (x @ lora_A.T) @ lora_B.T + magnitude_scale = triton_mm_small_k( + lora_B, + lora_A, + epilogue_norm=True, + source=w, + magnitude=magnitude_vector, + store_acc=False, + ) + dora_out = triton_mm(x, w, epilogue_source=lora_out, epilogue_scale=magnitude_scale) + return dora_out + + +def setup_dora_base_layers(layer_type, in_features, out_features, dtype): + if "bnb" in layer_type: + # BitsandBytes + base_layer = Linear4bit( + input_features=in_features, + output_features=out_features, + bias=False, + quant_type="nf4", + compute_dtype=dtype, + ).cuda() + base_layer.quant_state.dtype = base_layer.compute_dtype + dora_cls = BNBDoRALinear + elif "hqq" in layer_type: + # HQQ + quant_config = BaseQuantizeConfig( + nbits=4, + group_size=64, + quant_zero=False, + quant_scale=False, + offload_meta=True, + view_as_float=True, + ) + linear = torch.nn.Linear( + in_features, out_features, dtype=dtype, bias=False + ).cuda() + base_layer = HQQLinear( + linear, + quant_config, + compute_dtype=dtype, + ) + dora_cls = HQQDoRALinear + else: + raise ValueError(f"Unknown layer type: {layer_type}") + return base_layer, dora_cls diff --git a/benchmarks/dora/dora_bench.py b/benchmarks/dora/dora_bench.py new file mode 100644 index 000000000..305cfbdb1 --- /dev/null +++ b/benchmarks/dora/dora_bench.py @@ -0,0 +1,348 @@ +import argparse + +import pandas as pd +import torch +from bench_utils import ( + dora_colnorm_ref, + dora_mm_epilogue_ref, + dora_ref, + dora_triton, + make_dora_source_and_magnitude, + make_epilogue_scales, + make_epilogue_sources, + make_inputs, + make_lora_weights, + make_weights, + setup_dora_base_layers, +) +from triton.testing import do_bench + +from torchao.prototype.dora.kernels.matmul import triton_mm +from torchao.prototype.dora.kernels.smallk import triton_mm_small_k +from torchao.prototype.common.profiling_tools import pivot_df + + +def run_colnorm_bench(args): + in_features, out_features = args.in_features, args.out_features + + dtype = getattr(torch, args.dtype) + + # Inputs + As, Bs = make_lora_weights(args.dora_ranks, in_features, out_features, dtype) + source, magnitude = make_dora_source_and_magnitude(in_features, out_features, dtype) + + # torch.compile + dora_colnorm_compiled = torch.compile(dora_colnorm_ref, mode=args.compile_mode) + compiled_key = f"compiled_{args.compile_mode}" + + # Benchmark + timings = [] + + for a, b in zip(As, Bs): + ref_t = do_bench(lambda: dora_colnorm_ref(a, b, source, magnitude)) + compiled_t = do_bench(lambda: dora_colnorm_compiled(a, b, source, magnitude)) + + test_t = do_bench( + lambda: triton_mm_small_k( + b, + a, + epilogue_norm=True, + source=source, + magnitude=magnitude, + store_acc=False, + ), + ) + common_args = [a.shape[0], a.shape[1], b.shape[0], args.dtype] + timings.append([*common_args, "ref", ref_t]) + timings.append([*common_args, compiled_key, compiled_t]) + timings.append([*common_args, "triton", test_t]) + + # Group results for kernel type + headers = ["rank", "in_features", "out_features", "dtype", "kernel", "time(ms)"] + df = pd.DataFrame(timings, columns=headers) + id_cols = ["rank", "in_features", "out_features"] + pivot_df( + df, + id_cols=id_cols, + columns="kernel", + values="time(ms)", + column_order=[*id_cols, "ref", compiled_key, "triton"], + show=True, + ) + + +def run_epilogue_bench(args): + in_features, out_features = args.in_features, args.out_features + seqlen = args.seqlen + batch_sizes = ( + args.batch_sizes if isinstance(args.batch_sizes, list) else [args.batch_sizes] + ) + dtype = getattr(torch, args.dtype) + + # Inputs + xs = make_inputs(batch_sizes, seqlen, in_features, dtype) + weights = make_weights(batch_sizes, in_features, out_features, dtype) + epilogue_sources = make_epilogue_sources(batch_sizes, seqlen, out_features, dtype) + epilogue_scales = make_epilogue_scales(batch_sizes, out_features, dtype) + + # torch.compile + dora_mm_epilogue_compiled = torch.compile( + dora_mm_epilogue_ref, mode=args.compile_mode + ) + compiled_key = f"compiled_{args.compile_mode}" + + # Benchmark + timings = [] + for bs, x, w, e1, e2 in zip( + batch_sizes, xs, weights, epilogue_sources, epilogue_scales + ): + ref_t = do_bench(lambda: dora_mm_epilogue_ref(x, w, e1, e2)) + compiled_t = do_bench(lambda: dora_mm_epilogue_compiled(x, w, e1, e2)) + + test_t = do_bench( + lambda: triton_mm( + x, + w, + epilogue_source=e1, + epilogue_scale=e2, + ) + ) + common_args = [bs, seqlen, w.shape[0], w.shape[1], args.dtype] + timings.append([*common_args, "ref", ref_t]) + timings.append([*common_args, compiled_key, compiled_t]) + timings.append([*common_args, "triton", test_t]) + + # Group results for kernel type + headers = [ + "bs", + "seqlen", + "in_features", + "out_features", + "dtype", + "kernel", + "time(ms)", + ] + df = pd.DataFrame(timings, columns=headers) + id_cols = ["bs", "seqlen", "in_features", "out_features", "dtype"] + + pivot_df( + df, + id_cols=id_cols, + columns="kernel", + values="time(ms)", + column_order=[*id_cols, "ref", compiled_key, "triton"], + show=True, + ) + + +def run_full_dora(args): + """Dora Layer + + out = (x @ base_weight + lora_out) * magnitude_scale + where: + `lora_out = lora_B(lora_A(x)` + `magnitude_scale = (base_weight + lora_B @ lora_A).norm(p=2, dim=1) * magnitude_vector` + """ + + dtype = getattr(torch, args.dtype) + xs = make_inputs(args.batch_sizes, args.seqlen, args.in_features, dtype) + weights = make_weights(args.batch_sizes, args.in_features, args.out_features, dtype) + lora_As, lora_Bs = make_lora_weights( + args.dora_ranks, args.in_features, args.out_features, dtype + ) + _, magnitude_vector = make_dora_source_and_magnitude( + args.in_features, args.out_features, dtype + ) + + # torch.compile + dora_compiled = torch.compile(dora_ref, mode=args.compile_mode) + # triton_compiled = torch.compile(dora_triton, mode=args.compile_mode) + + compiled_key = f"compiled_{args.compile_mode}" + # triton_compiled_key = f"triton_compiled_{args.compile_mode}" + + # Benchmark + timings = [] + for lora_A, lora_B in zip(lora_As, lora_Bs): + for bs, x, w in zip(args.batch_sizes, xs, weights): + # ref = dora_ref(x, w, lora_A, lora_B, magnitude_vector) + # test = dora_triton(x, w, lora_A, lora_B, magnitude_vector) + # compiled = dora_compiled(x, w, lora_A, lora_B, magnitude_vector) + # test_compiled = triton_compiled(x, w, lora_A, lora_B, magnitude_vector) + # print(f"triton diff: {(ref - test).abs().max()}") + # print(f"compiled diff: {(ref - compiled).abs().max()}") + # print(f"triton compiled diff: {(ref - test_compiled).abs().max()}") + ref_t = do_bench(lambda: dora_ref(x, w, lora_A, lora_B, magnitude_vector)) + compiled_t = do_bench( + lambda: dora_compiled(x, w, lora_A, lora_B, magnitude_vector) + ) + triton_t = do_bench( + lambda: dora_triton(x, w, lora_A, lora_B, magnitude_vector) + ) + # triton_compiled_t = do_bench( + # lambda: triton_compiled(x, w, lora_A, lora_B, magnitude_vector) + # ) + + # batch_size, seq_len, rank, in_features, out_features, dtype + common_args = [ + bs, + args.seqlen, + lora_A.shape[0], + args.in_features, + args.out_features, + args.dtype, + ] + timings.append([*common_args, "ref", ref_t]) + timings.append([*common_args, compiled_key, compiled_t]) + timings.append([*common_args, "triton", triton_t]) + # timings.append([*common_args, triton_compiled_key, triton_compiled_t]) + + headers = [ + "bs", + "seqlen", + "rank", + "in_features", + "out_features", + "dtype", + "kernel", + "time(ms)", + ] + df = pd.DataFrame(timings, columns=headers) + id_cols = ["bs", "seqlen", "rank", "in_features", "out_features", "dtype"] + + pivot_df( + df, + id_cols=id_cols, + columns="kernel", + values="time(ms)", + column_order=[ + *id_cols, + "ref", + compiled_key, + "triton", + ], # , triton_compiled_key], + show=True, + ) + + +def run_dora_layer_bench(args): + dtype = getattr(torch, args.dtype) + in_features, out_features = args.in_features, args.out_features + xs = make_inputs(args.batch_sizes, args.seqlen, args.in_features, dtype) + base_layer, dora_cls = setup_dora_base_layers( + args.kernel, in_features, out_features, dtype + ) + + timings = [] + layer_key = f"{args.kernel}" + layer_key_fused = f"{args.kernel}-fused" + + for bs, x in zip(args.batch_sizes, xs): + for rank in args.dora_ranks: + dora_layer = dora_cls(base_layer, rank).cuda() + common_args = [ + bs, + args.seqlen, + rank, + args.in_features, + args.out_features, + args.dtype, + ] + ref_t = do_bench(lambda: dora_layer.forward(x)) + fused_t = do_bench(lambda: dora_layer.forward_fused(x)) + timings.append([*common_args, layer_key, ref_t]) + timings.append([*common_args, layer_key_fused, fused_t]) + + headers = [ + "bs", + "seqlen", + "rank", + "in_features", + "out_features", + "dtype", + "layer", + "time(ms)", + ] + df = pd.DataFrame(timings, columns=headers) + id_cols = ["bs", "seqlen", "rank", "in_features", "out_features", "dtype"] + + pivot_df( + df, + id_cols=id_cols, + columns="layer", + values="time(ms)", + column_order=[ + *id_cols, + layer_key, + layer_key_fused, + ], + show=True, + ) + + +def run_bench(args): + print(f"""Running {args.kernel} benchmark with dtype={args.dtype}, batch_sizes={args.batch_sizes}, seqlen={args.seqlen}, + in_features={args.in_features}, out_features={args.out_features}, dora_ranks={args.dora_ranks}""") + if args.kernel == "dora-colnorm": + return run_colnorm_bench(args) + elif args.kernel == "dora-mm-epilogue": + return run_epilogue_bench(args) + elif args.kernel == "dora-full": + return run_full_dora(args) + elif args.kernel == "dora-bnb" or args.kernel == "dora-hqq": + return run_dora_layer_bench(args) + else: + raise ValueError(f"Unknown kernel: {args.kernel}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--kernel", + type=str, + default="dora-mm-epilogue", + choices=( + "dora-colnorm", + "dora-mm-epilogue", + "dora-full", + "dora-bnb", + "dora-hqq", + ), + help="""The kernel to benchmark + + dora-colnorm: Small K GEMM with fused column-norm and magnitude vector multiplication + dora-mm-epilogue: GEMM with fused epilogue elementwise addition and broadcasted scale + dora-full: Full DORA kernel (dora-colnorm + dora-mm-epilogue) + dora-bnb: BNBDoRALinear layer with fused kernels + dora-hqq: HQQDoRALinear layer with fused kernels + """, + ) + parser.add_argument("--seqlen", type=int, default=512) + parser.add_argument( + "--batch_sizes", type=int, nargs="+", default=[1, 2, 4, 8, 16, 32] + ) + parser.add_argument("--dora_ranks", type=int, nargs="+", default=[16, 32, 64]) + parser.add_argument("--in_features", type=int, default=4096) + parser.add_argument("--out_features", type=int, default=4096) + parser.add_argument( + "--dtype", + type=str, + default="float16", + choices=("float16", "bfloat16", "float32"), + ) + parser.add_argument( + "--compile_mode", + type=str, + default="default", + choices=( + "default", + "reduce-overhead", + "max-autotune-no-cudagraphs", + "max-autotune", + ), + ) + + args = parser.parse_args() + run_bench(args) diff --git a/test/dora/test_dora_fusion.py b/test/dora/test_dora_fusion.py new file mode 100644 index 000000000..a7959f85a --- /dev/null +++ b/test/dora/test_dora_fusion.py @@ -0,0 +1,193 @@ +import sys + +import pytest + +if sys.version_info < (3, 11): + pytest.skip("requires Python >= 3.11", allow_module_level=True) + +triton = pytest.importorskip("triton", reason="requires triton") + +import itertools + +import torch + +from torchao.prototype.dora.kernels.matmul import triton_mm +from torchao.prototype.dora.kernels.smallk import triton_mm_small_k + +torch.manual_seed(0) + +# Test configs +M = 4096 +N = 4096 +Ks = [int(2**i) for i in range(4, 7)] + +FUSED_DORA_SHAPES = [(M, N, K) for K in Ks[:1]] + +DTYPES = [torch.float32, torch.float16, torch.bfloat16] + +STORE_ACC = [False] +EPILOGUE_NORM = [True, False] +ADD_SOURCE = [True] +MAGNITUDE_VECTOR = [True] +FUSED_DORA_TEST_CONFIGS = list( + itertools.product( + FUSED_DORA_SHAPES, + STORE_ACC, + EPILOGUE_NORM, + ADD_SOURCE, + MAGNITUDE_VECTOR, + DTYPES, + ) +) + + +def _arg_to_id(arg): + if isinstance(arg, (tuple, list)): + return "x".join([str(x) for x in arg]) + return str(arg) + + +def check(expected, actual, dtype): + if dtype == torch.float32: + atol = 1e-4 + elif dtype == torch.float16: + atol = 1e-3 + elif dtype == torch.bfloat16: + atol = 1e-2 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + diff = (expected - actual).abs().max() + print(f"diff: {diff}") + # assert diff < atol + return diff + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize( + "shape, store_acc, epilogue_norm, add_source, magnitude_vector, dtype", + FUSED_DORA_TEST_CONFIGS, + ids=_arg_to_id, +) +def test_dora_column_norm( + shape, store_acc, epilogue_norm, add_source, magnitude_vector, dtype +): + if not (store_acc or epilogue_norm): + pytest.skip("Either store_acc or epilogue_norm must be True") + + M, N, K = shape + A = torch.randn(M, K, device="cuda", dtype=dtype) + B = torch.randn(K, N, device="cuda", dtype=dtype) + source = torch.randn(M, N, device="cuda", dtype=dtype) + magnitude = torch.randn(M, device="cuda", dtype=dtype) + + c_ref = torch.matmul(A, B) + norm2_ref = 1 / c_ref.norm(2, dim=1) + source_ref = source + c_ref + source_norm2_ref = 1 / (source + c_ref).norm(2, dim=1) + source_norm2_magnitude_ref = magnitude * source_norm2_ref + + # First test small K only kernel, no epilogue + # source = None # source # None + # magnitude = None # magnitude # None + + tt_out = triton_mm_small_k( + A, + B, + source=source if add_source else None, + magnitude=magnitude if magnitude_vector else None, + epilogue_norm=epilogue_norm, + store_acc=store_acc, + ) + + if store_acc: + c_test = tt_out[0] if epilogue_norm else tt_out + if add_source: + check(source_ref, c_test, dtype) + else: + check(c_ref, c_test, dtype) + + if epilogue_norm: + norm2_test = tt_out[1] if store_acc else tt_out + if add_source: + if magnitude_vector: + check(source_norm2_magnitude_ref, norm2_test, dtype) + else: + check(source_norm2_ref, norm2_test, dtype) + else: + check(norm2_ref, norm2_test, dtype) + + +BATCH_SIZES = [int(2**i) for i in range(6)] +SEQ_LENS = [512] +IN_FEATURES = [4096] +OUT_FEATURES = [4096] +FUSED_MATMUL_SHAPES = [ + (bs * seqlen, in_features, out_features) + for bs, seqlen, in_features, out_features in zip( + BATCH_SIZES, SEQ_LENS, IN_FEATURES, OUT_FEATURES + ) +] +EPILOGUE_ELEMENTWISE_ADD = [True] +EPILOGUE_BROADCAST_SCALE = [True] + +FUSED_MATMUL_TEST_CONFIGS = list( + itertools.product( + FUSED_MATMUL_SHAPES[:1], + DTYPES, + EPILOGUE_ELEMENTWISE_ADD, + EPILOGUE_BROADCAST_SCALE, + ) +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize( + "shape, dtype, epilogue_add, epilogue_scale", + FUSED_MATMUL_TEST_CONFIGS, + ids=_arg_to_id, +) +def test_dora_matmul(shape, dtype, epilogue_add, epilogue_scale): + M, K, N = shape + A = torch.randn(M, K, device="cuda", dtype=dtype) + B = torch.randn(K, N, device="cuda", dtype=dtype) + C = torch.randn(M, N, device="cuda", dtype=dtype) if epilogue_add else None + scale = torch.randn(N, device="cuda", dtype=dtype) if epilogue_scale else None + + D_ref = torch.matmul(A, B) + if epilogue_add: + D_ref += C + if epilogue_scale: + D_ref *= scale.unsqueeze(0) + + D_test = triton_mm(A, B, epilogue_source=C, epilogue_scale=scale) + check(D_ref, D_test, dtype) + + +MODES = ["default"] + + +@pytest.mark.skip("TODO: torch.compile does not work with custom kernel") +@pytest.mark.parametrize( + "shape, dtype, epilogue_add, epilogue_scale, mode", + [[*cfg, mode] for cfg in FUSED_MATMUL_TEST_CONFIGS for mode in MODES][:1], + ids=_arg_to_id, +) +def test_dora_matmul_compile(shape, dtype, epilogue_add, epilogue_scale, mode): + M, K, N = shape + A = torch.randn(M, K, device="cuda", dtype=dtype) + B = torch.randn(K, N, device="cuda", dtype=dtype) + C = torch.randn(M, N, device="cuda", dtype=dtype) if epilogue_add else None + scale = torch.randn(N, device="cuda", dtype=dtype) if epilogue_scale else None + + D_ref = torch.matmul(A, B) + if epilogue_add: + D_ref += C + if epilogue_scale: + D_ref *= scale.unsqueeze(0) + + D_test = triton_mm(A, B, epilogue_source=C, epilogue_scale=scale) + check(D_ref, D_test, dtype) + + triton_compiled = torch.compile(triton_mm, mode=mode) + D_compiled = triton_compiled(A, B, epilogue_source=C, epilogue_scale=scale) + check(D_ref, D_compiled, dtype) diff --git a/test/dora/test_dora_layer.py b/test/dora/test_dora_layer.py new file mode 100644 index 000000000..dd38cc8d6 --- /dev/null +++ b/test/dora/test_dora_layer.py @@ -0,0 +1,111 @@ +import sys + +import pytest + +if sys.version_info < (3, 11): + pytest.skip("requires Python >= 3.11", allow_module_level=True) + +bnbnn = pytest.importorskip("bitsandbytes.nn", reason="requires bitsandbytes") +hqq_core = pytest.importorskip("hqq.core.quantize", reason="requires hqq") + +import itertools + +import torch + +# Import modules as opposed to classes directly, otherwise pytest.importorskip always skips +Linear4bit = bnbnn.Linear4bit +BaseQuantizeConfig = hqq_core.BaseQuantizeConfig +HQQLinear = hqq_core.HQQLinear +from torchao.prototype.dora.dora_layer import BNBDoRALinear, DoRALinear, HQQDoRALinear + + +def check(expected, actual, dtype): + if dtype == torch.float32: + atol = 1e-4 + elif dtype == torch.float16: + atol = 1e-3 + elif dtype == torch.bfloat16: + atol = 1e-2 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + diff = (expected - actual).abs().max() + print(f"diff: {diff}") + # assert diff < atol + return diff + + +def _arg_to_id(arg): + if isinstance(arg, (tuple, list)): + return "x".join([str(x) for x in arg]) + return str(arg) + + +BATCH_SIZES = [1] +SEQ_LENS = [512] +DTYPES = [torch.float32, torch.float16, torch.bfloat16] +IN_FEATURES = [4096] +OUT_FEATURES = [4096, 11008] +LORA_RANKS = [16] +MODEL_TYPES = ["DoRALinear", "BNBDoRALinear", "HQQDoRALinear"] + +TEST_CONFIGS = list( + itertools.product( + BATCH_SIZES, + SEQ_LENS, + IN_FEATURES, + OUT_FEATURES, + LORA_RANKS, + DTYPES, + MODEL_TYPES, + ) +) + + +@pytest.mark.parametrize( + "bs, seqlen, in_features, out_features, lora_rank, dtype, model_type", + TEST_CONFIGS, + ids=_arg_to_id, +) +def test_dora_layer( + bs, seqlen, in_features, out_features, lora_rank, dtype, model_type +): + x = torch.randn(bs, seqlen, in_features, dtype=dtype).cuda() + + if model_type == "DoRALinear": + base_layer = torch.nn.Linear( + in_features, out_features, dtype=dtype, bias=False + ).cuda() + dora_cls = DoRALinear + + elif model_type == "BNBDoRALinear": + base_layer = Linear4bit( + input_features=in_features, + output_features=out_features, + bias=False, + quant_type="nf4", + compute_dtype=dtype, + ).cuda() + base_layer.quant_state.dtype = base_layer.compute_dtype + dora_cls = BNBDoRALinear + + elif model_type == "HQQDoRALinear": + quant_config = BaseQuantizeConfig( + nbits=4, + group_size=64, + quant_zero=False, + quant_scale=False, + offload_meta=True, + view_as_float=True, + ) + torch_base = torch.nn.Linear(in_features, out_features, dtype=dtype, bias=False) + base_layer = HQQLinear( + torch_base, + quant_config, + compute_dtype=dtype, + ) + dora_cls = HQQDoRALinear + dora_layer = dora_cls(base_layer, lora_rank).cuda() + + ref = dora_layer.forward(x) + test = dora_layer.forward_fused(x) + check(ref, test, dtype) diff --git a/torchao/prototype/common/__init__.py b/torchao/prototype/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchao/prototype/common/profiling_tools.py b/torchao/prototype/common/profiling_tools.py new file mode 100644 index 000000000..607895d4e --- /dev/null +++ b/torchao/prototype/common/profiling_tools.py @@ -0,0 +1,268 @@ +import os +import types +from datetime import datetime +from functools import partial + +import pandas as pd +import torch +import torch.autograd.profiler_util +from tabulate import tabulate +from torch.autograd.profiler import record_function +from torch.cuda.nvtx import range as nvtx_range +from triton.testing import do_bench + +# from torch.cuda.nvtx import range_pop, range_push + +TIME_FORMAT_STR: str = "%m_%d" +PROFILE_DIR = "./profiles" + + +def simple_bench(fn, *args, **kwargs): + t = do_bench(lambda: fn(*args, **kwargs)) + return t + + +def check(expected, actual, atol=1e-3): + diff = (expected - actual).abs().max() + print(f"diff: {diff}") + # assert diff < atol + + +def benchmark_mm( + test_fn, xs, weight, ref_fn=torch.matmul, headers=["M", "K", "N", "test", "ref"] +): + timings = [] + for x in xs: + M, K = x.shape + _, N = weight.shape + assert x.shape[1] == weight.shape[0] + print(f"Benchmarking {(M, K, N)}") + test_times = do_bench(lambda: test_fn(x, weight)) + ref_times = do_bench(lambda: ref_fn(x, weight)) + timings.append([M, K, N, test_times, ref_times]) + return pd.DataFrame(timings, columns=headers) + + +def run_bench(xs, weight): + df = benchmark_mm(xs, weight) + print(tabulate(df, headers="keys", floatfmt=".4f")) + return df + + +class CudaProfilerCtx: + def __enter__(self): + print("Starting cuda profiler") + torch.cuda.cudart().cudaProfilerStart() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback) -> None: + print("Stopping cuda profiler") + torch.cuda.cudart().cudaProfilerStop() + if exc_type is not None: + print(f"Exception occurred: {exc_type}, {exc_value}") + # Return True to suppress the exception + return True + + def step(self): + pass + + +def trace_handler( + prof: torch.profiler.profile, + group_by_stack: int = 5, + group_by_input_shapes: bool = False, + prefix="", + out_dir=None, + export_events=False, + export_trace=True, + export_memory_timeline=False, +): + # Prefix for file names. + out_dir = out_dir or PROFILE_DIR + timestamp = datetime.now().strftime(TIME_FORMAT_STR) + file_prefix = os.path.join(out_dir, f"{prefix}-{timestamp}") + + if export_events: + evt_list = prof.key_averages( + group_by_stack_n=group_by_stack, group_by_input_shape=group_by_input_shapes + ) + torch.save(evt_list, f"{file_prefix}-key_averages.pt") + + # Construct the trace file. + if export_trace: + prof.export_chrome_trace(f"{file_prefix}-chrome-trace.json") + + # Construct the memory timeline file. + if export_memory_timeline: + prof.export_memory_timeline( + f"{file_prefix}-memory-timeline.html", device="cuda:0" + ) + prof.export_memory_timeline( + f"{file_prefix}-memory-timeline.json", device="cuda:0" + ) + + +# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)) + + +def get_torch_profiler( + name, + with_stack=True, + with_flops=True, + with_modules=True, + record_shapes=False, + export_events=False, + export_trace=True, + export_memory_timeline=False, + out_dir=None, + warmup=1, + active=5, +): + if not os.path.exists(out_dir): + os.makedirs(out_dir) + callback = partial( + trace_handler, + prefix=name, + out_dir=out_dir, + group_by_input_shapes=record_shapes, + group_by_stack=5 if export_events else None, + export_events=export_events, + export_trace=export_trace, + export_memory_timeline=export_memory_timeline, + ) + return torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=record_shapes, + with_stack=with_stack, + with_flops=with_flops, + with_modules=with_modules, + profile_memory=export_memory_timeline, + schedule=torch.profiler.schedule(wait=0, warmup=warmup, active=active), + on_trace_ready=callback, + ) + + +class TorchProfilerCtx: + @staticmethod + def profiler( + name, + out_dir, + warmup=1, + active=5, + record_shapes=False, + with_stack=True, + export_events=False, + export_trace=True, + export_memory_timeline=False, + ): + return get_torch_profiler( + name, + with_stack=with_stack, + record_shapes=export_memory_timeline or record_shapes, + export_events=export_events, + export_trace=export_trace, + export_memory_timeline=export_memory_timeline, + out_dir=out_dir, + warmup=warmup, + active=active, + ) + + +def get_annotation_ctx(profiler_type): + assert profiler_type in ["nsys", "torch"] + if profiler_type == "nsys": + return nvtx_range + else: + return record_function + + +_PERF_COLUMNS = [ + "key", + "count", + "cpu_children", + "cpu_parent", + "self_device_time_total", + "cuda_time", + "flops", + "self_cpu_time", + "self_cpu_time_total", + "cpu_time", + "cpu_time_total" "self_device_memory_usage", + "device_memory_usage", + "self_cpu_memory_usage", + "cpu_memory_usage", +] +PERF_COLS_SELECT = [ + "key", + "cpu_parent", + "cpu_children", + # "self_cpu_time", + # "self_cpu_time_total", + "cpu_time", + "cpu_time_total", + "cuda_time", + "self_device_time_total", +] + + +# cuda_time, cpu_time are avg times -- corresponds to CUDA time avg and CPU time avg in table() above +# "self" times is not meaningful for annotated regions, since they only have child regions +def is_function(obj): + return isinstance(obj, types.FunctionType) + + +def is_method(obj): + return isinstance(obj, types.MethodType) + + +def is_private(prop): + return prop.startswith("_") + + +def should_exclude(obj, prop): + return ( + is_function(getattr(obj, prop)) + or is_method(getattr(obj, prop)) + or is_private(prop) + ) + + +def _get_event_props(event: torch.autograd.profiler_util.FunctionEvent): + props = [p for p in dir(event) if not should_exclude(event, p)] + return props + + +def get_events_df(events: torch.autograd.profiler_util.EventList): + event_props = _get_event_props(events[0]) + data = [{p: getattr(e, p) for p in event_props} for e in events] + return pd.DataFrame(data) + + +def get_perf_df(events: torch.autograd.profiler_util.EventList, sort=True): + df = get_events_df(events).filter(PERF_COLS_SELECT) + if sort: + df = df.sort_values(["cpu_time", "cuda_time"], ascending=False) + return df + + +def pivot_df( + df, + id_cols: str | list[str], + columns: str | list[str], + values: str | list[str], + column_order: list[str] = None, + show: bool = True, +): + df = df.pivot_table( + index=id_cols, + columns=columns, + values=values, + ).reset_index() + if column_order is not None: + df = df[column_order] + if show: + print(df.to_string(index=False)) + return df diff --git a/torchao/prototype/dora/README.md b/torchao/prototype/dora/README.md new file mode 100644 index 000000000..d5bebc68a --- /dev/null +++ b/torchao/prototype/dora/README.md @@ -0,0 +1,164 @@ +## Fused DoRA Kernels + +Fused DoRA layer implementation that reduces number of individual kernels from ~10 -> 5. + +## Contents + +- [Background](#background) +- [Optimization](#optimization) +- [Key Contributions](#key-contributions) +- [Usage](#usage) +- [Tests](#tests) +- [Benchmarks](#benchmarks) +- [Profiling](#profiling) + +## Background + +[DoRA](https://arxiv.org/abs/2402.09353) (weight-decomposed low-rank adaptation) is a variant of LoRA that decomposes the LoRA update into magnitude and vector components. + +The DoRA layer is roughly as follows: + +```python + dora_out = (x @ base_weight.T + lora_out) * magnitude_scale +``` + +where: + +```python + lora_out = lora_B(lora_A(x)) + magnitude_scale = magnitude_vector / (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1) +``` + +- `lora_A` and `lora_B` are `linear` layers with weight shapes `rank x in_features` and `out_features x rank`. +- `base_weight` is the weight of the frozen `linear` layer of shape `out_features x in_features`. +- `magnitude_vector` is initialized as the columnwise `2-norm` of the frozen weight (shape `out-features`). +- `x` are the inputs of shape `batch_size x seqlen x in_features` + +## Optimization + +After initial profiling, and as outlined above, the `DoRA` update layer requires multiple kernels. + +In order of compute intensity: + +- 4 GEMMs: + - `x @ base_weight` + - `lora_B(lora_A(x))` + - `lora_B.weight @ lora_A.weight` +- 1 Reduction: `2-norm` +- 4 Elementwise: matrix-matrix additions (2) and broadcasted matrix-vector multiplications (2). + +While `torch.compile` (and `CUDA` graphs) can partially mitigate the overhead of multiple small kernels and improve compute efficiency of individual kernels, there remains room for additional optimization by reordering the computations to facilitate fusions, and more importantly, exploiting the unique shapes of the GEMMs, thereby decreasing the number of kernel launches and increasing the compute intensity of each kernel. + +## Key Contributions + +**1 - Small K Fused Kernel** + +Note that the `lora_B.weight @ lora_A.weight` has a specific shape, where `K << {M, N}`. That is, `lora_B.weight` is `out_features x lora_rank` and `lora_A.weight` is `lora_rank x in_features`. + +Since `lora_rank` is typically `< 64` while `{in,out}-features` are typically `> 4096` (e.g., `Llama MLP / QKV projections`), this `GEMM` is inefficient, since each `CTA` loads a block, only to perform a few `MAC` iterations given small `K`. + +Moreover, note that the result of this `GEMM` is not needed -- we only need the `2-norm` of this computation. + +Combining these two observations, we can write a fused kernel where: + +1. Each `CTA` computes an _entire_ row of the output matrix, with the key assumption that `BLOCK_K = K`. That is, each `CTA` does a single MAC iteration to compute a `BLOCK_M x BLOCK_N` output, then iterates across dimension `N`. +2. Since each block processes an entire row, we can now additionally fuse a grid-wise reduction along `axis=1` into the kernel. In this case, we can directly fold the `2-norm` computation into the `GEMM`. +3. As an added bonus, we can also include the `base_weight` elementwise addition and `magnitude_vector` multiplication into the `GEMM` epilogue. + +Altogether, this allows us to fuse the following computation into a single kernel: + +```python + magnitude_scale = magnitude_vector / (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1) +``` + +**2 - Fused Epilogue GEMM** + +Additionally, instead of computing the base layer output before the `DoRA / LoRA` updates, we can compute the latter (`loRA layer` and `magnitude_scale`) first, and fold these into the epilogue of the base layer `GEMM`: + +```python + + #DoRA / LoRA updates + lora_out = lora_B(lora_A(x)) + magnitude_scale = magnitude_vector / (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1) + + #This is now a single kernel + final_out = (x @ base_weight.T + lora_out) * magnitude_scale +``` + +## Usage + +The fused kernels can be used to implement `DoRA` / `QDoRA` layers. + +A reference implementation is provided in `dora.dora_layer.DoRALinear`, which defines a base `QDoRA` linear layer (with a stub `dequantize` method) along with corresponding `BNBDoRALinear` and `HQQDoRALinear` subclasses, which override `dequantize` with their respective methods. + +_Example_ + +```python + import torch + from bitsandbytes.nn import Linear4bit + from torchao.prototypes.dora.dora_layer import BNBDoRALinear + + bs, seqlen = 1, 512 + dtype = torch.float16 + in_features, out_features, lora_rank = 4096, 4096, 16 + x = torch.randn(bs, seqlen, in_features, dtype=dtype, device="cuda") + + #Construct bitsnbytes QDoRA layer + base_layer = Linear4bit( + input_features=in_features, + output_features=out_features, + bias=False, + quant_type="nf4", + compute_dtype=dtype, + ).cuda() + base_layer.quant_state.dtype = base_layer.compute_dtype + dora_layer = BNBDoRALinear(base_layer, lora_rank) + + #Run reference forward pass + ref = dora_layer.forward(x) + + #Run fused forward pass + fused_out = dora_layer.forward_fused(x) +``` + +See `test/test_dora_layer.py` and `benchmarks/dora_bench.py` for more detailed usage. + +### Tests + +See `test/dora/test*`, for correctness checks of the fused kernels and layers. + +## Benchmarks + +See `benchmarks/dora_bench.py`. + +```python +python benchmarks/dora_bench.py --help +``` + +Run with flag `--kernel` set to one of `{dora-colnorm,dora-mm-epilogue}`, to benchmark the respective fused kernels against a reference `torch` / `torch.compile` implementation, or `--kernel=dora-full` to bench against the entire `DoRA` computation. + +Additionally, passing either `--kernel={dora-bnb, dora-hqq}` will bench a reference `QDoRA` layer against their fused implementations. + +## Profiling + +The reference `DoRALinear` layer described above also has an instrumented forward pass with annotated regions for each of the `DoRA` ops. + +An example script for running a profiled forward pass is provided in `dora/dora_profile.py`. + +To run with `torch.profiler`: + +``` +python dora_profile.py +``` + +which outputs chrome trace to default folder `dora_profiles`. + +To run with `nsys`: + +``` +nsys profile --capture_range=cudaProfilerApi ... python dora_profile.py --profiler=nsys +``` + +where `...` are other desired `nsys` options. + +Note that `--capture_range=cudaProfilerApi` is required. diff --git a/torchao/prototype/dora/__init__.py b/torchao/prototype/dora/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchao/prototype/dora/dora_layer.py b/torchao/prototype/dora/dora_layer.py new file mode 100644 index 000000000..e0c97cdcb --- /dev/null +++ b/torchao/prototype/dora/dora_layer.py @@ -0,0 +1,194 @@ +import logging + +import bitsandbytes as bnb +import torch +import torch.nn as nn +from bitsandbytes.nn import Linear4bit +from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear + +from prototypes.dora.kernels.matmul import triton_mm +from prototypes.dora.kernels.smallk import triton_mm_small_k + +logger = logging.getLogger(__name__) + + +# Adapted from https://github.com/AnswerDotAI/fsdp_qlora/blob/dora/scripts/dora.py +class DoRALayer(nn.Module): + """DoRA Update""" + + def __init__( + self, in_features, out_features, lora_rank, device, dtype, *args, **kwargs + ): + super().__init__() + + # LoRA layers + std_dev = 1 / torch.sqrt(torch.tensor(lora_rank).float()) + lora_A_param = nn.Parameter( + torch.randn(lora_rank, in_features).to(device=device, dtype=dtype) * std_dev + ) + self.lora_A = nn.Linear( + in_features, lora_rank, bias=False, device=device, dtype=dtype + ) + setattr(self.lora_A, "weight", lora_A_param) + + self.lora_B = nn.Linear( + lora_rank, out_features, bias=False, device=device, dtype=dtype + ) + self.lora_B.weight.data.zero_() + + def forward(self, x, base_weight): + # LoRA update, shape `bs x seq_len x in-features @ in-features x lora-rank @ lora-rank x out-features = bs x seq_len x out-features` + output = self.lora_B(self.lora_A(x)) + + # DoRA Section 4.3. Column norm no gradient update. + column_norm = ( + (base_weight + self.lora_B.weight @ self.lora_A.weight) + .norm(p=2, dim=1) + .detach() + ) + + return output, column_norm + + +class DoRALinear(nn.Module): + """Reference DoRA Update Layer + + out = (x @ base_weight + lora_out) * magnitude_scale + where: + `lora_out = lora_B(lora_A(x)` + `magnitude_scale = (base_weight + lora_B @ lora_A).norm(p=2, dim=1) * magnitude_vector` + + base_weight is the weight of the frozen `linear` layer of shape `out_features x in_features`. + + In QDoRA, the base weight is quantized and needs an additional dequantization step. + In this base DoRA layer, a placeholder (no-op) `dequantize` method stub is provided, which simply + returns the base weight. + + For `bnb` and `hqq`, the respective `dequantize` method can be substituted. + """ + + def __init__(self, base_layer, lora_rank, *args, **kwargs): + super().__init__() + + # Get original (dequantized) weight dtype + dtype = getattr( + base_layer, "compute_dtype", next(base_layer.parameters()).dtype + ) + device = next(base_layer.parameters()).device + self.base_layer = base_layer + + # Initialize magnitude vec - TODO: this is clunky, better way to init? + base_weight = self.dequantize().clone().cuda() + self.magnitude_vec = nn.Parameter(base_weight.norm(p=2, dim=1)) + + del base_weight + torch.cuda.empty_cache() + + # DoRA layer + self.dora_layer = DoRALayer( + base_layer.in_features, + base_layer.out_features, + lora_rank, + device, + dtype, + *args, + **kwargs, + ) + + def dequantize(self): + return self.base_layer.weight + + def forward(self, x, *args, **kwargs): + # Out shape is either bs, seqlen, out_features or bs * seqlen, out_features + assert x.ndim == 2 or x.ndim == 3, "Expected 2D or 3D input" + dq_base_weight = self.dequantize() + out_shape = [*x.shape[:-1], dq_base_weight.shape[0]] + # Reshape to (bs * seqlen, out_features) + x = x.reshape(-1, x.shape[-1]) + + # LoRA update + lora_A_weight = self.dora_layer.lora_A.weight + lora_B_weight = self.dora_layer.lora_B.weight + lora_out = (x @ lora_A_weight.T) @ lora_B_weight.T + + # DoRA magnitude scale + column_norm = (dq_base_weight + lora_B_weight @ lora_A_weight).norm(p=2, dim=1) + magnitude_scale = self.magnitude_vec / column_norm + + # DoRA update + dora_out = (x @ dq_base_weight.T + lora_out) * magnitude_scale[None, :] + dora_out = dora_out.reshape(*out_shape) + + return dora_out + + def forward_fused(self, x, *args, **kwargs): + """Reorders computation as well employs two fused kernels to speed up computation. + + See README.md for description of fused kernels. + """ + assert x.ndim == 2 or x.ndim == 3, "Expected 2D or 3D input" + + dq_base_weight = self.dequantize() + # Out shape is either bs, seqlen, out_features or bs * seqlen, out_features + out_shape = [*x.shape[:-1], dq_base_weight.shape[0]] + # Reshape to (bs * seqlen, out_features) + x = x.reshape(-1, x.shape[-1]) + + # LoRA update + lora_A_weight = self.dora_layer.lora_A.weight + lora_B_weight = self.dora_layer.lora_B.weight + lora_out = (x @ lora_A_weight.T) @ lora_B_weight.T + + # DoRA magnitude + # Fused kernel #1: `magnitude_scale = (base_weight + lora_B @ lora_A).norm(p=2, dim=1) * magnitude_vector` + magnitude_scale = triton_mm_small_k( + lora_B_weight, + lora_A_weight, + epilogue_norm=True, + source=dq_base_weight, + magnitude=self.magnitude_vec, + store_acc=False, + ) + # DoRA update + # Fused kernel #2: `out = (x @ base_weight + lora_out) * magnitude_scale` + dora_out = triton_mm( + x, + dq_base_weight.T, + epilogue_source=lora_out, + epilogue_scale=magnitude_scale, + ) + dora_out = dora_out.reshape(out_shape) + + return dora_out + + # For profiling + def forward_instrumented(self, x, *args, **kwargs): + annotation_ctx = kwargs.pop("annotation_ctx") + with annotation_ctx("##dora_forward"): + with annotation_ctx("##base_layer"): + result = self.base_layer(x, *args, **kwargs) + + with annotation_ctx("##dora_layer"): + dq_weight = self.dequantize() + output, column_norm = self.dora_layer(x, dq_weight) + + with annotation_ctx("##dora_rescale"): + result += output + result = result / column_norm.view(1, 1, -1) + result = result * self.magnitude_vec.view(1, 1, -1) + + return result + + +class BNBDoRALinear(DoRALinear): + def dequantize(self): + return bnb.functional.dequantize_4bit( + self.base_layer.weight.data, self.base_layer.weight.quant_state + ) + + +class HQQDoRALinear(DoRALinear): + def dequantize(self): + return self.base_layer.dequantize() + + diff --git a/torchao/prototype/dora/dora_profile.py b/torchao/prototype/dora/dora_profile.py new file mode 100644 index 000000000..bf8776974 --- /dev/null +++ b/torchao/prototype/dora/dora_profile.py @@ -0,0 +1,124 @@ +import argparse + +import torch +from bitsandbytes.nn import Linear4bit +from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear + +from torchao.prototype.common.profiling_tools import ( + CudaProfilerCtx, + TorchProfilerCtx, + get_annotation_ctx, +) +from torchao.prototype.dora.dora_layer import BNBDoRALinear, DoRALinear, HQQDoRALinear + + +def run_profile(args, dora_forward): + if args.profiler == "nsys": + profiler = CudaProfilerCtx() + else: + profiler = TorchProfilerCtx.profiler( + f"dora_layer-{args.layer_type}", + active=max(5, args.num_iterations), + warmup=0, + out_dir=args.outdir, + ) + + annotation_ctx = get_annotation_ctx(args.profiler) + + x = torch.randn( + args.bs, args.seqlen, args.in_features, dtype=getattr(torch, args.dtype) + ).cuda() + for _ in range(args.warmup): + _ = dora_forward(x, annotation_ctx=annotation_ctx) + + with profiler as prof: + for _ in range(args.num_iterations): + _ = dora_forward(x, annotation_ctx=annotation_ctx) + prof.step() + print(f"Finished profiling, saving results to {args.outdir}") + + +def run(args): + in_features, out_features = args.in_features, args.out_features + dora_rank = args.dora_rank + dtype = getattr(torch, args.dtype) + + base_layer = torch.nn.Linear( + in_features, out_features, dtype=dtype, bias=False + ).cuda() + + if args.layer_type == "torch": + dora_layer = DoRALinear(base_layer=base_layer, lora_rank=dora_rank) + elif args.layer_type == "bnb": + base_layer = Linear4bit( + input_features=in_features, + output_features=out_features, + bias=False, + quant_type="nf4", + compute_dtype=dtype, + ) + base_layer.quant_state.dtype = base_layer.compute_dtype + dora_layer = BNBDoRALinear(base_layer=base_layer, lora_rank=dora_rank) + elif args.layer_type == "hqq": + quant_config = BaseQuantizeConfig( + nbits=4, + group_size=64, + quant_zero=False, + quant_scale=False, + offload_meta=True, + view_as_float=True, + ) + + base_layer = HQQLinear( + base_layer, + quant_config, + compute_dtype=dtype, + ) + + base_layer.set_backend(HQQBackend.PYTORCH) + dora_layer = HQQDoRALinear(base_layer=base_layer, lora_rank=dora_rank) + + run_profile(args, dora_layer.forward_instrumented) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--profiler", + type=str, + default="torch", + choices=("nsys", "torch"), + help=""" + Which profiler to use + + Default is the torch.profiler + + If using `nsys`, run the nsys profiler as so, substituting with other desired nsys options: + `nsys profile --capture-range=cudaProfilerApi ... python dora_profile.py --profiler=nsys` + + Note that `--capture-range=cudaProfilerApi` is required + """, + ) + parser.add_argument( + "--layer_type", + type=str, + default="torch", + choices=("torch", "bnb", "hqq"), + ) + parser.add_argument("--in_features", type=int, default=4096) + parser.add_argument("--out_features", type=int, default=4096) + parser.add_argument("--dora_rank", type=int, default=16) + parser.add_argument("--bs", type=int, default=1) + parser.add_argument("--seqlen", type=int, default=512) + parser.add_argument( + "--dtype", + type=str, + default="float16", + choices=("float16", "bfloat16", "float32"), + ) + parser.add_argument("--num_iterations", type=int, default=10) + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--outdir", type=str, default="./dora_profiles") + run(parser.parse_args()) diff --git a/torchao/prototype/dora/kernels/__init__.py b/torchao/prototype/dora/kernels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchao/prototype/dora/kernels/common.py b/torchao/prototype/dora/kernels/common.py new file mode 100644 index 000000000..cd0950d4c --- /dev/null +++ b/torchao/prototype/dora/kernels/common.py @@ -0,0 +1,176 @@ +from enum import Enum, StrEnum, unique + +import torch +import triton +import triton.language as tl + +# Re-exports +from triton.ops.matmul import ( + early_config_prune, + estimate_matmul_time, + get_configs_io_bound, + get_higher_dtype, +) +from triton.runtime import Config + + +@unique +class SwizzleType(Enum): + GROUPED = 0 + COLUMN_MAJOR = 1 + ROW_MAJOR = 2 + + +class TritonInputPrecision(StrEnum): + IEEE: str = "ieee" + TF32: str = "tf32" + TF32X3: str = "tf32x3" + + +TRITON_SUPPORTED_ACC_TYPES = { + torch.float16: (torch.float32, torch.float16), + torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32,), + torch.int8: (torch.int32,), +} + +MATMUL_HEURISTICS = { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + "SPLIT_K": lambda args: 1 + if (args["A"].dtype == torch.bfloat16 or args["B"].dtype == torch.bfloat16) + else args["SPLIT_K"], # atomic add not supported for bfloat16 +} + + +def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + + +def get_compute_bound_configs(): + configs = [ + # basic configs for compute-bound matmuls + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + # good for int8 + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + ] + return configs + + +@triton.jit() +def swizzle_tile( + pid, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_M: tl.constexpr, + SWIZZLE: tl.constexpr, +): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + + if SWIZZLE == tl.constexpr(SwizzleType.GROUPED): + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + elif SWIZZLE == tl.constexpr(SwizzleType.COLUMN_MAJOR): + pid_m = pid % grid_m + pid_n = pid // grid_m + elif SWIZZLE == tl.constexpr(SwizzleType.ROW_MAJOR): + pid_m = pid // grid_n + pid_n = pid % grid_n + else: + tl.static_assert(False, "swizzle type not supported") + + return pid_m, pid_n diff --git a/torchao/prototype/dora/kernels/custom_autotune.py b/torchao/prototype/dora/kernels/custom_autotune.py new file mode 100644 index 000000000..f67152068 --- /dev/null +++ b/torchao/prototype/dora/kernels/custom_autotune.py @@ -0,0 +1,395 @@ +from __future__ import annotations + +import builtins +import logging +import os +import time +from typing import Dict + +import numpy as np +from triton.runtime.cache import default_cache_dir +from triton.runtime.errors import OutOfResources +from triton.runtime.jit import KernelInterface +from triton.testing import do_bench + +logger = logging.getLogger(__file__) + + +class Autotuner(KernelInterface): + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + prune_configs_by: Dict = None, + warmup=25, + rep=100, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_idx = [] + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + self.restore_idx = [] + if restore_value is not None: + self.restore_idx = [arg_names.index(k) for k in restore_value] + + # Hook to reset or restore for required tensors + self.pre_hook = lambda args, reset_only=False: 0 + self.post_hook = lambda args: 0 + if len(self.reset_idx) > 0 or len(self.restore_idx) > 0: + + def _pre_hook(args, reset_only=False): + for i in self.reset_idx: + args[i].zero_() + if not reset_only: + self.restore_copies = [args[i].clone() for i in self.restore_idx] + + self.pre_hook = _pre_hook + if len(self.restore_idx) > 0: + + def _post_hook(args): + for i, j in enumerate(self.restore_idx): + args[j].copy_(self.restore_copies[i]) + self.restore_copies = [] + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get( + "early_config_prune", self.early_config_prune + ) + + self.fn = fn + self.num_warmups = warmup + self.num_reps = rep + # self.autotune_log_path = os.path.join(default_cache_dir(), autotune_log_file) + self.kernel_name = self._find_kernel_name() + + def _find_kernel_name(self): + try: + kernel_name = self.fn.__name__ + except AttributeError: + try: # in case JITfn is wrapped in both autotune and heuristic + kernel_name = self.fn.fn.__name__ + except: # noqa + kernel_name = self.fn.__name__ + return kernel_name + + def _get_key_combination(self, args, as_str=True, sep=" "): + key_vals = [f"{self.arg_names[i]}={args[i]}" for i in self.key_idx] + return f"{sep}".join(key_vals) if as_str else key_vals + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(args) + self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + num_ctas=config.num_ctas, + **current, + ) + self.post_hook(args) + + try: + return do_bench( + kernel_call, + warmup=self.num_warmups, + rep=self.num_reps, + quantiles=(0.5, 0.2, 0.8), + ) + except OutOfResources: + return [float("inf"), float("inf"), float("inf")] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + logger.debug(f"Autotune Num Configs: {len(self.configs)}") + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = [] + for name in self.arg_names: + if name in all_args: + _args.append(all_args[name]) + key = [_args[i] for i in self.key_idx] + for arg in _args: + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + logger.debug( + f"\n==== Autotune ====\nRunning autotune for {self.kernel_name} for {len(self.configs)} total configs" + f" for key combination {self._get_key_combination(args)}..." + ) + # prune configs + pruned_configs = self.prune_configs(kwargs) + logger.debug(f"\nNum configs after pruning {len(pruned_configs)}") + bench_start = time.time() + timings = {} + for config in pruned_configs: + timings[config] = self._bench(*args, config=config, **kwargs) + # timings = { + # config: self._bench(*args, config=config, **kwargs) + # for config in pruned_configs + # } + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.pre_hook(args, reset_only=True) + self.configs_timings = timings + + sorted_timings = dict( + sorted(timings.items(), key=lambda x: np.mean(x[1])) + ) + _key_suffix = self._get_key_combination(args, sep="-") + autotune_file = f"autotune_{self.kernel_name}_{_key_suffix}.log" + autotune_log_path = os.path.join(default_cache_dir(), autotune_file) + + logger.debug(f"\nFinished autotune, writing log to {autotune_log_path}") + + with open(f"{autotune_log_path}", "w") as f: + f.write( + f" ==== Autotune Results ====\nKernel name: {self.kernel_name}\nArgs: {self.arg_names}\nKeys: {self._get_key_combination(args)}\n" + ) + f.write(f"\nPruned configs:\n") + for cfg in pruned_configs: + f.write(f"{cfg}\n") + f.write(f"Timings:\n") + for cfg, timing in sorted_timings.items(): + f.write(f"{cfg} {timing} \n") + f.write(f"Best config: {self.cache[key]}\n") + else: + logger.debug( + f"Key {key} for {self.kernel_name} already in cache, skipping autotune\n" + ) + + config = self.cache[key] + # logger.debug(f"\nAutotune: Cache hit! Running best config...") + else: + config = self.configs[0] + self.best_config = config + logger.debug(f"\nAutotune Best Config: {config}\n") + + full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} + if config.pre_hook is not None: + config.pre_hook(full_nargs) + ret = self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + num_ctas=config.num_ctas, + **kwargs, + **config.kwargs, + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps, + num_ctas=config.num_ctas, + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ + :top_k + ] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append( + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_ctas=config.num_ctas, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + ) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type meta: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.pre_hook = pre_hook + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + return ", ".join(res) + + +def autotune( + configs, + key, + prune_configs_by=None, + reset_to_zero=None, + restore_value=None, + warmup=25, + rep=100, +): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. + :type warmup: int + :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. + :type rep: int + """ + + def decorator(fn): + return Autotuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + restore_value, + prune_configs_by, + warmup, + rep, + ) + + return decorator + + +class Heuristics(KernelInterface): + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/torchao/prototype/dora/kernels/matmul.py b/torchao/prototype/dora/kernels/matmul.py new file mode 100644 index 000000000..66e5ef77e --- /dev/null +++ b/torchao/prototype/dora/kernels/matmul.py @@ -0,0 +1,259 @@ +import logging + +import torch +import triton +import triton.language as tl + +from .common import ( + MATMUL_HEURISTICS, + TRITON_SUPPORTED_ACC_TYPES, + SwizzleType, + TritonInputPrecision, + early_config_prune, + estimate_matmul_time, + get_compute_bound_configs, + get_configs_io_bound, + get_higher_dtype, + swizzle_tile, + to_tl_type, +) +from .custom_autotune import autotune + +logger = logging.getLogger(__name__) + + +_AUTOTUNE_TOPK = 10 + + +@autotune( + get_compute_bound_configs() + get_configs_io_bound(), + key=["M", "N", "K"], + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": estimate_matmul_time, + "top_k": _AUTOTUNE_TOPK, + }, +) +@triton.heuristics( + { + "EVEN_K": MATMUL_HEURISTICS["EVEN_K"], + "SPLIT_K": MATMUL_HEURISTICS["SPLIT_K"], + } +) +@triton.jit +def _matmul_kernel( + A, + B, + C, + M, + N, + K, # + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, # + acc_dtype: tl.constexpr, # + input_precision: tl.constexpr, # + fp8_fast_accum: tl.constexpr, # + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, # + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + AB_DTYPE: tl.constexpr, # + SWIZZLE: tl.constexpr, + EPILOGUE_ELEMENTWISE_ADD: tl.constexpr = False, + Epilogue_source=None, + EPILOGUE_BROADCAST_SCALE: tl.constexpr = False, + Epilogue_scale=None, +): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + + # Threadblock swizzle + pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_M, BLOCK_N, GROUP_M, SWIZZLE) + + # Operand offsets + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + + # Operand pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # Allocate accumulator + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + + # MAC Loop + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + if fp8_fast_accum: + acc = tl.dot( + a, b, acc, out_dtype=acc_dtype, input_precision=input_precision + ) + else: + acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) + + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + # Convert acc to output dtype + acc = acc.to(C.dtype.element_ty) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + # mask = (rm < M)[:, None] & (rn < N)[None, :] + mask_m = (rm < M)[:, None] + mask_n = (rn < N)[None, :] + if EPILOGUE_ELEMENTWISE_ADD: + Epilogue_source = Epilogue_source + ( + rm[:, None] * stride_cm + rn[None, :] * stride_cn + ) + source = tl.load(Epilogue_source, mask=mask_m & mask_n) + acc += source + if EPILOGUE_BROADCAST_SCALE: + Epilogue_scale = Epilogue_scale + (rn[None, :]) + scale = tl.load(Epilogue_scale, mask=mask_n) + acc *= scale + + if SPLIT_K == 1: + tl.store(C, acc, mask=mask_m & mask_n) + else: + tl.atomic_add(C, acc, mask=mask_m & mask_n) + + +def triton_mm( + a, + b, + epilogue_source=None, + epilogue_scale=None, + acc_dtype=None, + input_precision=TritonInputPrecision.IEEE, + fp8_fast_accum=False, + output_dtype=None, + swizzle: SwizzleType = SwizzleType.GROUPED, + GROUP_M: int = 8, +): + """Triton GEMM implementation, `D = AB + C` + + Based on `triton.ops.matmul`, with the addition of epilogue. + + Args: + a (torch.Tensor): operand A + b (torch.Tensor): operand B + epilogue_source(optional, torch.Tensor): operand C in `D = AB + C` + epilogue_scale(optional, torch.Tensor): row-wise scale-vector of dim `N` in `D = scale * (AB + C)` + acc_dtype (torch.DType): accumulator type in MAC loop + input_precision (TritonInputPrecision): precision to use for fp32 matmul + fp8_fast_accum (bool) + output_dtype (optional, torch.DType): output type of the GEMM, defaults to higher dtype of A / B + + Returns: + torch.Tensor: `D = AB + C` + """ + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + # allocates output + if output_dtype is None: + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) + + # Epilogue pre-conditions + # TODO Check strides? + if epilogue_source is not None: + assert epilogue_source.shape == (M, N), "incompatible dimensions" + assert epilogue_source.dtype == c.dtype, "incompatible dtype" + + if epilogue_scale is not None: + assert ( + epilogue_scale.ndim == 1 and epilogue_scale.shape[0] == N + ), "incompatible dimensions" + assert epilogue_scale.dtype == c.dtype, "incompatible dtype" + + # choose accumulator type + if acc_dtype is None: + acc_dtype = TRITON_SUPPORTED_ACC_TYPES[ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert ( + acc_dtype in TRITON_SUPPORTED_ACC_TYPES[a.dtype] + ), "acc_dtype not compatible with the type of a" + assert ( + acc_dtype in TRITON_SUPPORTED_ACC_TYPES[b.dtype] + ), "acc_dtype not compatible with the type of b" + + # convert to triton types + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Tensor cores support input with mixed float8 types. + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ + tl.float8e4nv, + tl.float8e5, + ]: + ab_dtype = None + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + META["SPLIT_K"], + ) + + _matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + acc_dtype=acc_dtype, # + input_precision=input_precision, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=GROUP_M, + AB_DTYPE=ab_dtype, + SWIZZLE=swizzle, + EPILOGUE_ELEMENTWISE_ADD=epilogue_source is not None, + Epilogue_source=epilogue_source, + EPILOGUE_BROADCAST_SCALE=epilogue_scale is not None, + Epilogue_scale=epilogue_scale, + ) + return c diff --git a/torchao/prototype/dora/kernels/smallk.py b/torchao/prototype/dora/kernels/smallk.py new file mode 100644 index 000000000..fc24ea223 --- /dev/null +++ b/torchao/prototype/dora/kernels/smallk.py @@ -0,0 +1,545 @@ +import heapq +import logging +from enum import Enum, StrEnum, unique + +import torch +import triton +import triton.language as tl +from triton.ops.matmul import ( + estimate_matmul_time, + get_configs_io_bound, + get_higher_dtype, +) +from triton.runtime import driver + +from .custom_autotune import Config, autotune + +logger = logging.getLogger(__name__) + + +@unique +class SwizzleType(Enum): + GROUPED = 0 + COLUMN_MAJOR = 1 + ROW_MAJOR = 2 + + +class TritonInputPrecision(StrEnum): + IEEE: str = "ieee" + TF32: str = "tf32" + TF32X3: str = "tf32x3" + + +TRITON_SUPPORTED_ACC_TYPES = { + torch.float16: (torch.float32, torch.float16), + torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32,), + torch.int8: (torch.int32,), +} + + +def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + + +def get_compute_bound_configs(): + configs = [ + # basic configs for compute-bound matmuls + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + # good for int8 + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + ] + return configs + + +@triton.jit() +def swizzle_tile( + pid, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_M: tl.constexpr, + SWIZZLE: tl.constexpr, +): + if SWIZZLE == tl.constexpr(SwizzleType.GROUPED): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + else: + tl.static_assert(False, "swizzle type not supported") + + return pid_m, pid_n + + +def get_small_k_configs(): + configs = get_compute_bound_configs() + get_configs_io_bound() + KEYS_TO_REMOVE = ["BLOCK_K", "SPLIT_K"] + for cfg in configs: + for key in KEYS_TO_REMOVE: + del cfg.kwargs[key] + + return configs + + +def small_k_early_config_prune(configs, named_args, **kwargs): + device = torch.cuda.current_device() + capability = torch.cuda.get_device_capability() + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args["A"].element_size() + dtype = named_args["A"].dtype + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + named_args["K"], + config.num_stages, + ) + + max_shared_memory = driver.active.utils.get_device_properties(device)[ + "max_shared_mem" + ] + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs + + # Some dtypes do not allow atomic_add + # if dtype not in [torch.float16, torch.float32]: + # configs = [config for config in configs if config.kwargs["SPLIT_K"] == 1] + + # group configs by (BLOCK_M,_N,_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + named_args["K"], + # kw["SPLIT_K"], + config.num_warps, + config.num_stages, + ) + + key = (BLOCK_M, BLOCK_N, BLOCK_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, num_warps = k + if capability[0] >= 8: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest( + 2, + v, + key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 + else x[1] - optimal_num_stages, + ) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs + + +SMALLK_HEURISTICS = { + "BLOCK_K": lambda args: args["K"], +} + +_AUTOTUNE_TOPK = 10 + + +# @heuristics(SMALLK_HEURISTICS) +@autotune( + get_small_k_configs(), + key=["M", "N", "K"], + prune_configs_by={ + "early_config_prune": small_k_early_config_prune, + "perf_model": estimate_matmul_time, + "top_k": _AUTOTUNE_TOPK, + }, +) +@triton.jit +def _mm_small_k_kernel( + A, + B, + M, + N, + K, # + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + acc_dtype: tl.constexpr, # + input_precision: tl.constexpr, # + fp8_fast_accum: tl.constexpr, # + BLOCK_K: tl.constexpr, # + AB_DTYPE: tl.constexpr, # + BLOCK_M: tl.constexpr = 256, + BLOCK_N: tl.constexpr = 64, + C=None, + stride_cm=None, + stride_cn=None, # + Norm2=None, + Source=None, + stride_sourcem=None, + stride_sourcen=None, + Magnitude=None, + ADD_SOURCE: tl.constexpr = False, + EPILOGUE_NORM: tl.constexpr = False, + EPILOGUE_MAGNITUDE: tl.constexpr = False, + STORE_ACC: tl.constexpr = False, +): + pid_m = tl.program_id(0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rk = tl.arange(0, BLOCK_K) + + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + a = tl.load(A) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + + rn = tl.arange(0, BLOCK_N) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + if STORE_ACC: + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + + if ADD_SOURCE: + Source = Source + (rm[:, None] * stride_sourcem + rn[None, :] * stride_sourcen) + + if EPILOGUE_NORM: + norm_vec = tl.zeros((BLOCK_M,), dtype=acc_dtype) + + if EPILOGUE_MAGNITUDE: + Magnitude = Magnitude + ram + + mask_m = rm < M + + for n in range(0, tl.cdiv(N, BLOCK_N)): + # Advance B over N + + b = tl.load(B) + + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + + if fp8_fast_accum: + acc = tl.dot( + a, b, acc, out_dtype=acc_dtype, input_precision=input_precision + ) + else: + acc = tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) + + if ADD_SOURCE: + mask_n = (n * BLOCK_N + rn < N)[None, :] + source = tl.load(Source, mask=mask_m[:, None] & mask_n) + acc += source.to(acc_dtype) + Source += BLOCK_N * stride_sourcen + + # 2-norm = tl.sqrt(tl.sum(acc * acc, axis=1)) + if EPILOGUE_NORM: + norm_vec += tl.sum(acc * acc, axis=1) + + if STORE_ACC: + mask_n = (n * BLOCK_N + rn < N)[None, :] + tl.store(C, acc.to(C.dtype.element_ty), mask=mask_m[:, None] & mask_n) + C += BLOCK_N * stride_cn + + B += BLOCK_N * stride_bn + + if EPILOGUE_NORM: + Norm2 = Norm2 + rm + norm_vec = tl.rsqrt(norm_vec).to(Norm2.dtype.element_ty) + + if EPILOGUE_MAGNITUDE: + magnitude = tl.load(Magnitude, mask=mask_m) + norm_vec *= magnitude + + tl.store(Norm2, norm_vec, mask=mask_m) + + +def triton_mm_small_k( + a: torch.Tensor, + b: torch.Tensor, + epilogue_norm: bool = True, + source: torch.Tensor = None, + magnitude: torch.Tensor = None, + store_acc: bool = False, + acc_dtype: torch.dtype = None, + input_precision: TritonInputPrecision = TritonInputPrecision.IEEE, + fp8_fast_accum: bool = False, + output_dtype: torch.dtype = None, +): + """Computes GEMM for small K {16, 32, 64} + + Assumes that K is small enough that the MAC loop within each block is a single iteration. + Instead of iterating over K, we iterate over N per block such that each block computes a BLK_M x N row of C. Kernel grid is ceildiv(M, BLOCK_M). + + This specialized GEMM is primarily useful for low-rank projections and fusing grid-wide reductions into the epilogue. + + Currently, the following fusions are implemented: + - `epilogue_norm` - when set to True, the kernel computes the reverse 2-norm along axis=1 of AB ( `1 / 2-norm(AB, axis=1)` ) + - `source=torch.Tensor` - when passed a tensor of shape `M x N`, the kernel computes `D = AB + source` + - `magnitude=torch.Tensor` - when passed a tensor of shape `M`, the kernel additionally multiplies the epilogue norm by the magnitude vector + + Hence, when the above fusions are enabled, the kernel can be used to compute DoRA layer magnitude normalization: `magnitude * (base_weight + lora_B(lora_A(x))).norm(2, axis=1)` + + Args: + a (torch.Tensor): operand A + b (torch.Tensor): operand B + source (torch.Tensor): Operand C in `D = AB + C` + epilogue_norm (bool, optional): Whether to calculate 1 / 2-norm(AB, axis=1) + magnitude (torch.Tensor): vector to multiply epilogue norm by + store_acc (bool): whether to store `AB`, if False, then `epilogue_norm` must be True, in which case only the `2-norm` is stored + acc_dtype (torch.DType): accumulator type in MAC loop + input_precision (TritonInputPrecision): precision to use for fp32 matmul + fp8_fast_accum (bool) + output_dtype (torch.DType): type for output tensors (`D`, `2-norm`, etc.) + + Returns: + torch.Tensor + """ + assert store_acc or epilogue_norm, "Must use store_acc or epilogue_norm" + + device = a.device + + # Make sure inputs are contiguous + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + + assert a.shape[1] == b.shape[0], "Incompatible operand dimensions" + M, K = a.shape + _, N = b.shape + + assert K < 128, "K must be < 128 to use this kernel" + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + if output_dtype is None: + output_dtype = ab_dtype + + if epilogue_norm: + norm2 = torch.zeros(M, device=device, dtype=output_dtype) + + # Must set out_dtype before converting dtypes to tl types + if store_acc: + c = torch.empty((M, N), device=device, dtype=output_dtype) + + if acc_dtype is None: + acc_dtype = TRITON_SUPPORTED_ACC_TYPES[ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert ( + acc_dtype in TRITON_SUPPORTED_ACC_TYPES[a.dtype] + ), "acc_dtype not compatible with the type of a" + assert ( + acc_dtype in TRITON_SUPPORTED_ACC_TYPES[b.dtype] + ), "acc_dtype not compatible with the type of b" + + # Convert dtypes to tl types + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Use fp8 types in MAC loop + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ + tl.float8e4nv, + tl.float8e5, + ]: + ab_dtype = None + + logger.debug( + f"triton_mm_small_k: {ab_dtype=} {acc_dtype=} {input_precision=} {fp8_fast_accum=} {output_dtype=}" + ) + + # Set the fusion and other GEMM kwargs + # IMPORTANT: BLOCK_K must be equal to K + kwargs = { + "BLOCK_K": K, + "acc_dtype": acc_dtype, + "input_precision": input_precision, + "fp8_fast_accum": fp8_fast_accum, + "AB_DTYPE": ab_dtype, + "EPILOGUE_NORM": epilogue_norm, + "ADD_SOURCE": source is not None, + "EPILOGUE_MAGNITUDE": magnitude is not None, + "STORE_ACC": store_acc, + } + + # 2-norm params + if epilogue_norm: + kwargs["Norm2"] = norm2 + + # source params + if source is not None: + assert source.shape == (M, N) + kwargs["Source"] = source + kwargs["stride_sourcem"] = source.stride(0) + kwargs["stride_sourcen"] = source.stride(1) + else: + kwargs["Source"] = None + kwargs["stride_sourcem"] = 0 + kwargs["stride_sourcen"] = 0 + + # magnitude params, epilogue_norm must be True + if magnitude is not None: + assert epilogue_norm, "magnitude requires epilogue_norm" + assert magnitude.ndim == 1 and magnitude.shape[0] == M + kwargs["Magnitude"] = magnitude + + # store_acc, whether to store the intermediate AB + if store_acc: + kwargs["C"] = c + kwargs["stride_cm"] = c.stride(0) + kwargs["stride_cn"] = c.stride(1) + else: + kwargs["C"] = None + kwargs["stride_cm"] = 0 + kwargs["stride_cn"] = 0 + + # kwargs_str = " ".join( + # f"{k}={v}" for k, v in kwargs.items() if not isinstance(v, torch.Tensor) + # ) + # print(f"triton_mm_small_k: {kwargs_str}") + + # launch kernel + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) + _mm_small_k_kernel[grid]( + a, + b, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + **kwargs, + ) + + if store_acc: + if epilogue_norm: + return c, norm2 + else: + return c + return norm2