From a5b7137b0cd77a367142445e9a95d252d5fe62eb Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 5 May 2024 18:30:54 +0000 Subject: [PATCH 01/15] add dora kernels --- test/dora/test_dora_fusion.py | 183 ++++++ test/dora/test_dora_layer.py | 107 ++++ torchao/prototype/dora/kernels/__init__.py | 0 torchao/prototype/dora/kernels/common.py | 176 ++++++ .../prototype/dora/kernels/custom_autotune.py | 395 +++++++++++++ torchao/prototype/dora/kernels/matmul.py | 259 +++++++++ torchao/prototype/dora/kernels/smallk.py | 545 ++++++++++++++++++ 7 files changed, 1665 insertions(+) create mode 100644 test/dora/test_dora_fusion.py create mode 100644 test/dora/test_dora_layer.py create mode 100644 torchao/prototype/dora/kernels/__init__.py create mode 100644 torchao/prototype/dora/kernels/common.py create mode 100644 torchao/prototype/dora/kernels/custom_autotune.py create mode 100644 torchao/prototype/dora/kernels/matmul.py create mode 100644 torchao/prototype/dora/kernels/smallk.py diff --git a/test/dora/test_dora_fusion.py b/test/dora/test_dora_fusion.py new file mode 100644 index 000000000..78a9d2e6e --- /dev/null +++ b/test/dora/test_dora_fusion.py @@ -0,0 +1,183 @@ +import itertools + +import pytest +import torch + +from torchao.prototype.dora.kernels.matmul import triton_mm +from torchao.prototype.dora.kernels.smallk import triton_mm_small_k + +torch.manual_seed(0) + +# Test configs +M = 4096 +N = 4096 +Ks = [int(2**i) for i in range(4, 7)] + +FUSED_DORA_SHAPES = [(M, N, K) for K in Ks[:1]] + +DTYPES = [torch.float32, torch.float16, torch.bfloat16] + +STORE_ACC = [False] +EPILOGUE_NORM = [True, False] +ADD_SOURCE = [True] +MAGNITUDE_VECTOR = [True] +FUSED_DORA_TEST_CONFIGS = list( + itertools.product( + FUSED_DORA_SHAPES, + STORE_ACC, + EPILOGUE_NORM, + ADD_SOURCE, + MAGNITUDE_VECTOR, + DTYPES, + ) +) + + +def _arg_to_id(arg): + if isinstance(arg, (tuple, list)): + return "x".join([str(x) for x in arg]) + return str(arg) + + +def check(expected, actual, dtype): + if dtype == torch.float32: + atol = 1e-4 + elif dtype == torch.float16: + atol = 1e-3 + elif dtype == torch.bfloat16: + atol = 1e-2 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + diff = (expected - actual).abs().max() + print(f"diff: {diff}") + # assert diff < atol + return diff + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize( + "shape, store_acc, epilogue_norm, add_source, magnitude_vector, dtype", + FUSED_DORA_TEST_CONFIGS, + ids=_arg_to_id, +) +def test_dora_column_norm( + shape, store_acc, epilogue_norm, add_source, magnitude_vector, dtype +): + if not (store_acc or epilogue_norm): + pytest.skip("Either store_acc or epilogue_norm must be True") + + M, N, K = shape + A = torch.randn(M, K, device="cuda", dtype=dtype) + B = torch.randn(K, N, device="cuda", dtype=dtype) + source = torch.randn(M, N, device="cuda", dtype=dtype) + magnitude = torch.randn(M, device="cuda", dtype=dtype) + + c_ref = torch.matmul(A, B) + norm2_ref = 1 / c_ref.norm(2, dim=1) + source_ref = source + c_ref + source_norm2_ref = 1 / (source + c_ref).norm(2, dim=1) + source_norm2_magnitude_ref = magnitude * source_norm2_ref + + # First test small K only kernel, no epilogue + # source = None # source # None + # magnitude = None # magnitude # None + + tt_out = triton_mm_small_k( + A, + B, + source=source if add_source else None, + magnitude=magnitude if magnitude_vector else None, + epilogue_norm=epilogue_norm, + store_acc=store_acc, + ) + + if store_acc: + c_test = tt_out[0] if epilogue_norm else tt_out + if add_source: + check(source_ref, c_test, dtype) + else: + check(c_ref, c_test, dtype) + + if epilogue_norm: + norm2_test = tt_out[1] if store_acc else tt_out + if add_source: + if magnitude_vector: + check(source_norm2_magnitude_ref, norm2_test, dtype) + else: + check(source_norm2_ref, norm2_test, dtype) + else: + check(norm2_ref, norm2_test, dtype) + + +BATCH_SIZES = [int(2**i) for i in range(6)] +SEQ_LENS = [512] +IN_FEATURES = [4096] +OUT_FEATURES = [4096] +FUSED_MATMUL_SHAPES = [ + (bs * seqlen, in_features, out_features) + for bs, seqlen, in_features, out_features in zip( + BATCH_SIZES, SEQ_LENS, IN_FEATURES, OUT_FEATURES + ) +] +EPILOGUE_ELEMENTWISE_ADD = [True] +EPILOGUE_BROADCAST_SCALE = [True] + +FUSED_MATMUL_TEST_CONFIGS = list( + itertools.product( + FUSED_MATMUL_SHAPES[:1], + DTYPES, + EPILOGUE_ELEMENTWISE_ADD, + EPILOGUE_BROADCAST_SCALE, + ) +) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize( + "shape, dtype, epilogue_add, epilogue_scale", + FUSED_MATMUL_TEST_CONFIGS, + ids=_arg_to_id, +) +def test_dora_matmul(shape, dtype, epilogue_add, epilogue_scale): + M, K, N = shape + A = torch.randn(M, K, device="cuda", dtype=dtype) + B = torch.randn(K, N, device="cuda", dtype=dtype) + C = torch.randn(M, N, device="cuda", dtype=dtype) if epilogue_add else None + scale = torch.randn(N, device="cuda", dtype=dtype) if epilogue_scale else None + + D_ref = torch.matmul(A, B) + if epilogue_add: + D_ref += C + if epilogue_scale: + D_ref *= scale.unsqueeze(0) + + D_test = triton_mm(A, B, epilogue_source=C, epilogue_scale=scale) + check(D_ref, D_test, dtype) + + +MODES = ["default"] + + +@pytest.mark.skip("TODO: torch.compile does not work with custom kernel") +@pytest.mark.parametrize( + "shape, dtype, epilogue_add, epilogue_scale, mode", + [[*cfg, mode] for cfg in FUSED_MATMUL_TEST_CONFIGS for mode in MODES][:1], + ids=_arg_to_id, +) +def test_dora_matmul_compile(shape, dtype, epilogue_add, epilogue_scale, mode): + M, K, N = shape + A = torch.randn(M, K, device="cuda", dtype=dtype) + B = torch.randn(K, N, device="cuda", dtype=dtype) + C = torch.randn(M, N, device="cuda", dtype=dtype) if epilogue_add else None + scale = torch.randn(N, device="cuda", dtype=dtype) if epilogue_scale else None + + D_ref = torch.matmul(A, B) + if epilogue_add: + D_ref += C + if epilogue_scale: + D_ref *= scale.unsqueeze(0) + + D_test = triton_mm(A, B, epilogue_source=C, epilogue_scale=scale) + check(D_ref, D_test, dtype) + + triton_compiled = torch.compile(triton_mm, mode=mode) + D_compiled = triton_compiled(A, B, epilogue_source=C, epilogue_scale=scale) + check(D_ref, D_compiled, dtype) diff --git a/test/dora/test_dora_layer.py b/test/dora/test_dora_layer.py new file mode 100644 index 000000000..161d81e39 --- /dev/null +++ b/test/dora/test_dora_layer.py @@ -0,0 +1,107 @@ +# Linear4bit = pytest.importorskip( +# "bitsandbytes.nn.Linear4bit", reason="requires bitsandbytes" +# ) +# HQQLinear = pytest.importorskip("hqq.core.quantize.HQQLinear", reason="requires hqq") +# BaseQuantizeConfig = pytest.importorskip( +# "hqq.core.quantize.BaseQuantizeConfig", reason="requires hqq" +# ) +import itertools + +import pytest +import torch +from bitsandbytes.nn import Linear4bit +from hqq.core.quantize import BaseQuantizeConfig, HQQLinear + +from prototypes.dora.dora_layer import BNBDoRALinear, DoRALinear, HQQDoRALinear + + +def check(expected, actual, dtype): + if dtype == torch.float32: + atol = 1e-4 + elif dtype == torch.float16: + atol = 1e-3 + elif dtype == torch.bfloat16: + atol = 1e-2 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + diff = (expected - actual).abs().max() + print(f"diff: {diff}") + # assert diff < atol + return diff + + +def _arg_to_id(arg): + if isinstance(arg, (tuple, list)): + return "x".join([str(x) for x in arg]) + return str(arg) + + +BATCH_SIZES = [1] +SEQ_LENS = [512] +DTYPES = [torch.float32, torch.float16, torch.bfloat16] +IN_FEATURES = [4096] +OUT_FEATURES = [4096, 11008] +LORA_RANKS = [16] +MODEL_TYPES = ["DoRALinear", "BNBDoRALinear", "HQQDoRALinear"] + +TEST_CONFIGS = list( + itertools.product( + BATCH_SIZES, + SEQ_LENS, + IN_FEATURES, + OUT_FEATURES, + LORA_RANKS, + DTYPES, + MODEL_TYPES, + ) +) + + +@pytest.mark.parametrize( + "bs, seqlen, in_features, out_features, lora_rank, dtype, model_type", + TEST_CONFIGS, + ids=_arg_to_id, +) +def test_dora_layer( + bs, seqlen, in_features, out_features, lora_rank, dtype, model_type +): + x = torch.randn(bs, seqlen, in_features, dtype=dtype).cuda() + + if model_type == "DoRALinear": + base_layer = torch.nn.Linear( + in_features, out_features, dtype=dtype, bias=False + ).cuda() + dora_cls = DoRALinear + + elif model_type == "BNBDoRALinear": + base_layer = Linear4bit( + input_features=in_features, + output_features=out_features, + bias=False, + quant_type="nf4", + compute_dtype=dtype, + ).cuda() + base_layer.quant_state.dtype = base_layer.compute_dtype + dora_cls = BNBDoRALinear + + elif model_type == "HQQDoRALinear": + quant_config = BaseQuantizeConfig( + nbits=4, + group_size=64, + quant_zero=False, + quant_scale=False, + offload_meta=True, + view_as_float=True, + ) + torch_base = torch.nn.Linear(in_features, out_features, dtype=dtype, bias=False) + base_layer = HQQLinear( + torch_base, + quant_config, + compute_dtype=dtype, + ) + dora_cls = HQQDoRALinear + dora_layer = dora_cls(base_layer, lora_rank).cuda() + + ref = dora_layer.forward(x) + test = dora_layer.forward_fused(x) + check(ref, test, dtype) diff --git a/torchao/prototype/dora/kernels/__init__.py b/torchao/prototype/dora/kernels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchao/prototype/dora/kernels/common.py b/torchao/prototype/dora/kernels/common.py new file mode 100644 index 000000000..cd0950d4c --- /dev/null +++ b/torchao/prototype/dora/kernels/common.py @@ -0,0 +1,176 @@ +from enum import Enum, StrEnum, unique + +import torch +import triton +import triton.language as tl + +# Re-exports +from triton.ops.matmul import ( + early_config_prune, + estimate_matmul_time, + get_configs_io_bound, + get_higher_dtype, +) +from triton.runtime import Config + + +@unique +class SwizzleType(Enum): + GROUPED = 0 + COLUMN_MAJOR = 1 + ROW_MAJOR = 2 + + +class TritonInputPrecision(StrEnum): + IEEE: str = "ieee" + TF32: str = "tf32" + TF32X3: str = "tf32x3" + + +TRITON_SUPPORTED_ACC_TYPES = { + torch.float16: (torch.float32, torch.float16), + torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32,), + torch.int8: (torch.int32,), +} + +MATMUL_HEURISTICS = { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + "SPLIT_K": lambda args: 1 + if (args["A"].dtype == torch.bfloat16 or args["B"].dtype == torch.bfloat16) + else args["SPLIT_K"], # atomic add not supported for bfloat16 +} + + +def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + + +def get_compute_bound_configs(): + configs = [ + # basic configs for compute-bound matmuls + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + # good for int8 + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + ] + return configs + + +@triton.jit() +def swizzle_tile( + pid, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_M: tl.constexpr, + SWIZZLE: tl.constexpr, +): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + + if SWIZZLE == tl.constexpr(SwizzleType.GROUPED): + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + elif SWIZZLE == tl.constexpr(SwizzleType.COLUMN_MAJOR): + pid_m = pid % grid_m + pid_n = pid // grid_m + elif SWIZZLE == tl.constexpr(SwizzleType.ROW_MAJOR): + pid_m = pid // grid_n + pid_n = pid % grid_n + else: + tl.static_assert(False, "swizzle type not supported") + + return pid_m, pid_n diff --git a/torchao/prototype/dora/kernels/custom_autotune.py b/torchao/prototype/dora/kernels/custom_autotune.py new file mode 100644 index 000000000..f67152068 --- /dev/null +++ b/torchao/prototype/dora/kernels/custom_autotune.py @@ -0,0 +1,395 @@ +from __future__ import annotations + +import builtins +import logging +import os +import time +from typing import Dict + +import numpy as np +from triton.runtime.cache import default_cache_dir +from triton.runtime.errors import OutOfResources +from triton.runtime.jit import KernelInterface +from triton.testing import do_bench + +logger = logging.getLogger(__file__) + + +class Autotuner(KernelInterface): + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + prune_configs_by: Dict = None, + warmup=25, + rep=100, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_idx = [] + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + self.restore_idx = [] + if restore_value is not None: + self.restore_idx = [arg_names.index(k) for k in restore_value] + + # Hook to reset or restore for required tensors + self.pre_hook = lambda args, reset_only=False: 0 + self.post_hook = lambda args: 0 + if len(self.reset_idx) > 0 or len(self.restore_idx) > 0: + + def _pre_hook(args, reset_only=False): + for i in self.reset_idx: + args[i].zero_() + if not reset_only: + self.restore_copies = [args[i].clone() for i in self.restore_idx] + + self.pre_hook = _pre_hook + if len(self.restore_idx) > 0: + + def _post_hook(args): + for i, j in enumerate(self.restore_idx): + args[j].copy_(self.restore_copies[i]) + self.restore_copies = [] + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get( + "early_config_prune", self.early_config_prune + ) + + self.fn = fn + self.num_warmups = warmup + self.num_reps = rep + # self.autotune_log_path = os.path.join(default_cache_dir(), autotune_log_file) + self.kernel_name = self._find_kernel_name() + + def _find_kernel_name(self): + try: + kernel_name = self.fn.__name__ + except AttributeError: + try: # in case JITfn is wrapped in both autotune and heuristic + kernel_name = self.fn.fn.__name__ + except: # noqa + kernel_name = self.fn.__name__ + return kernel_name + + def _get_key_combination(self, args, as_str=True, sep=" "): + key_vals = [f"{self.arg_names[i]}={args[i]}" for i in self.key_idx] + return f"{sep}".join(key_vals) if as_str else key_vals + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(args) + self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + num_ctas=config.num_ctas, + **current, + ) + self.post_hook(args) + + try: + return do_bench( + kernel_call, + warmup=self.num_warmups, + rep=self.num_reps, + quantiles=(0.5, 0.2, 0.8), + ) + except OutOfResources: + return [float("inf"), float("inf"), float("inf")] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + logger.debug(f"Autotune Num Configs: {len(self.configs)}") + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = [] + for name in self.arg_names: + if name in all_args: + _args.append(all_args[name]) + key = [_args[i] for i in self.key_idx] + for arg in _args: + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + logger.debug( + f"\n==== Autotune ====\nRunning autotune for {self.kernel_name} for {len(self.configs)} total configs" + f" for key combination {self._get_key_combination(args)}..." + ) + # prune configs + pruned_configs = self.prune_configs(kwargs) + logger.debug(f"\nNum configs after pruning {len(pruned_configs)}") + bench_start = time.time() + timings = {} + for config in pruned_configs: + timings[config] = self._bench(*args, config=config, **kwargs) + # timings = { + # config: self._bench(*args, config=config, **kwargs) + # for config in pruned_configs + # } + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.pre_hook(args, reset_only=True) + self.configs_timings = timings + + sorted_timings = dict( + sorted(timings.items(), key=lambda x: np.mean(x[1])) + ) + _key_suffix = self._get_key_combination(args, sep="-") + autotune_file = f"autotune_{self.kernel_name}_{_key_suffix}.log" + autotune_log_path = os.path.join(default_cache_dir(), autotune_file) + + logger.debug(f"\nFinished autotune, writing log to {autotune_log_path}") + + with open(f"{autotune_log_path}", "w") as f: + f.write( + f" ==== Autotune Results ====\nKernel name: {self.kernel_name}\nArgs: {self.arg_names}\nKeys: {self._get_key_combination(args)}\n" + ) + f.write(f"\nPruned configs:\n") + for cfg in pruned_configs: + f.write(f"{cfg}\n") + f.write(f"Timings:\n") + for cfg, timing in sorted_timings.items(): + f.write(f"{cfg} {timing} \n") + f.write(f"Best config: {self.cache[key]}\n") + else: + logger.debug( + f"Key {key} for {self.kernel_name} already in cache, skipping autotune\n" + ) + + config = self.cache[key] + # logger.debug(f"\nAutotune: Cache hit! Running best config...") + else: + config = self.configs[0] + self.best_config = config + logger.debug(f"\nAutotune Best Config: {config}\n") + + full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} + if config.pre_hook is not None: + config.pre_hook(full_nargs) + ret = self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + num_ctas=config.num_ctas, + **kwargs, + **config.kwargs, + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps, + num_ctas=config.num_ctas, + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ + :top_k + ] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append( + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_ctas=config.num_ctas, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + ) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type meta: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.pre_hook = pre_hook + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + return ", ".join(res) + + +def autotune( + configs, + key, + prune_configs_by=None, + reset_to_zero=None, + restore_value=None, + warmup=25, + rep=100, +): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. + :type warmup: int + :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. + :type rep: int + """ + + def decorator(fn): + return Autotuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + restore_value, + prune_configs_by, + warmup, + rep, + ) + + return decorator + + +class Heuristics(KernelInterface): + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/torchao/prototype/dora/kernels/matmul.py b/torchao/prototype/dora/kernels/matmul.py new file mode 100644 index 000000000..66e5ef77e --- /dev/null +++ b/torchao/prototype/dora/kernels/matmul.py @@ -0,0 +1,259 @@ +import logging + +import torch +import triton +import triton.language as tl + +from .common import ( + MATMUL_HEURISTICS, + TRITON_SUPPORTED_ACC_TYPES, + SwizzleType, + TritonInputPrecision, + early_config_prune, + estimate_matmul_time, + get_compute_bound_configs, + get_configs_io_bound, + get_higher_dtype, + swizzle_tile, + to_tl_type, +) +from .custom_autotune import autotune + +logger = logging.getLogger(__name__) + + +_AUTOTUNE_TOPK = 10 + + +@autotune( + get_compute_bound_configs() + get_configs_io_bound(), + key=["M", "N", "K"], + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": estimate_matmul_time, + "top_k": _AUTOTUNE_TOPK, + }, +) +@triton.heuristics( + { + "EVEN_K": MATMUL_HEURISTICS["EVEN_K"], + "SPLIT_K": MATMUL_HEURISTICS["SPLIT_K"], + } +) +@triton.jit +def _matmul_kernel( + A, + B, + C, + M, + N, + K, # + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, # + acc_dtype: tl.constexpr, # + input_precision: tl.constexpr, # + fp8_fast_accum: tl.constexpr, # + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, # + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + AB_DTYPE: tl.constexpr, # + SWIZZLE: tl.constexpr, + EPILOGUE_ELEMENTWISE_ADD: tl.constexpr = False, + Epilogue_source=None, + EPILOGUE_BROADCAST_SCALE: tl.constexpr = False, + Epilogue_scale=None, +): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + + # Threadblock swizzle + pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_M, BLOCK_N, GROUP_M, SWIZZLE) + + # Operand offsets + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + + # Operand pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # Allocate accumulator + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + + # MAC Loop + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + if fp8_fast_accum: + acc = tl.dot( + a, b, acc, out_dtype=acc_dtype, input_precision=input_precision + ) + else: + acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) + + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + # Convert acc to output dtype + acc = acc.to(C.dtype.element_ty) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + # mask = (rm < M)[:, None] & (rn < N)[None, :] + mask_m = (rm < M)[:, None] + mask_n = (rn < N)[None, :] + if EPILOGUE_ELEMENTWISE_ADD: + Epilogue_source = Epilogue_source + ( + rm[:, None] * stride_cm + rn[None, :] * stride_cn + ) + source = tl.load(Epilogue_source, mask=mask_m & mask_n) + acc += source + if EPILOGUE_BROADCAST_SCALE: + Epilogue_scale = Epilogue_scale + (rn[None, :]) + scale = tl.load(Epilogue_scale, mask=mask_n) + acc *= scale + + if SPLIT_K == 1: + tl.store(C, acc, mask=mask_m & mask_n) + else: + tl.atomic_add(C, acc, mask=mask_m & mask_n) + + +def triton_mm( + a, + b, + epilogue_source=None, + epilogue_scale=None, + acc_dtype=None, + input_precision=TritonInputPrecision.IEEE, + fp8_fast_accum=False, + output_dtype=None, + swizzle: SwizzleType = SwizzleType.GROUPED, + GROUP_M: int = 8, +): + """Triton GEMM implementation, `D = AB + C` + + Based on `triton.ops.matmul`, with the addition of epilogue. + + Args: + a (torch.Tensor): operand A + b (torch.Tensor): operand B + epilogue_source(optional, torch.Tensor): operand C in `D = AB + C` + epilogue_scale(optional, torch.Tensor): row-wise scale-vector of dim `N` in `D = scale * (AB + C)` + acc_dtype (torch.DType): accumulator type in MAC loop + input_precision (TritonInputPrecision): precision to use for fp32 matmul + fp8_fast_accum (bool) + output_dtype (optional, torch.DType): output type of the GEMM, defaults to higher dtype of A / B + + Returns: + torch.Tensor: `D = AB + C` + """ + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + # allocates output + if output_dtype is None: + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) + + # Epilogue pre-conditions + # TODO Check strides? + if epilogue_source is not None: + assert epilogue_source.shape == (M, N), "incompatible dimensions" + assert epilogue_source.dtype == c.dtype, "incompatible dtype" + + if epilogue_scale is not None: + assert ( + epilogue_scale.ndim == 1 and epilogue_scale.shape[0] == N + ), "incompatible dimensions" + assert epilogue_scale.dtype == c.dtype, "incompatible dtype" + + # choose accumulator type + if acc_dtype is None: + acc_dtype = TRITON_SUPPORTED_ACC_TYPES[ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert ( + acc_dtype in TRITON_SUPPORTED_ACC_TYPES[a.dtype] + ), "acc_dtype not compatible with the type of a" + assert ( + acc_dtype in TRITON_SUPPORTED_ACC_TYPES[b.dtype] + ), "acc_dtype not compatible with the type of b" + + # convert to triton types + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Tensor cores support input with mixed float8 types. + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ + tl.float8e4nv, + tl.float8e5, + ]: + ab_dtype = None + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + META["SPLIT_K"], + ) + + _matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + acc_dtype=acc_dtype, # + input_precision=input_precision, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=GROUP_M, + AB_DTYPE=ab_dtype, + SWIZZLE=swizzle, + EPILOGUE_ELEMENTWISE_ADD=epilogue_source is not None, + Epilogue_source=epilogue_source, + EPILOGUE_BROADCAST_SCALE=epilogue_scale is not None, + Epilogue_scale=epilogue_scale, + ) + return c diff --git a/torchao/prototype/dora/kernels/smallk.py b/torchao/prototype/dora/kernels/smallk.py new file mode 100644 index 000000000..227264cf0 --- /dev/null +++ b/torchao/prototype/dora/kernels/smallk.py @@ -0,0 +1,545 @@ +import heapq +import logging +from enum import Enum, StrEnum, unique + +import torch +import triton +import triton.language as tl +from triton.ops.matmul import ( + estimate_matmul_time, + get_configs_io_bound, + get_higher_dtype, +) +from triton.runtime import driver + +from .custom_autotune import Config, autotune + +logger = logging.getLogger(__name__) + + +@unique +class SwizzleType(Enum): + GROUPED = 0 + COLUMN_MAJOR = 1 + ROW_MAJOR = 2 + + +class TritonInputPrecision(StrEnum): + IEEE: str = "ieee" + TF32: str = "tf32" + TF32X3: str = "tf32x3" + + +TRITON_SUPPORTED_ACC_TYPES = { + torch.float16: (torch.float32, torch.float16), + torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32,), + torch.int8: (torch.int32,), +} + + +def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + + +def get_compute_bound_configs(): + configs = [ + # basic configs for compute-bound matmuls + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + # good for int8 + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + ] + return configs + + +@triton.jit() +def swizzle_tile( + pid, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_M: tl.constexpr, + SWIZZLE: tl.constexpr, +): + if SWIZZLE == tl.constexpr(SwizzleType.GROUPED): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + else: + tl.static_assert(False, "swizzle type not supported") + + return pid_m, pid_n + + +def get_small_k_configs(): + configs = get_compute_bound_configs() + get_configs_io_bound() + KEYS_TO_REMOVE = ["BLOCK_K", "SPLIT_K"] + for cfg in configs: + for key in KEYS_TO_REMOVE: + del cfg.kwargs[key] + + return configs + + +def small_k_early_config_prune(configs, named_args, **kwargs): + device = torch.cuda.current_device() + capability = torch.cuda.get_device_capability() + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args["A"].element_size() + dtype = named_args["A"].dtype + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + named_args["K"], + config.num_stages, + ) + + max_shared_memory = driver.active.utils.get_device_properties(device)[ + "max_shared_mem" + ] + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs + + # Some dtypes do not allow atomic_add + # if dtype not in [torch.float16, torch.float32]: + # configs = [config for config in configs if config.kwargs["SPLIT_K"] == 1] + + # group configs by (BLOCK_M,_N,_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + named_args["K"], + # kw["SPLIT_K"], + config.num_warps, + config.num_stages, + ) + + key = (BLOCK_M, BLOCK_N, BLOCK_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, num_warps = k + if capability[0] >= 8: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest( + 2, + v, + key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 + else x[1] - optimal_num_stages, + ) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs + + +SMALLK_HEURISTICS = { + "BLOCK_K": lambda args: args["K"], +} + +_AUTOTUNE_TOPK = 10 + + +# @heuristics(SMALLK_HEURISTICS) +@autotune( + get_small_k_configs()[:10], + key=["M", "N", "K"], + prune_configs_by={ + "early_config_prune": small_k_early_config_prune, + "perf_model": estimate_matmul_time, + "top_k": _AUTOTUNE_TOPK, + }, +) +@triton.jit +def _mm_small_k_kernel( + A, + B, + M, + N, + K, # + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + acc_dtype: tl.constexpr, # + input_precision: tl.constexpr, # + fp8_fast_accum: tl.constexpr, # + BLOCK_K: tl.constexpr, # + AB_DTYPE: tl.constexpr, # + BLOCK_M: tl.constexpr = 256, + BLOCK_N: tl.constexpr = 64, + C=None, + stride_cm=None, + stride_cn=None, # + Norm2=None, + Source=None, + stride_sourcem=None, + stride_sourcen=None, + Magnitude=None, + ADD_SOURCE: tl.constexpr = False, + EPILOGUE_NORM: tl.constexpr = False, + EPILOGUE_MAGNITUDE: tl.constexpr = False, + STORE_ACC: tl.constexpr = False, +): + pid_m = tl.program_id(0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rk = tl.arange(0, BLOCK_K) + + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + a = tl.load(A) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + + rn = tl.arange(0, BLOCK_N) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + if STORE_ACC: + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + + if ADD_SOURCE: + Source = Source + (rm[:, None] * stride_sourcem + rn[None, :] * stride_sourcen) + + if EPILOGUE_NORM: + norm_vec = tl.zeros((BLOCK_M,), dtype=acc_dtype) + + if EPILOGUE_MAGNITUDE: + Magnitude = Magnitude + ram + + mask_m = rm < M + + for n in range(0, tl.cdiv(N, BLOCK_N)): + # Advance B over N + + b = tl.load(B) + + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + + if fp8_fast_accum: + acc = tl.dot( + a, b, acc, out_dtype=acc_dtype, input_precision=input_precision + ) + else: + acc = tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) + + if ADD_SOURCE: + mask_n = (n * BLOCK_N + rn < N)[None, :] + source = tl.load(Source, mask=mask_m[:, None] & mask_n) + acc += source.to(acc_dtype) + Source += BLOCK_N * stride_sourcen + + # 2-norm = tl.sqrt(tl.sum(acc * acc, axis=1)) + if EPILOGUE_NORM: + norm_vec += tl.sum(acc * acc, axis=1) + + if STORE_ACC: + mask_n = (n * BLOCK_N + rn < N)[None, :] + tl.store(C, acc.to(C.dtype.element_ty), mask=mask_m[:, None] & mask_n) + C += BLOCK_N * stride_cn + + B += BLOCK_N * stride_bn + + if EPILOGUE_NORM: + Norm2 = Norm2 + rm + norm_vec = tl.rsqrt(norm_vec).to(Norm2.dtype.element_ty) + + if EPILOGUE_MAGNITUDE: + magnitude = tl.load(Magnitude, mask=mask_m) + norm_vec *= magnitude + + tl.store(Norm2, norm_vec, mask=mask_m) + + +def triton_mm_small_k( + a: torch.Tensor, + b: torch.Tensor, + epilogue_norm: bool = True, + source: torch.Tensor = None, + magnitude: torch.Tensor = None, + store_acc: bool = False, + acc_dtype: torch.dtype = None, + input_precision: TritonInputPrecision = TritonInputPrecision.IEEE, + fp8_fast_accum: bool = False, + output_dtype: torch.dtype = None, +): + """Computes GEMM for small K {16, 32, 64} + + Assumes that K is small enough that the MAC loop within each block is a single iteration. + Instead of iterating over K, we iterate over N per block such that each block computes a BLK_M x N row of C. Kernel grid is ceildiv(M, BLOCK_M). + + This specialized GEMM is primarily useful for low-rank projections and fusing grid-wide reductions into the epilogue. + + Currently, the following fusions are implemented: + - `epilogue_norm` - when set to True, the kernel computes the reverse 2-norm along axis=1 of AB ( `1 / 2-norm(AB, axis=1)` ) + - `source=torch.Tensor` - when passed a tensor of shape `M x N`, the kernel computes `D = AB + source` + - `magnitude=torch.Tensor` - when passed a tensor of shape `M`, the kernel additionally multiplies the epilogue norm by the magnitude vector + + Hence, when the above fusions are enabled, the kernel can be used to compute DoRA layer magnitude normalization: `magnitude * (base_weight + lora_B(lora_A(x))).norm(2, axis=1)` + + Args: + a (torch.Tensor): operand A + b (torch.Tensor): operand B + source (torch.Tensor): Operand C in `D = AB + C` + epilogue_norm (bool, optional): Whether to calculate 1 / 2-norm(AB, axis=1) + magnitude (torch.Tensor): vector to multiply epilogue norm by + store_acc (bool): whether to store `AB`, if False, then `epilogue_norm` must be True, in which case only the `2-norm` is stored + acc_dtype (torch.DType): accumulator type in MAC loop + input_precision (TritonInputPrecision): precision to use for fp32 matmul + fp8_fast_accum (bool) + output_dtype (torch.DType): type for output tensors (`D`, `2-norm`, etc.) + + Returns: + torch.Tensor + """ + assert store_acc or epilogue_norm, "Must use store_acc or epilogue_norm" + + device = a.device + + # Make sure inputs are contiguous + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + + assert a.shape[1] == b.shape[0], "Incompatible operand dimensions" + M, K = a.shape + _, N = b.shape + + assert K < 128, "K must be < 128 to use this kernel" + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + if output_dtype is None: + output_dtype = ab_dtype + + if epilogue_norm: + norm2 = torch.zeros(M, device=device, dtype=output_dtype) + + # Must set out_dtype before converting dtypes to tl types + if store_acc: + c = torch.empty((M, N), device=device, dtype=output_dtype) + + if acc_dtype is None: + acc_dtype = TRITON_SUPPORTED_ACC_TYPES[ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert ( + acc_dtype in TRITON_SUPPORTED_ACC_TYPES[a.dtype] + ), "acc_dtype not compatible with the type of a" + assert ( + acc_dtype in TRITON_SUPPORTED_ACC_TYPES[b.dtype] + ), "acc_dtype not compatible with the type of b" + + # Convert dtypes to tl types + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Use fp8 types in MAC loop + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ + tl.float8e4nv, + tl.float8e5, + ]: + ab_dtype = None + + logger.debug( + f"triton_mm_small_k: {ab_dtype=} {acc_dtype=} {input_precision=} {fp8_fast_accum=} {output_dtype=}" + ) + + # Set the fusion and other GEMM kwargs + # IMPORTANT: BLOCK_K must be equal to K + kwargs = { + "BLOCK_K": K, + "acc_dtype": acc_dtype, + "input_precision": input_precision, + "fp8_fast_accum": fp8_fast_accum, + "AB_DTYPE": ab_dtype, + "EPILOGUE_NORM": epilogue_norm, + "ADD_SOURCE": source is not None, + "EPILOGUE_MAGNITUDE": magnitude is not None, + "STORE_ACC": store_acc, + } + + # 2-norm params + if epilogue_norm: + kwargs["Norm2"] = norm2 + + # source params + if source is not None: + assert source.shape == (M, N) + kwargs["Source"] = source + kwargs["stride_sourcem"] = source.stride(0) + kwargs["stride_sourcen"] = source.stride(1) + else: + kwargs["Source"] = None + kwargs["stride_sourcem"] = 0 + kwargs["stride_sourcen"] = 0 + + # magnitude params, epilogue_norm must be True + if magnitude is not None: + assert epilogue_norm, "magnitude requires epilogue_norm" + assert magnitude.ndim == 1 and magnitude.shape[0] == M + kwargs["Magnitude"] = magnitude + + # store_acc, whether to store the intermediate AB + if store_acc: + kwargs["C"] = c + kwargs["stride_cm"] = c.stride(0) + kwargs["stride_cn"] = c.stride(1) + else: + kwargs["C"] = None + kwargs["stride_cm"] = 0 + kwargs["stride_cn"] = 0 + + # kwargs_str = " ".join( + # f"{k}={v}" for k, v in kwargs.items() if not isinstance(v, torch.Tensor) + # ) + # print(f"triton_mm_small_k: {kwargs_str}") + + # launch kernel + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) + _mm_small_k_kernel[grid]( + a, + b, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + **kwargs, + ) + + if store_acc: + if epilogue_norm: + return c, norm2 + else: + return c + return norm2 From cebc05c6e27f99627290b474efaf5c4fe067a212 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 5 May 2024 18:42:41 +0000 Subject: [PATCH 02/15] fix test imports --- test/dora/test_dora_layer.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/test/dora/test_dora_layer.py b/test/dora/test_dora_layer.py index 161d81e39..c72693930 100644 --- a/test/dora/test_dora_layer.py +++ b/test/dora/test_dora_layer.py @@ -1,18 +1,19 @@ -# Linear4bit = pytest.importorskip( -# "bitsandbytes.nn.Linear4bit", reason="requires bitsandbytes" -# ) -# HQQLinear = pytest.importorskip("hqq.core.quantize.HQQLinear", reason="requires hqq") -# BaseQuantizeConfig = pytest.importorskip( -# "hqq.core.quantize.BaseQuantizeConfig", reason="requires hqq" -# ) +import pytest + +bnbnn = pytest.importorskip( + "bitsandbytes.nn", reason="requires bitsandbytes" +) +hqq_core = pytest.importorskip("hqq.core.quantize", reason="requires hqq") + import itertools -import pytest import torch -from bitsandbytes.nn import Linear4bit -from hqq.core.quantize import BaseQuantizeConfig, HQQLinear -from prototypes.dora.dora_layer import BNBDoRALinear, DoRALinear, HQQDoRALinear +#Import modules as opposed to classes directly, otherwise pytest.importorskip always skips +Linear4bit = bnbnn.Linear4bit +BaseQuantizeConfig = hqq_core.BaseQuantizeConfig +HQQLinear = hqq_core.HQQLinear +from torchao.prototype.dora.dora_layer import BNBDoRALinear, DoRALinear, HQQDoRALinear def check(expected, actual, dtype): From e87fde627e347fd6b4a2a88bc8a75ae7dcfc465d Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 5 May 2024 19:04:14 +0000 Subject: [PATCH 03/15] add dora layers --- torchao/prototype/common/__init__.py | 0 torchao/prototype/common/profiling_tools.py | 300 ++++++++++++++++++++ torchao/prototype/dora/__init__.py | 0 torchao/prototype/dora/dora_layer.py | 260 +++++++++++++++++ 4 files changed, 560 insertions(+) create mode 100644 torchao/prototype/common/__init__.py create mode 100644 torchao/prototype/common/profiling_tools.py create mode 100644 torchao/prototype/dora/__init__.py create mode 100644 torchao/prototype/dora/dora_layer.py diff --git a/torchao/prototype/common/__init__.py b/torchao/prototype/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchao/prototype/common/profiling_tools.py b/torchao/prototype/common/profiling_tools.py new file mode 100644 index 000000000..dbaa398cb --- /dev/null +++ b/torchao/prototype/common/profiling_tools.py @@ -0,0 +1,300 @@ +import os +import types +from datetime import datetime +from functools import partial + +import pandas as pd +import torch +import torch.autograd.profiler_util +from tabulate import tabulate +from torch.autograd.profiler import record_function +from torch.cuda.nvtx import range as nvtx_range +from triton.testing import do_bench + +# from torch.cuda.nvtx import range_pop, range_push + +TIME_FORMAT_STR: str = "%m_%d" +PROFILE_DIR = "./profiles" + + +def simple_bench(fn, *args, **kwargs): + t = do_bench(lambda: fn(*args, **kwargs)) + return t + + +def check(expected, actual, atol=1e-3): + diff = (expected - actual).abs().max() + print(f"diff: {diff}") + # assert diff < atol + + +def benchmark_mm( + test_fn, xs, weight, ref_fn=torch.matmul, headers=["M", "K", "N", "test", "ref"] +): + timings = [] + for x in xs: + M, K = x.shape + _, N = weight.shape + assert x.shape[1] == weight.shape[0] + print(f"Benchmarking {(M, K, N)}") + test_times = do_bench(lambda: test_fn(x, weight)) + ref_times = do_bench(lambda: ref_fn(x, weight)) + timings.append([M, K, N, test_times, ref_times]) + return pd.DataFrame(timings, columns=headers) + + +def run_bench(xs, weight): + df = benchmark_mm(xs, weight) + print(tabulate(df, headers="keys", floatfmt=".4f")) + return df + + +class CudaProfilerCtx: + def __enter__(self): + print("Starting cuda profiler") + torch.cuda.cudart().cudaProfilerStart() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback) -> None: + print("Stopping cuda profiler") + torch.cuda.cudart().cudaProfilerStop() + if exc_type is not None: + print(f"Exception occurred: {exc_type}, {exc_value}") + # Return True to suppress the exception + return True + + def step(self): + pass + + +def get_torch_profiler( + name, + with_stack=True, + with_flops=True, + with_modules=True, + record_shapes=False, + export_events=False, + export_trace=True, + export_memory_timeline=True, + out_dir=None, + warmup=1, + active=5, +): + if not os.path.exists(out_dir): + os.makedirs(out_dir) + callback = partial( + trace_handler, + prefix=name, + out_dir=out_dir, + group_by_input_shapes=record_shapes, + group_by_stack=5 if export_events else None, + export_events=export_events, + export_trace=export_trace, + export_memory_timeline=export_memory_timeline, + ) + return torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=record_shapes, + with_stack=with_stack, + with_flops=with_flops, + with_modules=with_modules, + profile_memory=export_memory_timeline, + schedule=torch.profiler.schedule(wait=0, warmup=warmup, active=active), + on_trace_ready=callback, + ) + + +class TorchProfilerCtx: + @staticmethod + def profiler( + name, + out_dir, + warmup=1, + active=5, + record_shapes=False, + with_stack=True, + export_events=False, + export_trace=True, + export_memory_timeline=True, + ): + return get_torch_profiler( + name, + with_stack=with_stack, + record_shapes=export_memory_timeline or record_shapes, + export_events=export_events, + export_trace=export_trace, + export_memory_timeline=export_memory_timeline, + out_dir=out_dir, + warmup=warmup, + active=active, + ) + + def __init__( + self, + name, + out_dir, + warmup=1, + active=5, + record_shapes=False, + with_stack=True, + export_events=False, + export_trace=True, + export_memory_timeline=True, + ): + self.profiler = get_torch_profiler( + name, + with_stack=with_stack, + record_shapes=export_memory_timeline or record_shapes, + export_events=export_events, + export_trace=export_trace, + export_memory_timeline=export_memory_timeline, + out_dir=out_dir, + warmup=warmup, + active=active, + ) + + def __enter__(self): + return self.profiler.__enter__() + + def __exit__(self, exc_type, exc_value, exc_traceback) -> None: + return self.profiler.__exit__(exc_type, exc_value, exc_traceback) + + # def step(self): + # return self.profiler.step() + + +def get_annotation_ctx(profiler_type): + assert profiler_type in ["nsys", "torch"] + if profiler_type == "nsys": + return nvtx_range + else: + return record_function + + +def trace_handler( + prof: torch.profiler.profile, + group_by_stack: int = 5, + group_by_input_shapes: bool = False, + prefix="", + out_dir=None, + export_events=False, + export_trace=True, + export_memory_timeline=True, +): + # Prefix for file names. + out_dir = out_dir or PROFILE_DIR + timestamp = datetime.now().strftime(TIME_FORMAT_STR) + file_prefix = os.path.join(out_dir, f"{prefix}-{timestamp}") + + if export_events: + evt_list = prof.key_averages( + group_by_stack_n=group_by_stack, group_by_input_shape=group_by_input_shapes + ) + torch.save(evt_list, f"{file_prefix}-key_averages.pt") + + # Construct the trace file. + if export_trace: + prof.export_chrome_trace(f"{file_prefix}-chrome-trace.json") + + # Construct the memory timeline file. + if export_memory_timeline: + prof.export_memory_timeline( + f"{file_prefix}-memory-timeline.html", device="cuda:0" + ) + prof.export_memory_timeline( + f"{file_prefix}-memory-timeline.json", device="cuda:0" + ) + print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)) + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) + + +_PERF_COLUMNS = [ + "key", + "count", + "cpu_children", + "cpu_parent", + "self_device_time_total", + "cuda_time", + "flops", + "self_cpu_time", + "self_cpu_time_total", + "cpu_time", + "cpu_time_total" "self_device_memory_usage", + "device_memory_usage", + "self_cpu_memory_usage", + "cpu_memory_usage", +] +PERF_COLS_SELECT = [ + "key", + "cpu_parent", + "cpu_children", + # "self_cpu_time", + # "self_cpu_time_total", + "cpu_time", + "cpu_time_total", + "cuda_time", + "self_device_time_total", +] + + +# cuda_time, cpu_time are avg times -- corresponds to CUDA time avg and CPU time avg in table() above +# "self" times is not meaningful for annotated regions, since they only have child regions +def is_function(obj): + return isinstance(obj, types.FunctionType) + + +def is_method(obj): + return isinstance(obj, types.MethodType) + + +def is_private(prop): + return prop.startswith("_") + + +def should_exclude(obj, prop): + return ( + is_function(getattr(obj, prop)) + or is_method(getattr(obj, prop)) + or is_private(prop) + ) + + +def _get_event_props(event: torch.autograd.profiler_util.FunctionEvent): + props = [p for p in dir(event) if not should_exclude(event, p)] + return props + + +def get_events_df(events: torch.autograd.profiler_util.EventList): + event_props = _get_event_props(events[0]) + data = [{p: getattr(e, p) for p in event_props} for e in events] + return pd.DataFrame(data) + + +def get_perf_df(events: torch.autograd.profiler_util.EventList, sort=True): + df = get_events_df(events).filter(PERF_COLS_SELECT) + if sort: + df = df.sort_values(["cpu_time", "cuda_time"], ascending=False) + return df + + +def pivot_df( + df, + id_cols: str | list[str], + columns: str | list[str], + values: str | list[str], + column_order: list[str] = None, + show: bool = True, +): + df = df.pivot_table( + index=id_cols, + columns=columns, + values=values, + ).reset_index() + if column_order is not None: + df = df[column_order] + if show: + print(df.to_string(index=False)) + return df diff --git a/torchao/prototype/dora/__init__.py b/torchao/prototype/dora/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchao/prototype/dora/dora_layer.py b/torchao/prototype/dora/dora_layer.py new file mode 100644 index 000000000..23840ea21 --- /dev/null +++ b/torchao/prototype/dora/dora_layer.py @@ -0,0 +1,260 @@ +import logging + +import bitsandbytes as bnb +import torch +import torch.nn as nn +from bitsandbytes.nn import Linear4bit +from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear + +from prototypes.dora.kernels.matmul import triton_mm +from prototypes.dora.kernels.smallk import triton_mm_small_k + +logger = logging.getLogger(__name__) + + +# Adapted from https://github.com/AnswerDotAI/fsdp_qlora/blob/dora/scripts/dora.py +class DoRALayer(nn.Module): + """DoRA Update""" + + def __init__( + self, in_features, out_features, lora_rank, device, dtype, *args, **kwargs + ): + super().__init__() + + # LoRA layers + std_dev = 1 / torch.sqrt(torch.tensor(lora_rank).float()) + lora_A_param = nn.Parameter( + torch.randn(lora_rank, in_features).to(device=device, dtype=dtype) * std_dev + ) + self.lora_A = nn.Linear( + in_features, lora_rank, bias=False, device=device, dtype=dtype + ) + setattr(self.lora_A, "weight", lora_A_param) + + self.lora_B = nn.Linear( + lora_rank, out_features, bias=False, device=device, dtype=dtype + ) + self.lora_B.weight.data.zero_() + + def forward(self, x, base_weight): + # LoRA update, shape `bs x seq_len x in-features @ in-features x lora-rank @ lora-rank x out-features = bs x seq_len x out-features` + output = self.lora_B(self.lora_A(x)) + + # DoRA Section 4.3. Column norm no gradient update. + column_norm = ( + (base_weight + self.lora_B.weight @ self.lora_A.weight) + .norm(p=2, dim=1) + .detach() + ) + + return output, column_norm + + +# class MagnitudeLayer(nn.Module): +# "FSDP doesn't work with nn.ParameterDict hence this module: https://github.com/pytorch/pytorch/issues/79605" + +# def __init__(self, vector_data, device, dtype): +# super().__init__() +# self.magnitude = nn.Parameter(vector_data.to(device=device, dtype=dtype)) + +# def forward(self, x): +# return x * self.magnitude.view(1, 1, -1) + + +class DoRALinear(nn.Module): + """Reference DoRA Update Layer + + out = (x @ base_weight + lora_out) * magnitude_scale + where: + `lora_out = lora_B(lora_A(x)` + `magnitude_scale = (base_weight + lora_B @ lora_A).norm(p=2, dim=1) * magnitude_vector` + + base_weight is the weight of the frozen `linear` layer of shape `out_features x in_features`. + + In QDoRA, the base weight is quantized and needs an additional dequantization step. + In this base DoRA layer, a placeholder (no-op) `dequantize` method stub is provided, which simply + returns the base weight. + + For `bnb` and `hqq`, the respective `dequantize` method can be substituted. + """ + + def __init__(self, base_layer, lora_rank, *args, **kwargs): + super().__init__() + + # Get original (dequantized) weight dtype + dtype = getattr( + base_layer, "compute_dtype", next(base_layer.parameters()).dtype + ) + device = next(base_layer.parameters()).device + self.base_layer = base_layer + + # Initialize magnitude vec - TODO: this is clunky, better way to init? + base_weight = self.dequantize().clone().cuda() + self.magnitude_vec = nn.Parameter(base_weight.norm(p=2, dim=1)) + + del base_weight + torch.cuda.empty_cache() + + # DoRA layer + self.dora_layer = DoRALayer( + base_layer.in_features, + base_layer.out_features, + lora_rank, + device, + dtype, + *args, + **kwargs, + ) + + def dequantize(self): + return self.base_layer.weight + + def forward(self, x, *args, **kwargs): + # Out shape is either bs, seqlen, out_features or bs * seqlen, out_features + assert x.ndim == 2 or x.ndim == 3, "Expected 2D or 3D input" + dq_base_weight = self.dequantize() + out_shape = [*x.shape[:-1], dq_base_weight.shape[0]] + # Reshape to (bs * seqlen, out_features) + x = x.reshape(-1, x.shape[-1]) + + # LoRA update + lora_A_weight = self.dora_layer.lora_A.weight + lora_B_weight = self.dora_layer.lora_B.weight + lora_out = (x @ lora_A_weight.T) @ lora_B_weight.T + + # DoRA magnitude scale + column_norm = (dq_base_weight + lora_B_weight @ lora_A_weight).norm(p=2, dim=1) + magnitude_scale = self.magnitude_vec / column_norm + + # DoRA update + dora_out = (x @ dq_base_weight.T + lora_out) * magnitude_scale[None, :] + dora_out = dora_out.reshape(*out_shape) + + return dora_out + + def forward_fused(self, x, *args, **kwargs): + """Reorders computation as well employs two fused kernels to speed up computation. + + See README.md for description of fused kernels. + """ + assert x.ndim == 2 or x.ndim == 3, "Expected 2D or 3D input" + + dq_base_weight = self.dequantize() + # Out shape is either bs, seqlen, out_features or bs * seqlen, out_features + out_shape = [*x.shape[:-1], dq_base_weight.shape[0]] + # Reshape to (bs * seqlen, out_features) + x = x.reshape(-1, x.shape[-1]) + + # LoRA update + lora_A_weight = self.dora_layer.lora_A.weight + lora_B_weight = self.dora_layer.lora_B.weight + lora_out = (x @ lora_A_weight.T) @ lora_B_weight.T + + # DoRA magnitude + # Fused kernel #1: `magnitude_scale = (base_weight + lora_B @ lora_A).norm(p=2, dim=1) * magnitude_vector` + magnitude_scale = triton_mm_small_k( + lora_B_weight, + lora_A_weight, + epilogue_norm=True, + source=dq_base_weight, + magnitude=self.magnitude_vec, + store_acc=False, + ) + # DoRA update + # Fused kernel #2: `out = (x @ base_weight + lora_out) * magnitude_scale` + dora_out = triton_mm( + x, + dq_base_weight.T, + epilogue_source=lora_out, + epilogue_scale=magnitude_scale, + ) + dora_out = dora_out.reshape(out_shape) + + return dora_out + + # For profiling + def forward_instrumented(self, x, *args, **kwargs): + annotation_ctx = kwargs.pop("annotation_ctx") + with annotation_ctx("##dora_forward"): + with annotation_ctx("##base_layer"): + result = self.base_layer(x, *args, **kwargs) + + with annotation_ctx("##dora_layer"): + dq_weight = self.dequantize() + output, column_norm = self.dora_layer(x, dq_weight) + + with annotation_ctx("##dora_rescale"): + result += output + result = result / column_norm.view(1, 1, -1) + result = result * self.magnitude_vec.view(1, 1, -1) + + return result + + +class BNBDoRALinear(DoRALinear): + def dequantize(self): + return bnb.functional.dequantize_4bit( + self.base_layer.weight.data, self.base_layer.weight.quant_state + ) + + +class HQQDoRALinear(DoRALinear): + def dequantize(self): + return self.base_layer.dequantize() + + +if __name__ == "__main__": + # bnb_dora_layer = BNBDoraLayer(in_features=128, out_features=32, lora_rank=16) + + bs, seqlen = 1, 16 + in_features, out_features = 128, 256 + x = torch.randn(bs, seqlen, in_features).cuda().to(torch.float32) + + torch_base = nn.Linear(128, 256, bias=False).cuda() + torch_dora = DoRALinear(torch_base, lora_rank=16).cuda() + + bnb_base = Linear4bit( + input_features=in_features, + output_features=out_features, + bias=False, + quant_type="nf4", + compute_dtype=torch.float32, + quant_storage=torch.float32, + ) + bnb_base.load_state_dict(torch_base.state_dict()) + # print((bnb_base.weight - torch_base.weight).abs().max()) + bnb_base = bnb_base.to(0) + # print((W_dq - torch_base.weight.data).abs().max()) + # y = torch_base(x) + # y_bnb = bnb_base(x) + # y_bnb_ref = x @ W_dq.T + # print((y - y_bnb_ref).abs().max()) + # print((y - y_bnb).abs().max()) + # bnb_dora = BNBDoRALinear(bnb_base, lora_rank=16).cuda() + # y = torch_dora.forward(x) + # y_bnb = bnb_dora.forward(x) + # print((y - y_bnb).abs().max()) + # print((torch_base(x) - bnb_base(x)).abs().max()) + quant_config = BaseQuantizeConfig( + nbits=4, + group_size=64, + quant_zero=False, + quant_scale=False, + offload_meta=True, + view_as_float=True, + ) + + hqq_base = HQQLinear( + torch_base, + quant_config, + compute_dtype=torch.float32, + ) + + print(hqq_base.meta.keys()) + hqq_base.set_backend(HQQBackend.PYTORCH) + hqq_dora = HQQDoRALinear(hqq_base, lora_rank=16) + # print(hqq_dora.base_layer.meta) + # print(hqq_dora.base_layer.meta["nbits"]) + # print(hqq_dora.base_layer.meta["zero_scale"]) + # print(hqq_dora.dequantize().shape) + print(hqq_dora(x).shape) From 49b957b3413b8661ce4e045b332aaca7a6f1f676 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 5 May 2024 19:11:09 +0000 Subject: [PATCH 04/15] add benchmark --- benchmarks/dora/bench_utils.py | 131 +++++++++++++ benchmarks/dora/dora_bench.py | 348 +++++++++++++++++++++++++++++++++ 2 files changed, 479 insertions(+) create mode 100644 benchmarks/dora/bench_utils.py create mode 100644 benchmarks/dora/dora_bench.py diff --git a/benchmarks/dora/bench_utils.py b/benchmarks/dora/bench_utils.py new file mode 100644 index 000000000..2de4fa637 --- /dev/null +++ b/benchmarks/dora/bench_utils.py @@ -0,0 +1,131 @@ +import torch +from bitsandbytes.nn import Linear4bit +from hqq.core.quantize import BaseQuantizeConfig, HQQLinear + +from prototypes.dora.dora_layer import BNBDoRALinear, HQQDoRALinear +from prototypes.dora.kernels.matmul import triton_mm +from prototypes.dora.kernels.smallk import triton_mm_small_k + + +def make_lora_weights(ranks, in_features, out_features, dtype): + As = [torch.randn(rank, in_features, device="cuda", dtype=dtype) for rank in ranks] + Bs = [torch.randn(out_features, rank, device="cuda", dtype=dtype) for rank in ranks] + return As, Bs + + +def make_dora_source_and_magnitude(in_features, out_features, dtype): + source = torch.randn(out_features, in_features, device="cuda", dtype=dtype) + magnitude = torch.randn(out_features, device="cuda", dtype=dtype) + return source, magnitude + + +def make_inputs(batch_sizes, seqlen, in_features, dtype): + xs = [ + torch.randn(bs * seqlen, in_features, device="cuda", dtype=dtype) + for bs in batch_sizes + ] + return xs + + +def make_weights(batch_sizes, in_features, out_features, dtype): + weights = [ + torch.randn(in_features, out_features, device="cuda", dtype=dtype) + for _ in range(len(batch_sizes)) + ] + return weights + + +def make_epilogue_sources(batch_sizes, seqlen, out_features, dtype): + epilogue_sources = [ + torch.randn(bs * seqlen, out_features, device="cuda", dtype=dtype) + for bs in batch_sizes + ] + return epilogue_sources + + +def make_epilogue_scales(batch_sizes, out_features, dtype): + epilogue_scales = [ + torch.randn(out_features, device="cuda", dtype=dtype) + for _ in range(len(batch_sizes)) + ] + return epilogue_scales + + +def dora_colnorm_ref( + A: torch.Tensor, + B: torch.Tensor, + base_weight: torch.Tensor, + magnitude_vector: torch.Tensor, +): + column_norm = (base_weight + B @ A).norm(p=2, dim=1) + return magnitude_vector / column_norm + + +def dora_mm_epilogue_ref( + A: torch.Tensor, + B: torch.Tensor, + epilogue_source: torch.Tensor, + epilogue_scale: torch.Tensor, +): + out = (A @ B + epilogue_source) * epilogue_scale[None, :] + return out + + +def dora_ref(x, w, lora_A, lora_B, magnitude_vector): + # (bs x seq_len x out_features) = (bs x seq_len x in_features) @ (in_features x rank) @ (rank x out_features) + lora_out = (x @ lora_A.T) @ lora_B.T + # (out_features) + magnitude_scale = dora_colnorm_ref(lora_A, lora_B, w, magnitude_vector) + # (bs x seq_len x out_features) + dora_out_ref = dora_mm_epilogue_ref(x, w, lora_out, magnitude_scale) + return dora_out_ref + + +def dora_triton(x, w, lora_A, lora_B, magnitude_vector): + lora_out = (x @ lora_A.T) @ lora_B.T + magnitude_scale = triton_mm_small_k( + lora_B, + lora_A, + epilogue_norm=True, + source=w, + magnitude=magnitude_vector, + store_acc=False, + ) + dora_out = triton_mm(x, w, epilogue_source=lora_out, epilogue_scale=magnitude_scale) + return dora_out + + +def setup_dora_base_layers(layer_type, in_features, out_features, dtype): + if "bnb" in layer_type: + # BitsandBytes + base_layer = Linear4bit( + input_features=in_features, + output_features=out_features, + bias=False, + quant_type="nf4", + compute_dtype=dtype, + ).cuda() + base_layer.quant_state.dtype = base_layer.compute_dtype + dora_cls = BNBDoRALinear + elif "hqq" in layer_type: + # HQQ + quant_config = BaseQuantizeConfig( + nbits=4, + group_size=64, + quant_zero=False, + quant_scale=False, + offload_meta=True, + view_as_float=True, + ) + linear = torch.nn.Linear( + in_features, out_features, dtype=dtype, bias=False + ).cuda() + base_layer = HQQLinear( + linear, + quant_config, + compute_dtype=dtype, + ) + dora_cls = HQQDoRALinear + else: + raise ValueError(f"Unknown layer type: {layer_type}") + return base_layer, dora_cls diff --git a/benchmarks/dora/dora_bench.py b/benchmarks/dora/dora_bench.py new file mode 100644 index 000000000..305cfbdb1 --- /dev/null +++ b/benchmarks/dora/dora_bench.py @@ -0,0 +1,348 @@ +import argparse + +import pandas as pd +import torch +from bench_utils import ( + dora_colnorm_ref, + dora_mm_epilogue_ref, + dora_ref, + dora_triton, + make_dora_source_and_magnitude, + make_epilogue_scales, + make_epilogue_sources, + make_inputs, + make_lora_weights, + make_weights, + setup_dora_base_layers, +) +from triton.testing import do_bench + +from torchao.prototype.dora.kernels.matmul import triton_mm +from torchao.prototype.dora.kernels.smallk import triton_mm_small_k +from torchao.prototype.common.profiling_tools import pivot_df + + +def run_colnorm_bench(args): + in_features, out_features = args.in_features, args.out_features + + dtype = getattr(torch, args.dtype) + + # Inputs + As, Bs = make_lora_weights(args.dora_ranks, in_features, out_features, dtype) + source, magnitude = make_dora_source_and_magnitude(in_features, out_features, dtype) + + # torch.compile + dora_colnorm_compiled = torch.compile(dora_colnorm_ref, mode=args.compile_mode) + compiled_key = f"compiled_{args.compile_mode}" + + # Benchmark + timings = [] + + for a, b in zip(As, Bs): + ref_t = do_bench(lambda: dora_colnorm_ref(a, b, source, magnitude)) + compiled_t = do_bench(lambda: dora_colnorm_compiled(a, b, source, magnitude)) + + test_t = do_bench( + lambda: triton_mm_small_k( + b, + a, + epilogue_norm=True, + source=source, + magnitude=magnitude, + store_acc=False, + ), + ) + common_args = [a.shape[0], a.shape[1], b.shape[0], args.dtype] + timings.append([*common_args, "ref", ref_t]) + timings.append([*common_args, compiled_key, compiled_t]) + timings.append([*common_args, "triton", test_t]) + + # Group results for kernel type + headers = ["rank", "in_features", "out_features", "dtype", "kernel", "time(ms)"] + df = pd.DataFrame(timings, columns=headers) + id_cols = ["rank", "in_features", "out_features"] + pivot_df( + df, + id_cols=id_cols, + columns="kernel", + values="time(ms)", + column_order=[*id_cols, "ref", compiled_key, "triton"], + show=True, + ) + + +def run_epilogue_bench(args): + in_features, out_features = args.in_features, args.out_features + seqlen = args.seqlen + batch_sizes = ( + args.batch_sizes if isinstance(args.batch_sizes, list) else [args.batch_sizes] + ) + dtype = getattr(torch, args.dtype) + + # Inputs + xs = make_inputs(batch_sizes, seqlen, in_features, dtype) + weights = make_weights(batch_sizes, in_features, out_features, dtype) + epilogue_sources = make_epilogue_sources(batch_sizes, seqlen, out_features, dtype) + epilogue_scales = make_epilogue_scales(batch_sizes, out_features, dtype) + + # torch.compile + dora_mm_epilogue_compiled = torch.compile( + dora_mm_epilogue_ref, mode=args.compile_mode + ) + compiled_key = f"compiled_{args.compile_mode}" + + # Benchmark + timings = [] + for bs, x, w, e1, e2 in zip( + batch_sizes, xs, weights, epilogue_sources, epilogue_scales + ): + ref_t = do_bench(lambda: dora_mm_epilogue_ref(x, w, e1, e2)) + compiled_t = do_bench(lambda: dora_mm_epilogue_compiled(x, w, e1, e2)) + + test_t = do_bench( + lambda: triton_mm( + x, + w, + epilogue_source=e1, + epilogue_scale=e2, + ) + ) + common_args = [bs, seqlen, w.shape[0], w.shape[1], args.dtype] + timings.append([*common_args, "ref", ref_t]) + timings.append([*common_args, compiled_key, compiled_t]) + timings.append([*common_args, "triton", test_t]) + + # Group results for kernel type + headers = [ + "bs", + "seqlen", + "in_features", + "out_features", + "dtype", + "kernel", + "time(ms)", + ] + df = pd.DataFrame(timings, columns=headers) + id_cols = ["bs", "seqlen", "in_features", "out_features", "dtype"] + + pivot_df( + df, + id_cols=id_cols, + columns="kernel", + values="time(ms)", + column_order=[*id_cols, "ref", compiled_key, "triton"], + show=True, + ) + + +def run_full_dora(args): + """Dora Layer + + out = (x @ base_weight + lora_out) * magnitude_scale + where: + `lora_out = lora_B(lora_A(x)` + `magnitude_scale = (base_weight + lora_B @ lora_A).norm(p=2, dim=1) * magnitude_vector` + """ + + dtype = getattr(torch, args.dtype) + xs = make_inputs(args.batch_sizes, args.seqlen, args.in_features, dtype) + weights = make_weights(args.batch_sizes, args.in_features, args.out_features, dtype) + lora_As, lora_Bs = make_lora_weights( + args.dora_ranks, args.in_features, args.out_features, dtype + ) + _, magnitude_vector = make_dora_source_and_magnitude( + args.in_features, args.out_features, dtype + ) + + # torch.compile + dora_compiled = torch.compile(dora_ref, mode=args.compile_mode) + # triton_compiled = torch.compile(dora_triton, mode=args.compile_mode) + + compiled_key = f"compiled_{args.compile_mode}" + # triton_compiled_key = f"triton_compiled_{args.compile_mode}" + + # Benchmark + timings = [] + for lora_A, lora_B in zip(lora_As, lora_Bs): + for bs, x, w in zip(args.batch_sizes, xs, weights): + # ref = dora_ref(x, w, lora_A, lora_B, magnitude_vector) + # test = dora_triton(x, w, lora_A, lora_B, magnitude_vector) + # compiled = dora_compiled(x, w, lora_A, lora_B, magnitude_vector) + # test_compiled = triton_compiled(x, w, lora_A, lora_B, magnitude_vector) + # print(f"triton diff: {(ref - test).abs().max()}") + # print(f"compiled diff: {(ref - compiled).abs().max()}") + # print(f"triton compiled diff: {(ref - test_compiled).abs().max()}") + ref_t = do_bench(lambda: dora_ref(x, w, lora_A, lora_B, magnitude_vector)) + compiled_t = do_bench( + lambda: dora_compiled(x, w, lora_A, lora_B, magnitude_vector) + ) + triton_t = do_bench( + lambda: dora_triton(x, w, lora_A, lora_B, magnitude_vector) + ) + # triton_compiled_t = do_bench( + # lambda: triton_compiled(x, w, lora_A, lora_B, magnitude_vector) + # ) + + # batch_size, seq_len, rank, in_features, out_features, dtype + common_args = [ + bs, + args.seqlen, + lora_A.shape[0], + args.in_features, + args.out_features, + args.dtype, + ] + timings.append([*common_args, "ref", ref_t]) + timings.append([*common_args, compiled_key, compiled_t]) + timings.append([*common_args, "triton", triton_t]) + # timings.append([*common_args, triton_compiled_key, triton_compiled_t]) + + headers = [ + "bs", + "seqlen", + "rank", + "in_features", + "out_features", + "dtype", + "kernel", + "time(ms)", + ] + df = pd.DataFrame(timings, columns=headers) + id_cols = ["bs", "seqlen", "rank", "in_features", "out_features", "dtype"] + + pivot_df( + df, + id_cols=id_cols, + columns="kernel", + values="time(ms)", + column_order=[ + *id_cols, + "ref", + compiled_key, + "triton", + ], # , triton_compiled_key], + show=True, + ) + + +def run_dora_layer_bench(args): + dtype = getattr(torch, args.dtype) + in_features, out_features = args.in_features, args.out_features + xs = make_inputs(args.batch_sizes, args.seqlen, args.in_features, dtype) + base_layer, dora_cls = setup_dora_base_layers( + args.kernel, in_features, out_features, dtype + ) + + timings = [] + layer_key = f"{args.kernel}" + layer_key_fused = f"{args.kernel}-fused" + + for bs, x in zip(args.batch_sizes, xs): + for rank in args.dora_ranks: + dora_layer = dora_cls(base_layer, rank).cuda() + common_args = [ + bs, + args.seqlen, + rank, + args.in_features, + args.out_features, + args.dtype, + ] + ref_t = do_bench(lambda: dora_layer.forward(x)) + fused_t = do_bench(lambda: dora_layer.forward_fused(x)) + timings.append([*common_args, layer_key, ref_t]) + timings.append([*common_args, layer_key_fused, fused_t]) + + headers = [ + "bs", + "seqlen", + "rank", + "in_features", + "out_features", + "dtype", + "layer", + "time(ms)", + ] + df = pd.DataFrame(timings, columns=headers) + id_cols = ["bs", "seqlen", "rank", "in_features", "out_features", "dtype"] + + pivot_df( + df, + id_cols=id_cols, + columns="layer", + values="time(ms)", + column_order=[ + *id_cols, + layer_key, + layer_key_fused, + ], + show=True, + ) + + +def run_bench(args): + print(f"""Running {args.kernel} benchmark with dtype={args.dtype}, batch_sizes={args.batch_sizes}, seqlen={args.seqlen}, + in_features={args.in_features}, out_features={args.out_features}, dora_ranks={args.dora_ranks}""") + if args.kernel == "dora-colnorm": + return run_colnorm_bench(args) + elif args.kernel == "dora-mm-epilogue": + return run_epilogue_bench(args) + elif args.kernel == "dora-full": + return run_full_dora(args) + elif args.kernel == "dora-bnb" or args.kernel == "dora-hqq": + return run_dora_layer_bench(args) + else: + raise ValueError(f"Unknown kernel: {args.kernel}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--kernel", + type=str, + default="dora-mm-epilogue", + choices=( + "dora-colnorm", + "dora-mm-epilogue", + "dora-full", + "dora-bnb", + "dora-hqq", + ), + help="""The kernel to benchmark + + dora-colnorm: Small K GEMM with fused column-norm and magnitude vector multiplication + dora-mm-epilogue: GEMM with fused epilogue elementwise addition and broadcasted scale + dora-full: Full DORA kernel (dora-colnorm + dora-mm-epilogue) + dora-bnb: BNBDoRALinear layer with fused kernels + dora-hqq: HQQDoRALinear layer with fused kernels + """, + ) + parser.add_argument("--seqlen", type=int, default=512) + parser.add_argument( + "--batch_sizes", type=int, nargs="+", default=[1, 2, 4, 8, 16, 32] + ) + parser.add_argument("--dora_ranks", type=int, nargs="+", default=[16, 32, 64]) + parser.add_argument("--in_features", type=int, default=4096) + parser.add_argument("--out_features", type=int, default=4096) + parser.add_argument( + "--dtype", + type=str, + default="float16", + choices=("float16", "bfloat16", "float32"), + ) + parser.add_argument( + "--compile_mode", + type=str, + default="default", + choices=( + "default", + "reduce-overhead", + "max-autotune-no-cudagraphs", + "max-autotune", + ), + ) + + args = parser.parse_args() + run_bench(args) From f0df7646433b1ac9eb35547a352c0c54c6422a7c Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 5 May 2024 19:49:57 +0000 Subject: [PATCH 05/15] add readme --- torchao/prototype/common/profiling_tools.py | 112 +++++--------- torchao/prototype/dora/README.md | 154 ++++++++++++++++++++ torchao/prototype/dora/dora_profile.py | 124 ++++++++++++++++ 3 files changed, 318 insertions(+), 72 deletions(-) create mode 100644 torchao/prototype/dora/README.md create mode 100644 torchao/prototype/dora/dora_profile.py diff --git a/torchao/prototype/common/profiling_tools.py b/torchao/prototype/common/profiling_tools.py index dbaa398cb..607895d4e 100644 --- a/torchao/prototype/common/profiling_tools.py +++ b/torchao/prototype/common/profiling_tools.py @@ -67,6 +67,44 @@ def step(self): pass +def trace_handler( + prof: torch.profiler.profile, + group_by_stack: int = 5, + group_by_input_shapes: bool = False, + prefix="", + out_dir=None, + export_events=False, + export_trace=True, + export_memory_timeline=False, +): + # Prefix for file names. + out_dir = out_dir or PROFILE_DIR + timestamp = datetime.now().strftime(TIME_FORMAT_STR) + file_prefix = os.path.join(out_dir, f"{prefix}-{timestamp}") + + if export_events: + evt_list = prof.key_averages( + group_by_stack_n=group_by_stack, group_by_input_shape=group_by_input_shapes + ) + torch.save(evt_list, f"{file_prefix}-key_averages.pt") + + # Construct the trace file. + if export_trace: + prof.export_chrome_trace(f"{file_prefix}-chrome-trace.json") + + # Construct the memory timeline file. + if export_memory_timeline: + prof.export_memory_timeline( + f"{file_prefix}-memory-timeline.html", device="cuda:0" + ) + prof.export_memory_timeline( + f"{file_prefix}-memory-timeline.json", device="cuda:0" + ) + + +# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)) + + def get_torch_profiler( name, with_stack=True, @@ -75,7 +113,7 @@ def get_torch_profiler( record_shapes=False, export_events=False, export_trace=True, - export_memory_timeline=True, + export_memory_timeline=False, out_dir=None, warmup=1, active=5, @@ -118,7 +156,7 @@ def profiler( with_stack=True, export_events=False, export_trace=True, - export_memory_timeline=True, + export_memory_timeline=False, ): return get_torch_profiler( name, @@ -132,39 +170,6 @@ def profiler( active=active, ) - def __init__( - self, - name, - out_dir, - warmup=1, - active=5, - record_shapes=False, - with_stack=True, - export_events=False, - export_trace=True, - export_memory_timeline=True, - ): - self.profiler = get_torch_profiler( - name, - with_stack=with_stack, - record_shapes=export_memory_timeline or record_shapes, - export_events=export_events, - export_trace=export_trace, - export_memory_timeline=export_memory_timeline, - out_dir=out_dir, - warmup=warmup, - active=active, - ) - - def __enter__(self): - return self.profiler.__enter__() - - def __exit__(self, exc_type, exc_value, exc_traceback) -> None: - return self.profiler.__exit__(exc_type, exc_value, exc_traceback) - - # def step(self): - # return self.profiler.step() - def get_annotation_ctx(profiler_type): assert profiler_type in ["nsys", "torch"] @@ -174,43 +179,6 @@ def get_annotation_ctx(profiler_type): return record_function -def trace_handler( - prof: torch.profiler.profile, - group_by_stack: int = 5, - group_by_input_shapes: bool = False, - prefix="", - out_dir=None, - export_events=False, - export_trace=True, - export_memory_timeline=True, -): - # Prefix for file names. - out_dir = out_dir or PROFILE_DIR - timestamp = datetime.now().strftime(TIME_FORMAT_STR) - file_prefix = os.path.join(out_dir, f"{prefix}-{timestamp}") - - if export_events: - evt_list = prof.key_averages( - group_by_stack_n=group_by_stack, group_by_input_shape=group_by_input_shapes - ) - torch.save(evt_list, f"{file_prefix}-key_averages.pt") - - # Construct the trace file. - if export_trace: - prof.export_chrome_trace(f"{file_prefix}-chrome-trace.json") - - # Construct the memory timeline file. - if export_memory_timeline: - prof.export_memory_timeline( - f"{file_prefix}-memory-timeline.html", device="cuda:0" - ) - prof.export_memory_timeline( - f"{file_prefix}-memory-timeline.json", device="cuda:0" - ) - print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)) - print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) - - _PERF_COLUMNS = [ "key", "count", diff --git a/torchao/prototype/dora/README.md b/torchao/prototype/dora/README.md new file mode 100644 index 000000000..ae423ce43 --- /dev/null +++ b/torchao/prototype/dora/README.md @@ -0,0 +1,154 @@ +### Fused DoRA Kernels + +Fused kernels for DoRA layer optimization. + +#### Background + +[DoRA](https://arxiv.org/abs/2402.09353) (weight-decomposed low-rank adaptation) is a variant of LoRA that decomposes the LoRA update into magnitude and vector components. + +The DoRA layer is roughly as follows: + +```python + dora_out = (x @ base_weight.T + lora_out) * magnitude_scale +``` + +where: + +```python + lora_out = lora_B(lora_A(x)) + magnitude_scale = magnitude_vector / (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1) +``` + +Additionally: + +- `lora_A` and `lora_B` are `linear` layers with weight shapes `rank x in_features` and `out_features x rank`. +- `base_weight` is the weight of the frozen `linear` layer of shape `out_features x in_features`. +- `magnitude_vector` is initialized as the columnwise `2-norm` of the frozen weight (shape `out-features`). +- `x` are the inputs of shape `batch_size x seqlen x in_features` + +#### Key Contributions + +After initial profiling, and as outlined above, the `DoRA` computation requires multiple kernels, listed here in order of compute intensity: + +- 4 GEMMs + - `x @ base_weight` + - `lora_B(lora_A(x))` + - `lora_B.weight @ lora_A.weight` +- 1 Reduction - `2-norm` +- 4 Elementwise - matrix-matrix additions and broadcasted matrix-vector multiplications. + +While `torch.compile` (and `CUDA` graphs) can partially mitigate the overhead of multiple small kernels and improve compute efficiency of individual kernels, there remains room for additional optimization by reordering the computations to facilitate fusions, and more importantly, exploiting the unique shapes of the GEMMs, thereby decreasing the number of kernel launches and increasing the compute intensity of each kernel. + +**1 - Small K Fused Kernel** + +Note that the `lora_B.weight @ lora_A.weight` is an extreme case of skinny by fat matmul. That is, `lora_B.weight` is `out_features x lora_rank` and `lora_A.weight` is `lora_rank x in_features`. Since `lora_rank` is typically `< 64` while `{in,out}-features` are typically `> 4096` (e.g., `Llama MLP / QKV projections`), this `GEMM` is inefficient, since each `CTA` loads a block, only to perform a few `MAC` iterations given small `K`. + +Moreover, note that the result of this `GEMM` is not needed -- we only need the `2-norm` of this computation. + +Combining these two observations, we can write a fused kernel where: + +1. Each `CTA` computes an _entire_ row of the output matrix, with the key assumption that `BLOCK_K = K`. That is, each `CTA` does a single MAC iteration to compute a `BLOCK_M x BLOCK_N` output, then iterates across dimension `N`. +2. Since each block processes an entire row, we can now additionally fuse a grid-wise reduction along `axis=1` into the kernel. In this case, we can directly fold the `2-norm` computation into the `GEMM`. +3. As an added bonus, we can also include the `base_weight` elementwise addition and `magnitude_vector` multiplication into the `GEMM` epilogue. + +Altogether, this allows us to fuse the following computation into a single kernel: + +```python + magnitude_scale = (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1) * magnitude_vector +``` + +**2 - Fused Epilogue GEMM** + +Additionally, instead of computing the base layer output before the `DoRA / LoRA` updates, we can compute the latter (`loRA layer` and `magnitude_scale`) first, and fold these into the epilogue of the base layer `GEMM`: + +```python + + #DoRA / LoRA updates + lora_out = lora_B(lora_A(x)) + magnitude_scale = magnitude_vector / (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1) + + #This is now a single kernel + final_out = (x @ base_weight.T + lora_out) * magnitude_scale +``` + +#### Usage + +The fused kernels can be used to implement either DoRA or QDoRA layers. + +A reference implementation lives in `dora.dora_layer.DoRALinear`, which defines a base QDoRA linear layer (with a stub `dequantize` method) along with corresponding `BNBDoRALinear` and `HQQDoRALinear` subclasses, which override `dequantize` with their respective methods. + +_Example_ + +```python + import torch + from bitsandbytes.nn import Linear4bit + + bs, seqlen = 1, 512 + dtype = torch.float16 + in_features, out_features, lora_rank = 4096, 4096, 16 + x = torch.randn(bs, seqlen, in_features, dtype=dtype, device="cuda") + + #Construct bitsnbytes QDoRA layer + base_layer = Linear4bit( + input_features=in_features, + output_features=out_features, + bias=False, + quant_type="nf4", + compute_dtype=dtype, + ).cuda() + base_layer.quant_state.dtype = base_layer.compute_dtype + dora_layer = BNBDoRALinear(base_layer, lora_rank) + + #Run reference forward pass + ref = dora_layer.forward(x) + + #Run fused forward pass + fused_out = dora_layer.forward_fused(x) +``` + +See `test/test_dora_layer.py` and `benchmarks/dora_bench.py` for more detailed usage. + +#### Tests + +See `test/test_dora_fusion.py`, which checks the 2 fused kernels across a range of dtypes and shapes. + +#### Benchmarks + +See `benchmarks/dora_bench.py`. + +```python +python benchmarks/dora_bench.py --help +``` + +Run with flag `--kernel` set to one of `{dora-colnorm,dora-mm-epilogue}`, to benchmark the respective fused kernels against a reference `torch` / `torch.compile` implementation, or `--kernel=dora-full` to bench against the entire `DoRA` computation. + +Additionally, passing either `--kernel={dora-bnb, dora-hqq}` will bench a reference QDoRA layer against their fused implementations (see `Usage` below). + +#### Profiling + +The reference `DoRALinear` layer described above also has an instrumented forward pass with annotated regions for each of the ops in the `dora` layer. + +An example script for running a profiled forward pass is provided in `dora/dora_profile.py`. + +To run with `torch.profiler`: + +``` +python dora_profile.py +``` + +To run with `nsys`: + +``` +nsys profile --capture_range=cudaProfilerApi ... python dora_profile.py --profiler=nsys +``` + +where `...` are other desired `nsys` options. Note that `--capture_range=cudaProfilerApi` is required. + +#### Next steps + +- [ ] `torch.compile` entire DoRA layer with fused k + ernels + +- [ ] Implement backwards pass + +- [ ] Refactor! Lots of repeated profiling / kernel functions across `galore`, `hqq`, and `dora` can now be refactored into single module. Separate PR? diff --git a/torchao/prototype/dora/dora_profile.py b/torchao/prototype/dora/dora_profile.py new file mode 100644 index 000000000..bf8776974 --- /dev/null +++ b/torchao/prototype/dora/dora_profile.py @@ -0,0 +1,124 @@ +import argparse + +import torch +from bitsandbytes.nn import Linear4bit +from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear + +from torchao.prototype.common.profiling_tools import ( + CudaProfilerCtx, + TorchProfilerCtx, + get_annotation_ctx, +) +from torchao.prototype.dora.dora_layer import BNBDoRALinear, DoRALinear, HQQDoRALinear + + +def run_profile(args, dora_forward): + if args.profiler == "nsys": + profiler = CudaProfilerCtx() + else: + profiler = TorchProfilerCtx.profiler( + f"dora_layer-{args.layer_type}", + active=max(5, args.num_iterations), + warmup=0, + out_dir=args.outdir, + ) + + annotation_ctx = get_annotation_ctx(args.profiler) + + x = torch.randn( + args.bs, args.seqlen, args.in_features, dtype=getattr(torch, args.dtype) + ).cuda() + for _ in range(args.warmup): + _ = dora_forward(x, annotation_ctx=annotation_ctx) + + with profiler as prof: + for _ in range(args.num_iterations): + _ = dora_forward(x, annotation_ctx=annotation_ctx) + prof.step() + print(f"Finished profiling, saving results to {args.outdir}") + + +def run(args): + in_features, out_features = args.in_features, args.out_features + dora_rank = args.dora_rank + dtype = getattr(torch, args.dtype) + + base_layer = torch.nn.Linear( + in_features, out_features, dtype=dtype, bias=False + ).cuda() + + if args.layer_type == "torch": + dora_layer = DoRALinear(base_layer=base_layer, lora_rank=dora_rank) + elif args.layer_type == "bnb": + base_layer = Linear4bit( + input_features=in_features, + output_features=out_features, + bias=False, + quant_type="nf4", + compute_dtype=dtype, + ) + base_layer.quant_state.dtype = base_layer.compute_dtype + dora_layer = BNBDoRALinear(base_layer=base_layer, lora_rank=dora_rank) + elif args.layer_type == "hqq": + quant_config = BaseQuantizeConfig( + nbits=4, + group_size=64, + quant_zero=False, + quant_scale=False, + offload_meta=True, + view_as_float=True, + ) + + base_layer = HQQLinear( + base_layer, + quant_config, + compute_dtype=dtype, + ) + + base_layer.set_backend(HQQBackend.PYTORCH) + dora_layer = HQQDoRALinear(base_layer=base_layer, lora_rank=dora_rank) + + run_profile(args, dora_layer.forward_instrumented) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--profiler", + type=str, + default="torch", + choices=("nsys", "torch"), + help=""" + Which profiler to use + + Default is the torch.profiler + + If using `nsys`, run the nsys profiler as so, substituting with other desired nsys options: + `nsys profile --capture-range=cudaProfilerApi ... python dora_profile.py --profiler=nsys` + + Note that `--capture-range=cudaProfilerApi` is required + """, + ) + parser.add_argument( + "--layer_type", + type=str, + default="torch", + choices=("torch", "bnb", "hqq"), + ) + parser.add_argument("--in_features", type=int, default=4096) + parser.add_argument("--out_features", type=int, default=4096) + parser.add_argument("--dora_rank", type=int, default=16) + parser.add_argument("--bs", type=int, default=1) + parser.add_argument("--seqlen", type=int, default=512) + parser.add_argument( + "--dtype", + type=str, + default="float16", + choices=("float16", "bfloat16", "float32"), + ) + parser.add_argument("--num_iterations", type=int, default=10) + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--outdir", type=str, default="./dora_profiles") + run(parser.parse_args()) From c92d17f67f25b836736f8379faf85f587045ea27 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 5 May 2024 20:56:41 +0000 Subject: [PATCH 06/15] update readme --- torchao/prototype/dora/README.md | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torchao/prototype/dora/README.md b/torchao/prototype/dora/README.md index ae423ce43..c8cde461e 100644 --- a/torchao/prototype/dora/README.md +++ b/torchao/prototype/dora/README.md @@ -143,12 +143,3 @@ nsys profile --capture_range=cudaProfilerApi ... python dora_profile.py --profil ``` where `...` are other desired `nsys` options. Note that `--capture_range=cudaProfilerApi` is required. - -#### Next steps - -- [ ] `torch.compile` entire DoRA layer with fused k - ernels - -- [ ] Implement backwards pass - -- [ ] Refactor! Lots of repeated profiling / kernel functions across `galore`, `hqq`, and `dora` can now be refactored into single module. Separate PR? From f5cf2772891b8afa5885eb298c968df1b80ec240 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 5 May 2024 21:08:16 +0000 Subject: [PATCH 07/15] update readme --- torchao/prototype/dora/README.md | 39 ++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/torchao/prototype/dora/README.md b/torchao/prototype/dora/README.md index c8cde461e..78c0ef5df 100644 --- a/torchao/prototype/dora/README.md +++ b/torchao/prototype/dora/README.md @@ -1,6 +1,6 @@ ### Fused DoRA Kernels -Fused kernels for DoRA layer optimization. +Fused DoRA layer implementation that reduces number of individual kernels from ~10 -> 5. #### Background @@ -19,29 +19,33 @@ where: magnitude_scale = magnitude_vector / (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1) ``` -Additionally: - - `lora_A` and `lora_B` are `linear` layers with weight shapes `rank x in_features` and `out_features x rank`. - `base_weight` is the weight of the frozen `linear` layer of shape `out_features x in_features`. - `magnitude_vector` is initialized as the columnwise `2-norm` of the frozen weight (shape `out-features`). - `x` are the inputs of shape `batch_size x seqlen x in_features` -#### Key Contributions +#### Optimization -After initial profiling, and as outlined above, the `DoRA` computation requires multiple kernels, listed here in order of compute intensity: +After initial profiling, and as outlined above, the `DoRA` update layer requires multiple kernels. -- 4 GEMMs +In order of compute intensity: + +- 4 GEMMs: - `x @ base_weight` - `lora_B(lora_A(x))` - `lora_B.weight @ lora_A.weight` -- 1 Reduction - `2-norm` -- 4 Elementwise - matrix-matrix additions and broadcasted matrix-vector multiplications. +- 1 Reduction: `2-norm` +- 4 Elementwise: matrix-matrix additions (2) and broadcasted matrix-vector multiplications (2). While `torch.compile` (and `CUDA` graphs) can partially mitigate the overhead of multiple small kernels and improve compute efficiency of individual kernels, there remains room for additional optimization by reordering the computations to facilitate fusions, and more importantly, exploiting the unique shapes of the GEMMs, thereby decreasing the number of kernel launches and increasing the compute intensity of each kernel. +#### Key Contributions + **1 - Small K Fused Kernel** -Note that the `lora_B.weight @ lora_A.weight` is an extreme case of skinny by fat matmul. That is, `lora_B.weight` is `out_features x lora_rank` and `lora_A.weight` is `lora_rank x in_features`. Since `lora_rank` is typically `< 64` while `{in,out}-features` are typically `> 4096` (e.g., `Llama MLP / QKV projections`), this `GEMM` is inefficient, since each `CTA` loads a block, only to perform a few `MAC` iterations given small `K`. +Note that the `lora_B.weight @ lora_A.weight` has a specific shape, where `K << {M, N}`. That is, `lora_B.weight` is `out_features x lora_rank` and `lora_A.weight` is `lora_rank x in_features`. + +Since `lora_rank` is typically `< 64` while `{in,out}-features` are typically `> 4096` (e.g., `Llama MLP / QKV projections`), this `GEMM` is inefficient, since each `CTA` loads a block, only to perform a few `MAC` iterations given small `K`. Moreover, note that the result of this `GEMM` is not needed -- we only need the `2-norm` of this computation. @@ -73,15 +77,16 @@ Additionally, instead of computing the base layer output before the `DoRA / LoRA #### Usage -The fused kernels can be used to implement either DoRA or QDoRA layers. +The fused kernels can be used to implement `DoRA` / `QDoRA` layers. -A reference implementation lives in `dora.dora_layer.DoRALinear`, which defines a base QDoRA linear layer (with a stub `dequantize` method) along with corresponding `BNBDoRALinear` and `HQQDoRALinear` subclasses, which override `dequantize` with their respective methods. +A reference implementation is provided in `dora.dora_layer.DoRALinear`, which defines a base `QDoRA` linear layer (with a stub `dequantize` method) along with corresponding `BNBDoRALinear` and `HQQDoRALinear` subclasses, which override `dequantize` with their respective methods. _Example_ ```python import torch from bitsandbytes.nn import Linear4bit + from torchao.prototypes.dora.dora_layer import BNBDoRALinear bs, seqlen = 1, 512 dtype = torch.float16 @@ -110,7 +115,7 @@ See `test/test_dora_layer.py` and `benchmarks/dora_bench.py` for more detailed u #### Tests -See `test/test_dora_fusion.py`, which checks the 2 fused kernels across a range of dtypes and shapes. +See `test/dora/test*`, for correctness checks of the fused kernels and layers. #### Benchmarks @@ -122,11 +127,11 @@ python benchmarks/dora_bench.py --help Run with flag `--kernel` set to one of `{dora-colnorm,dora-mm-epilogue}`, to benchmark the respective fused kernels against a reference `torch` / `torch.compile` implementation, or `--kernel=dora-full` to bench against the entire `DoRA` computation. -Additionally, passing either `--kernel={dora-bnb, dora-hqq}` will bench a reference QDoRA layer against their fused implementations (see `Usage` below). +Additionally, passing either `--kernel={dora-bnb, dora-hqq}` will bench a reference `QDoRA` layer against their fused implementations. #### Profiling -The reference `DoRALinear` layer described above also has an instrumented forward pass with annotated regions for each of the ops in the `dora` layer. +The reference `DoRALinear` layer described above also has an instrumented forward pass with annotated regions for each of the `DoRA` ops. An example script for running a profiled forward pass is provided in `dora/dora_profile.py`. @@ -136,10 +141,14 @@ To run with `torch.profiler`: python dora_profile.py ``` +which outputs chrome trace to default folder `dora_profiles`. + To run with `nsys`: ``` nsys profile --capture_range=cudaProfilerApi ... python dora_profile.py --profiler=nsys ``` -where `...` are other desired `nsys` options. Note that `--capture_range=cudaProfilerApi` is required. +where `...` are other desired `nsys` options. + +Note that `--capture_range=cudaProfilerApi` is required. From ad6ded8fe353d495c85ba0c291fb5264765ec0af Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 5 May 2024 21:11:55 +0000 Subject: [PATCH 08/15] small readme fix --- torchao/prototype/dora/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/dora/README.md b/torchao/prototype/dora/README.md index 78c0ef5df..34cd44b4e 100644 --- a/torchao/prototype/dora/README.md +++ b/torchao/prototype/dora/README.md @@ -58,7 +58,7 @@ Combining these two observations, we can write a fused kernel where: Altogether, this allows us to fuse the following computation into a single kernel: ```python - magnitude_scale = (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1) * magnitude_vector + magnitude_scale = magnitude_vector / (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1) ``` **2 - Fused Epilogue GEMM** From 893b003d2036e91ddb1a26141bc2fcfa299457e2 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 5 May 2024 21:24:49 +0000 Subject: [PATCH 09/15] skip test without triton --- test/dora/test_dora_fusion.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/dora/test_dora_fusion.py b/test/dora/test_dora_fusion.py index 78a9d2e6e..1e53071f4 100644 --- a/test/dora/test_dora_fusion.py +++ b/test/dora/test_dora_fusion.py @@ -1,6 +1,8 @@ +import pytest + +triton = pytest.importorskip("triton", reason="requires triton") import itertools -import pytest import torch from torchao.prototype.dora.kernels.matmul import triton_mm @@ -53,6 +55,7 @@ def check(expected, actual, dtype): # assert diff < atol return diff + @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize( "shape, store_acc, epilogue_norm, add_source, magnitude_vector, dtype", @@ -130,6 +133,7 @@ def test_dora_column_norm( ) ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize( "shape, dtype, epilogue_add, epilogue_scale", From 9792db00c8fc5de94d9d18e19578216f3b4bc707 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 5 May 2024 22:46:40 +0000 Subject: [PATCH 10/15] update test require python >= 3.11 --- test/dora/test_dora_fusion.py | 4 ++++ test/dora/test_dora_layer.py | 9 +++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/test/dora/test_dora_fusion.py b/test/dora/test_dora_fusion.py index 1e53071f4..7d2907214 100644 --- a/test/dora/test_dora_fusion.py +++ b/test/dora/test_dora_fusion.py @@ -1,3 +1,5 @@ +import sys + import pytest triton = pytest.importorskip("triton", reason="requires triton") @@ -56,6 +58,7 @@ def check(expected, actual, dtype): return diff +@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires Python >= 3.11") @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize( "shape, store_acc, epilogue_norm, add_source, magnitude_vector, dtype", @@ -134,6 +137,7 @@ def test_dora_column_norm( ) +@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires Python >= 3.11") @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize( "shape, dtype, epilogue_add, epilogue_scale", diff --git a/test/dora/test_dora_layer.py b/test/dora/test_dora_layer.py index c72693930..d0e97c142 100644 --- a/test/dora/test_dora_layer.py +++ b/test/dora/test_dora_layer.py @@ -1,15 +1,15 @@ +import sys + import pytest -bnbnn = pytest.importorskip( - "bitsandbytes.nn", reason="requires bitsandbytes" -) +bnbnn = pytest.importorskip("bitsandbytes.nn", reason="requires bitsandbytes") hqq_core = pytest.importorskip("hqq.core.quantize", reason="requires hqq") import itertools import torch -#Import modules as opposed to classes directly, otherwise pytest.importorskip always skips +# Import modules as opposed to classes directly, otherwise pytest.importorskip always skips Linear4bit = bnbnn.Linear4bit BaseQuantizeConfig = hqq_core.BaseQuantizeConfig HQQLinear = hqq_core.HQQLinear @@ -58,6 +58,7 @@ def _arg_to_id(arg): ) +@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires Python >= 3.11") @pytest.mark.parametrize( "bs, seqlen, in_features, out_features, lora_rank, dtype, model_type", TEST_CONFIGS, From 901ac116ccfa36cf3dcc03c0fa9f47eda17fea61 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 5 May 2024 23:08:53 +0000 Subject: [PATCH 11/15] fix test skip python < 3.11 --- test/dora/test_dora_fusion.py | 6 ++++-- test/dora/test_dora_layer.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/test/dora/test_dora_fusion.py b/test/dora/test_dora_fusion.py index 7d2907214..a7959f85a 100644 --- a/test/dora/test_dora_fusion.py +++ b/test/dora/test_dora_fusion.py @@ -2,7 +2,11 @@ import pytest +if sys.version_info < (3, 11): + pytest.skip("requires Python >= 3.11", allow_module_level=True) + triton = pytest.importorskip("triton", reason="requires triton") + import itertools import torch @@ -58,7 +62,6 @@ def check(expected, actual, dtype): return diff -@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires Python >= 3.11") @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize( "shape, store_acc, epilogue_norm, add_source, magnitude_vector, dtype", @@ -137,7 +140,6 @@ def test_dora_column_norm( ) -@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires Python >= 3.11") @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize( "shape, dtype, epilogue_add, epilogue_scale", diff --git a/test/dora/test_dora_layer.py b/test/dora/test_dora_layer.py index d0e97c142..dd38cc8d6 100644 --- a/test/dora/test_dora_layer.py +++ b/test/dora/test_dora_layer.py @@ -2,6 +2,9 @@ import pytest +if sys.version_info < (3, 11): + pytest.skip("requires Python >= 3.11", allow_module_level=True) + bnbnn = pytest.importorskip("bitsandbytes.nn", reason="requires bitsandbytes") hqq_core = pytest.importorskip("hqq.core.quantize", reason="requires hqq") @@ -58,7 +61,6 @@ def _arg_to_id(arg): ) -@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires Python >= 3.11") @pytest.mark.parametrize( "bs, seqlen, in_features, out_features, lora_rank, dtype, model_type", TEST_CONFIGS, From 4b4d1552fbe76f4b5fffb3fef46130f772976629 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 6 May 2024 10:43:36 +0000 Subject: [PATCH 12/15] fix autotuning configs --- torchao/prototype/dora/kernels/smallk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/dora/kernels/smallk.py b/torchao/prototype/dora/kernels/smallk.py index 227264cf0..fc24ea223 100644 --- a/torchao/prototype/dora/kernels/smallk.py +++ b/torchao/prototype/dora/kernels/smallk.py @@ -262,7 +262,7 @@ def small_k_early_config_prune(configs, named_args, **kwargs): # @heuristics(SMALLK_HEURISTICS) @autotune( - get_small_k_configs()[:10], + get_small_k_configs(), key=["M", "N", "K"], prune_configs_by={ "early_config_prune": small_k_early_config_prune, From 40d05507687e1e9b3b2afde4df9404ce0e4cca87 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 6 May 2024 13:03:29 +0000 Subject: [PATCH 13/15] make readme more readable --- torchao/prototype/dora/README.md | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/torchao/prototype/dora/README.md b/torchao/prototype/dora/README.md index 34cd44b4e..3a2b6db78 100644 --- a/torchao/prototype/dora/README.md +++ b/torchao/prototype/dora/README.md @@ -1,8 +1,19 @@ -### Fused DoRA Kernels +## Fused DoRA Kernels Fused DoRA layer implementation that reduces number of individual kernels from ~10 -> 5. -#### Background +## Contents + +- [Background](#background) +- [Optimization](#optimization) +- [Key Contributions](#key-contributions) +- [Usage](#usage) +- [Tests](#tests) +- [Benchmarks](#benchmarks) +- [Profiling](#profiling) +- [Next Steps](#next-steps) + +## Background [DoRA](https://arxiv.org/abs/2402.09353) (weight-decomposed low-rank adaptation) is a variant of LoRA that decomposes the LoRA update into magnitude and vector components. @@ -24,7 +35,7 @@ where: - `magnitude_vector` is initialized as the columnwise `2-norm` of the frozen weight (shape `out-features`). - `x` are the inputs of shape `batch_size x seqlen x in_features` -#### Optimization +## Optimization After initial profiling, and as outlined above, the `DoRA` update layer requires multiple kernels. @@ -39,7 +50,7 @@ In order of compute intensity: While `torch.compile` (and `CUDA` graphs) can partially mitigate the overhead of multiple small kernels and improve compute efficiency of individual kernels, there remains room for additional optimization by reordering the computations to facilitate fusions, and more importantly, exploiting the unique shapes of the GEMMs, thereby decreasing the number of kernel launches and increasing the compute intensity of each kernel. -#### Key Contributions +## Key Contributions **1 - Small K Fused Kernel** @@ -75,7 +86,7 @@ Additionally, instead of computing the base layer output before the `DoRA / LoRA final_out = (x @ base_weight.T + lora_out) * magnitude_scale ``` -#### Usage +## Usage The fused kernels can be used to implement `DoRA` / `QDoRA` layers. @@ -113,11 +124,11 @@ _Example_ See `test/test_dora_layer.py` and `benchmarks/dora_bench.py` for more detailed usage. -#### Tests +### Tests See `test/dora/test*`, for correctness checks of the fused kernels and layers. -#### Benchmarks +## Benchmarks See `benchmarks/dora_bench.py`. @@ -129,7 +140,7 @@ Run with flag `--kernel` set to one of `{dora-colnorm,dora-mm-epilogue}`, to ben Additionally, passing either `--kernel={dora-bnb, dora-hqq}` will bench a reference `QDoRA` layer against their fused implementations. -#### Profiling +## Profiling The reference `DoRALinear` layer described above also has an instrumented forward pass with annotated regions for each of the `DoRA` ops. From 43d387cec2275e77dcecbbe546cdd685291833d0 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 6 May 2024 13:04:36 +0000 Subject: [PATCH 14/15] remove next steps from readme --- torchao/prototype/dora/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/prototype/dora/README.md b/torchao/prototype/dora/README.md index 3a2b6db78..d5bebc68a 100644 --- a/torchao/prototype/dora/README.md +++ b/torchao/prototype/dora/README.md @@ -11,7 +11,6 @@ Fused DoRA layer implementation that reduces number of individual kernels from ~ - [Tests](#tests) - [Benchmarks](#benchmarks) - [Profiling](#profiling) -- [Next Steps](#next-steps) ## Background From 197b3724c2052c0fdbb6f3d9ca0ac330a27f5d23 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 6 May 2024 15:34:39 +0000 Subject: [PATCH 15/15] clean up dora layer --- torchao/prototype/dora/dora_layer.py | 66 ---------------------------- 1 file changed, 66 deletions(-) diff --git a/torchao/prototype/dora/dora_layer.py b/torchao/prototype/dora/dora_layer.py index 23840ea21..e0c97cdcb 100644 --- a/torchao/prototype/dora/dora_layer.py +++ b/torchao/prototype/dora/dora_layer.py @@ -50,17 +50,6 @@ def forward(self, x, base_weight): return output, column_norm -# class MagnitudeLayer(nn.Module): -# "FSDP doesn't work with nn.ParameterDict hence this module: https://github.com/pytorch/pytorch/issues/79605" - -# def __init__(self, vector_data, device, dtype): -# super().__init__() -# self.magnitude = nn.Parameter(vector_data.to(device=device, dtype=dtype)) - -# def forward(self, x): -# return x * self.magnitude.view(1, 1, -1) - - class DoRALinear(nn.Module): """Reference DoRA Update Layer @@ -203,58 +192,3 @@ def dequantize(self): return self.base_layer.dequantize() -if __name__ == "__main__": - # bnb_dora_layer = BNBDoraLayer(in_features=128, out_features=32, lora_rank=16) - - bs, seqlen = 1, 16 - in_features, out_features = 128, 256 - x = torch.randn(bs, seqlen, in_features).cuda().to(torch.float32) - - torch_base = nn.Linear(128, 256, bias=False).cuda() - torch_dora = DoRALinear(torch_base, lora_rank=16).cuda() - - bnb_base = Linear4bit( - input_features=in_features, - output_features=out_features, - bias=False, - quant_type="nf4", - compute_dtype=torch.float32, - quant_storage=torch.float32, - ) - bnb_base.load_state_dict(torch_base.state_dict()) - # print((bnb_base.weight - torch_base.weight).abs().max()) - bnb_base = bnb_base.to(0) - # print((W_dq - torch_base.weight.data).abs().max()) - # y = torch_base(x) - # y_bnb = bnb_base(x) - # y_bnb_ref = x @ W_dq.T - # print((y - y_bnb_ref).abs().max()) - # print((y - y_bnb).abs().max()) - # bnb_dora = BNBDoRALinear(bnb_base, lora_rank=16).cuda() - # y = torch_dora.forward(x) - # y_bnb = bnb_dora.forward(x) - # print((y - y_bnb).abs().max()) - # print((torch_base(x) - bnb_base(x)).abs().max()) - quant_config = BaseQuantizeConfig( - nbits=4, - group_size=64, - quant_zero=False, - quant_scale=False, - offload_meta=True, - view_as_float=True, - ) - - hqq_base = HQQLinear( - torch_base, - quant_config, - compute_dtype=torch.float32, - ) - - print(hqq_base.meta.keys()) - hqq_base.set_backend(HQQBackend.PYTORCH) - hqq_dora = HQQDoRALinear(hqq_base, lora_rank=16) - # print(hqq_dora.base_layer.meta) - # print(hqq_dora.base_layer.meta["nbits"]) - # print(hqq_dora.base_layer.meta["zero_scale"]) - # print(hqq_dora.dequantize().shape) - print(hqq_dora(x).shape)