From d8884e6f6a294fc8f1a325665d86a07603d43864 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:54:26 +0000 Subject: [PATCH 1/8] Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability --- bitblas/ops/impl/base.py | 16 +++ bitblas/ops/impl/batch_matmul_impl.py | 166 ++++++++++++++++---------- 2 files changed, 119 insertions(+), 63 deletions(-) create mode 100644 bitblas/ops/impl/base.py diff --git a/bitblas/ops/impl/base.py b/bitblas/ops/impl/base.py new file mode 100644 index 000000000..6d510f7da --- /dev/null +++ b/bitblas/ops/impl/base.py @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from abc import ABC, abstractmethod + +# TODO: Refactor all the tir script implementations to use this base class +# Abstract base class for TIR script emitters +class TIRScriptEmitter(ABC): + @abstractmethod + def emit(self): + raise NotImplementedError + +# Abstract base class for TIR script selectors +class TIRScriptSelector(ABC): + @abstractmethod + def select(self): + raise NotImplementedError diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py index 09b536afa..75449ea4b 100644 --- a/bitblas/ops/impl/batch_matmul_impl.py +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -4,63 +4,117 @@ from bitblas import tvm from tvm import te from bitblas.ops.operator import TransformKind +from .base import TIRScriptEmitter, TIRScriptSelector +from bitblas import tvm +from tvm import te +from bitblas.ops.operator import TransformKind +class BatchMatMulEmitter(TIRScriptEmitter): + def __init__( + self, + batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + ): + self.batch = batch + self.M = self._validate_dimension(M, "M") + self.N = self._validate_dimension(N, "N") + self.K = self._validate_dimension(K, "K") + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.accum_dtype = accum_dtype + self.with_bias = with_bias + self.layout = layout + self._validate_layout() + + @staticmethod + def _validate_dimension(dim, name): + if not isinstance(dim, int): + return tvm.te.var(name.lower()) + return dim -def matmul_nt( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, -): - if not isinstance(M, int): - M = tvm.te.var("m") - A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) - B = te.placeholder((Batch, N, K), name="B", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + def _validate_layout(self): + if self.layout not in ["nn", "nt"]: + raise ValueError(f"Unsupported layout: {self.layout}") + if self.layout == "nn": + raise ValueError("Currently only support layout=nt") - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (Batch, M, N), - lambda b, i, j: te.sum( - A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k), - name="C", - ) - last_output = C - if accum_dtype != out_dtype: - D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") - last_output = D + def _create_placeholders(self): + A = te.placeholder((self.batch, self.M, self.K), name="A", dtype=self.in_dtype) + B = te.placeholder((self.batch, self.N, self.K), name="B", dtype=self.in_dtype) + Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None + return A, B, Bias - if with_bias: - E = te.compute((Batch, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E") - last_output = E + def _compute_matmul(self, A, B): + k = te.reduce_axis((0, self.K), name="k") + C = te.compute( + (self.batch, self.M, self.N), + lambda b, i, j: te.sum( + A[b, i, k].astype(self.accum_dtype) * B[b, j, k].astype(self.accum_dtype), axis=k), + name="C", + ) + return C - args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + def _apply_bias(self, C, Bias): + if self.with_bias: + return te.compute((self.batch, self.M, self.N), lambda b, i, j: C[b, i, j] + Bias[j], name="E") + return C - func = te.create_prim_func(args) + def _convert_dtype(self, tensor): + if self.accum_dtype != self.out_dtype: + return te.compute((self.batch, self.M, self.N), lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), name="D") + return tensor - return tvm.IRModule.from_expr(func) + def emit(self): + A, B, Bias = self._create_placeholders() + C = self._compute_matmul(A, B) + last_output = self._convert_dtype(C) + if self.with_bias: + last_output = self._apply_bias(last_output, Bias) + args = [A, B, Bias, last_output] if self.with_bias else [A, B, last_output] + func = te.create_prim_func(args) + return tvm.IRModule.from_expr(func) -def matmul( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - layout="nt", -): - if layout == "nn": - raise ValueError("Currently only support layout=nt") - return matmul_nt(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) +class BatchMatMulSelector(TIRScriptSelector): + def __init__(self, propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform): + self.propagate_a = propagate_a + self.propagate_b = propagate_b + + def select( + self, + batch=1, + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + ): + if layout == "nn": + if self.propagate_a or self.propagate_b: + raise ValueError("Currently only support propagate_a=False and propagate_b=False for layout=nn") + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + elif layout == "nt": + if self.propagate_a and self.propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif self.propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif self.propagate_b: + raise ValueError("Currently only support propagate_b=False for layout=nt") + else: + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + else: + raise ValueError(f"Unsupported layout: {layout}") def select_implementation( Batch=1, @@ -75,19 +129,5 @@ def select_implementation( propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform, ): - if layout == "nn": - if propagate_a or propagate_b: - raise ValueError( - "Currently only support propagate_a=False and propagate_b=False for layout=nn") - return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - elif layout == "nt": - if propagate_a and propagate_b: - raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") - elif propagate_a: - raise ValueError("Currently only support propagate_a=False for layout=nt") - elif propagate_b: - raise ValueError("Currently only support propagate_b=False for layout=nt") - else: - return matmul(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - else: - raise ValueError(f"Unsupported layout: {layout}") + selector = BatchMatMulSelector(propagate_a, propagate_b) + return selector.select(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) From fc84173f22d2f4867a8e6413117b5cd8e830ab27 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:57:43 +0000 Subject: [PATCH 2/8] Refactor import statements for improved readability and maintainability --- bitblas/ops/impl/__init__.py | 2 +- bitblas/ops/impl/base.py | 4 ++++ bitblas/ops/impl/batch_matmul_impl.py | 33 ++++++++++++++++++--------- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/bitblas/ops/impl/__init__.py b/bitblas/ops/impl/__init__.py index a254dc7fb..8a9bbd2a5 100644 --- a/bitblas/ops/impl/__init__.py +++ b/bitblas/ops/impl/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .lop3_permutate_impl import tir_interleave_weight +from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 diff --git a/bitblas/ops/impl/base.py b/bitblas/ops/impl/base.py index 6d510f7da..4a67987be 100644 --- a/bitblas/ops/impl/base.py +++ b/bitblas/ops/impl/base.py @@ -2,15 +2,19 @@ # Licensed under the MIT License. from abc import ABC, abstractmethod + # TODO: Refactor all the tir script implementations to use this base class # Abstract base class for TIR script emitters class TIRScriptEmitter(ABC): + @abstractmethod def emit(self): raise NotImplementedError + # Abstract base class for TIR script selectors class TIRScriptSelector(ABC): + @abstractmethod def select(self): raise NotImplementedError diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py index 75449ea4b..3904f36e6 100644 --- a/bitblas/ops/impl/batch_matmul_impl.py +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -5,11 +5,10 @@ from tvm import te from bitblas.ops.operator import TransformKind from .base import TIRScriptEmitter, TIRScriptSelector -from bitblas import tvm -from tvm import te -from bitblas.ops.operator import TransformKind + class BatchMatMulEmitter(TIRScriptEmitter): + def __init__( self, batch, @@ -32,7 +31,7 @@ def __init__( self.with_bias = with_bias self.layout = layout self._validate_layout() - + @staticmethod def _validate_dimension(dim, name): if not isinstance(dim, int): @@ -48,7 +47,8 @@ def _validate_layout(self): def _create_placeholders(self): A = te.placeholder((self.batch, self.M, self.K), name="A", dtype=self.in_dtype) B = te.placeholder((self.batch, self.N, self.K), name="B", dtype=self.in_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None + Bias = te.placeholder( + (self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None return A, B, Bias def _compute_matmul(self, A, B): @@ -63,12 +63,16 @@ def _compute_matmul(self, A, B): def _apply_bias(self, C, Bias): if self.with_bias: - return te.compute((self.batch, self.M, self.N), lambda b, i, j: C[b, i, j] + Bias[j], name="E") + return te.compute((self.batch, self.M, self.N), + lambda b, i, j: C[b, i, j] + Bias[j], + name="E") return C def _convert_dtype(self, tensor): if self.accum_dtype != self.out_dtype: - return te.compute((self.batch, self.M, self.N), lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), name="D") + return te.compute((self.batch, self.M, self.N), + lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), + name="D") return tensor def emit(self): @@ -84,7 +88,10 @@ def emit(self): class BatchMatMulSelector(TIRScriptSelector): - def __init__(self, propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform): + + def __init__(self, + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform): self.propagate_a = propagate_a self.propagate_b = propagate_b @@ -102,8 +109,10 @@ def select( ): if layout == "nn": if self.propagate_a or self.propagate_b: - raise ValueError("Currently only support propagate_a=False and propagate_b=False for layout=nn") - return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn") + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, + layout).emit() elif layout == "nt": if self.propagate_a and self.propagate_b: raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") @@ -112,10 +121,12 @@ def select( elif self.propagate_b: raise ValueError("Currently only support propagate_b=False for layout=nt") else: - return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, + with_bias, layout).emit() else: raise ValueError(f"Unsupported layout: {layout}") + def select_implementation( Batch=1, M=None, From 02f64de6cf2d338c092dcf29ec55b69804fda892 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:58:06 +0000 Subject: [PATCH 3/8] Refactor import statements for improved readability and maintainability --- bitblas/ops/impl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/ops/impl/__init__.py b/bitblas/ops/impl/__init__.py index 8a9bbd2a5..67e49b2ae 100644 --- a/bitblas/ops/impl/__init__.py +++ b/bitblas/ops/impl/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 +from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 From 397eee6141599e84b509594bb99a0531e409c266 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 16:25:47 +0000 Subject: [PATCH 4/8] disable failure email for ci --- .github/workflows/ci.yml | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ceb69fcc7..1fbdf19dd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,4 +64,13 @@ jobs: run: | source bitblas_ci/bin/activate cd testing/python - python -m pytest \ No newline at end of file + python -m pytest + + # Control notifications + notify: + runs-on: self-hosted + needs: [format-check, build-test] + if: failure() + steps: + - name: Notification + run: echo "Jobs failed, but no email will be sent." From 20f6ad1e7ca4e6e1ca9e13ad7c1bbc8c430a8e51 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 03:23:50 +0000 Subject: [PATCH 5/8] remove email notifications. --- .github/workflows/ci.yml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1fbdf19dd..511b95833 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,12 +65,3 @@ jobs: source bitblas_ci/bin/activate cd testing/python python -m pytest - - # Control notifications - notify: - runs-on: self-hosted - needs: [format-check, build-test] - if: failure() - steps: - - name: Notification - run: echo "Jobs failed, but no email will be sent." From b93c39431c803e22b12f71b555939785da36b96a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 03:25:05 +0000 Subject: [PATCH 6/8] move relax pass from testing to mlc_llm --- .../mlc_llm}/test_weight_only_transform.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {testing/python/transform => integration/mlc_llm}/test_weight_only_transform.py (100%) diff --git a/testing/python/transform/test_weight_only_transform.py b/integration/mlc_llm/test_weight_only_transform.py similarity index 100% rename from testing/python/transform/test_weight_only_transform.py rename to integration/mlc_llm/test_weight_only_transform.py From 257693a7c3cb3083aac144182f58d38bfe3bcfdd Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:51:01 +0000 Subject: [PATCH 7/8] Refactor scripts with se check_eual_ref_scripts_with_emitter function --- bitblas/ops/impl/matmul_dequantize_impl.py | 224 ++++++++++++++---- .../operators/test_tir_script_emitter.py | 52 +++- 2 files changed, 216 insertions(+), 60 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 1ed6b3404..e69e8fcfb 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -15,8 +15,10 @@ _tir_packed_to_unsigned_convert_with_zeros, ) + # TODO: The following code should be refactored. class MatMulNTDequantizeEmitter: + def __init__( self, M, @@ -52,8 +54,8 @@ def __init__( self.fast_decoding = fast_decoding self.with_bias = with_bias self.zeros_mode = zeros_mode - self.propagate_a = propagate_a - self.propagate_b = propagate_b + self.propagate_a = self._legalize_transform_kind(propagate_a) + self.propagate_b = self._legalize_transform_kind(propagate_b) self._validate_bit() self._validate_layout() @@ -69,62 +71,169 @@ def _validate_bit(self): raise ValueError(f"Unsupported bit: {self.bit}") def _validate_layout(self): - if self.layout not in ["nt"]: - raise ValueError(f"Unsupported layout: {self.layout}") + # TODO: extend the dequantize operators into General Layout + pass + + def _legalize_group_size(self): + if self.group_size == -1: + self.group_size = self.K + + def _legalize_transform_kind(self, propagate): + if propagate is None: + return TransformKind.NonTransform + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) def _create_placeholders(self): - storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) - n_float_per_elem = storage_nbit // self.bit - - A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype) - LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype) - Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype) - Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype) - QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit), + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + in_dtype = self.in_dtype + bit = self.bit + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), + name="B", + dtype=storage_dtype) + if self.propagate_a: + A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) + if self.propagate_b: + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // target_dtype.bits) + qr = r * bit // storage_nbit + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=in_dtype) + QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * bit), name="QZeros", dtype=self.storage_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) - return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem + Bias = te.placeholder((self.N,), name="Bias", dtype=in_dtype) + return A, B, LUT, Scale, Zeros, QZeros, Bias + + def _propagate_input(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="A"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=False, dtype=in_dtype, matrix_name=matrix_name) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.M, self.K), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _propagage_weight(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="B"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + bit = self.bit + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=True, dtype=in_dtype, matrix_name=matrix_name) + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + qr = r * bit // storage_nbit - def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem): - w = None + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.N, self.K // storage_nbit * bit), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _decode_func(self, B, LUT, Scale, Zeros, QZeros): + bit = self.bit + in_dtype = self.in_dtype + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + + # TODO: Move the decode function into a more general place def decode(n, k): + w = None if self.with_zeros and self.zeros_mode == "quantized": - qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, + qzeros_dequantize = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, QZeros[k, n // n_float_per_elem], n % n_float_per_elem, dtype=self.storage_dtype, ) - w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)( - self.bit, + w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, qzeros_dequantize, - dtype=self.in_dtype, + dtype=in_dtype, ) elif self.source_format == "uint": - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "int": - if self.bit == 1: - w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp": w = _tir_u32_to_f4_to_f16( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype) + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) elif self.source_format == "nf": - index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, + index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype="int32", @@ -132,7 +241,9 @@ def decode(n, k): w = LUT[index] else: raise ValueError(f"Unsupported source_format: {self.source_format}") - + + assert w is not None, "w is None" + group_size = self.group_size zeros_mode = self.zeros_mode @@ -167,7 +278,9 @@ def _compute_matmul(self, A, B_decode): def _convert_dtype(self, tensor): if self.accum_dtype != self.out_dtype: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D") + return te.compute((self.M, self.N), + lambda i, j: tensor[i, j].astype(self.out_dtype), + name="D") return tensor def _apply_bias(self, tensor, Bias): @@ -176,9 +289,12 @@ def _apply_bias(self, tensor, Bias): return tensor def emit(self): - A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders() - B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem) - C = self._compute_matmul(A, B_decode) + A, B, LUT, Scale, Zeros, QZeros, Bias = self._create_placeholders() + A_reindex = self._propagate_input(A, self.propagate_a, "A") + B_reindex = self._propagage_weight(B, self.propagate_b, "B") + + B_decode = self._decode_func(B_reindex, LUT, Scale, Zeros, QZeros) + C = self._compute_matmul(A_reindex, B_decode) D = self._convert_dtype(C) last_output = self._apply_bias(D, Bias) @@ -212,8 +328,13 @@ def emit(self): } }, ) + if self.propagate_a: + func = func.with_attr("input_transform_kind", self.propagate_a.value) + if self.propagate_b: + func = func.with_attr("weight_transform_kind", self.propagate_b.value) return tvm.IRModule.from_expr(func) + def matmul_nt_dequantize_b( M, N, @@ -335,9 +456,12 @@ def decode_func(n, k): A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: @@ -517,9 +641,11 @@ def decode_func(n, k): A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: @@ -715,9 +841,11 @@ def decode_func(n, k): ), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: diff --git a/testing/python/operators/test_tir_script_emitter.py b/testing/python/operators/test_tir_script_emitter.py index cec56b473..fcfa7d9af 100644 --- a/testing/python/operators/test_tir_script_emitter.py +++ b/testing/python/operators/test_tir_script_emitter.py @@ -1,18 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from bitblas.ops.impl.matmul_dequantize_impl import ( - MatMulNTDequantizeEmitter, - matmul_nt_dequantize_b, - matmul_nt_dequantize_b_propagate_b, - matmul_nt_dequantize_b_propagate_a_propagate_b, -) from bitblas import tvm import logging from bitblas import set_log_level set_log_level(logging.DEBUG) -def compare_tir_scripts_and_emitter( + +def check_eual_ref_scripts_with_emitter( M, N, K, @@ -28,8 +23,26 @@ def compare_tir_scripts_and_emitter( fast_decoding, with_bias, zeros_mode, + propagate_a, + propagate_b, ): - tir_script_func = matmul_nt_dequantize_b( + from bitblas.ops.impl.matmul_dequantize_impl import ( + MatMulNTDequantizeEmitter, + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_b, + matmul_nt_dequantize_b_propagate_a_propagate_b, + ) + func = None + if propagate_a and propagate_b: + func = matmul_nt_dequantize_b_propagate_a_propagate_b + elif propagate_b: + func = matmul_nt_dequantize_b_propagate_b + else: + func = matmul_nt_dequantize_b + + assert func is not None, "No function found for the given configuration" + + ref_func = func( M, N, K, @@ -46,8 +59,8 @@ def compare_tir_scripts_and_emitter( with_bias, zeros_mode, ) - - emitter_func = MatMulNTDequantizeEmitter( + + emit_func = MatMulNTDequantizeEmitter( M, N, K, @@ -63,6 +76,21 @@ def compare_tir_scripts_and_emitter( fast_decoding, with_bias, zeros_mode, + propagate_a=propagate_a, + propagate_b=propagate_b, ).emit() - - tvm.ir.assert_structural_equal(tir_script_func, emitter_func) + + tvm.ir.assert_structural_equal(ref_func, emit_func) + + +def test_check_eual_ref_scripts_with_emitter(): + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) + check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", True, True) + +if __name__ == "__main__": + test_check_eual_ref_scripts_with_emitter() From 9bb7f49a968d4c71dbbc12121b4b7cb8258b2136 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:51:15 +0000 Subject: [PATCH 8/8] Lint Fix --- bitblas/ops/impl/matmul_dequantize_impl.py | 13 +++++---- .../operators/test_tir_script_emitter.py | 29 ++++++++++++++----- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index e69e8fcfb..7b91764ca 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -73,7 +73,7 @@ def _validate_bit(self): def _validate_layout(self): # TODO: extend the dequantize operators into General Layout pass - + def _legalize_group_size(self): if self.group_size == -1: self.group_size = self.K @@ -96,18 +96,19 @@ def _create_placeholders(self): l, r = 16, 32 # noqa: E741 A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * bit), - name="B", - dtype=storage_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), name="B", dtype=storage_dtype) if self.propagate_a: A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) if self.propagate_b: target_dtype = DataType(in_dtype) scaling_factor = 1 if bit > 0 and bit < target_dtype.bits: - scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // target_dtype.bits) + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) qr = r * bit // storage_nbit - B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), + name="B", + dtype=storage_dtype) LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) diff --git a/testing/python/operators/test_tir_script_emitter.py b/testing/python/operators/test_tir_script_emitter.py index fcfa7d9af..b2c7a8d4f 100644 --- a/testing/python/operators/test_tir_script_emitter.py +++ b/testing/python/operators/test_tir_script_emitter.py @@ -84,13 +84,28 @@ def check_eual_ref_scripts_with_emitter( def test_check_eual_ref_scripts_with_emitter(): - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) - check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", True, True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "nf", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, + "int8", "nf", True, False, -1, False, False, "original", + False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + True) + check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, + "int8", "uint", True, False, -1, False, False, "original", + True, True) + if __name__ == "__main__": test_check_eual_ref_scripts_with_emitter()