Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Potentially improve performance through block reduction #63

Merged
merged 38 commits into from
Jun 30, 2024

Conversation

LeiWang1999
Copy link
Contributor

In our recent evaluations, we observed that batch inference on matrix multiplication shapes for 7b/13b models didn't achieve the expected theoretical speedup. This performance bottleneck appears to be linked to the absence of block reduction support within our schedule template and tvm.

This pull request includes a variety of changes across multiple files, mainly focusing on introducing new features related to block reduction.

Here are the most important changes:
* python/bitblas/base/roller/policy/tensorcore.py: The block_reduction_depth field was added to the Policy class, and this value is now considered in several functions, including _check_small_tile, _enlarge, and _score. [1] [2] [3] [4] [5]
* python/bitblas/gpu/matmul_analysis.py: The check_last_trait function was updated to set the block_reduction_depth to 2 for small M values.
* python/bitblas/gpu/matmul_mma.py: The apply_config function was updated to call a new function apply_block_reduction_with_config if block_reduction_depth is not None.
* 3rdparty/tvm: The subproject commit was updated.

LeiWang199 and others added 30 commits May 21, 2024 11:51
@LeiWang1999
Copy link
Contributor Author

The BitBLAS performance on small shapes has not yet met our expectations, indicating that further investigation is necessary, so the block reduce related items was disabled when we enable auto tuning, cc @tzj-fxz if you have time.
Code to reproduce the performance:

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas.utils.target_detector import auto_detect_nvidia_target
from bitblas import Matmul, MatmulConfig
import argparse
import bitblas
import tvm
from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy
from bitblas.base.roller.arch import CUDA
from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags
from bitblas.base.utils import apply_and_build
import time
from tvm import te, tir

bitblas.set_log_level("DEBUG")
# Initialize the parser  
parser = argparse.ArgumentParser(  
    description="Benchmark BitBLAS int4 on a specific target."  
)  
  
# Add arguments to the parser  
parser.add_argument(  
    "--target",  
    type=str,  
    default=auto_detect_nvidia_target(),  
    help="Specify the target device for benchmarking."  
)  
parser.add_argument(  
    "--group_size",  
    type=int,  
    default=None,  
    help="Group size for grouped quantization."  
)  
parser.add_argument(  
    "--A_dtype",  
    type=str,  
    default="float16",  
    choices=["float16", "float32", "float64", "int32", "int8"],  # Assuming these are the valid choices  
    help="Data type of activation A."  
)  
parser.add_argument(  
    "--W_dtype",  
    type=str,  
    default="uint4",  
    help="Data type of weight W."  
)  
parser.add_argument(  
    "--accum_dtype",  
    type=str,  
    default="float16",  
    help="Data type for accumulation."  
)  
parser.add_argument(  
    "--out_dtype",  
    type=str,  
    default="float16",  
    choices=["float16", "float32", "int32", "int8"],  # Assuming these are the valid choices  
    help="Data type for output."  
)  
parser.add_argument(  
    "--layout",  
    type=str,  
    default="nt",  
    choices=["nt", "nn"],  # Assuming these are the valid choices  
    help="Matrix layout, 'nt' for non-transpose A and transpose W."  
)  
parser.add_argument(  
    "--with_bias",  
    action="store_true",  
    help="Include bias in the benchmark."  
)  
parser.add_argument(  
    "--with_scaling",  
    action="store_true",  
    help="Include scaling factor in the quantization."  
)  
parser.add_argument(  
    "--with_zeros",  
    action="store_true",  
    help="Include zeros in the quantization."  
)  
parser.add_argument(  
    "--zeros_mode",  
    type=str,  
    default=None,  
    choices=["original", "rescale", "quantized"],  # Replace with actual modes if applicable  
    help="Specify the mode for calculating zeros."  
)
parser.add_argument(  
    "--propagate_a",  
    type=str,  
    default=True,  
    choices=["original", "rescale", "quantized"],  # Replace with actual modes if applicable  
    help="Specify the mode for calculating zeros."  
)
parser.add_argument(  
    "--propagate_b",  
    type=str,  
    default=True,  
    choices=["original", "rescale", "quantized"],  # Replace with actual modes if applicable  
    help="Specify the mode for calculating zeros."  
)  
  
# Parse the arguments  
args = parser.parse_args()  
  
# Assign arguments to variables  
target = args.target  
group_size = args.group_size  
A_dtype = args.A_dtype  
W_dtype = args.W_dtype  
accum_dtype = args.accum_dtype  
out_dtype = args.out_dtype  
layout = args.layout  
with_bias = args.with_bias  
group_size = args.group_size  
with_scaling = args.with_scaling  
with_zeros = args.with_zeros  
zeros_mode = args.zeros_mode 
propagate_a = args.propagate_a
propagate_b = args.propagate_b

test_shapes = [
    (MatmulConfig, Matmul, (16, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
]

benchmark_sets = []
benchmark_sets.extend(test_shapes)

# fmt:on

benchmark_results = {}
for config, operator, input_args in benchmark_sets:
    matmul_config = config(*input_args, propagate_a=True, propagate_b=True, fast_decoding=True)
    matmul = operator(matmul_config, target=target, enable_tuning=False)
    func = matmul.prim_func

    intrin_info = bitblas.base.roller.hint.IntrinInfo(
        in_dtype="float16",
        out_dtype="float16",
        trans_b=True,
        input_transform_kind=2,
        weight_transform_kind=2,
    )


    sch_normal = bitblas.gpu.MatmulTensorizationMMAWithDequantizeInfo().sch_shared_memory_prefetch_with_config(
        func,
        bitblas.base.roller.hint.Hint().from_dict({
            "warp": [16, 16],
            "block": [16, 64],
            "rstep": [128],
            "pipeline_stage": 2,
            "use_async": True,
            "intrin_info": intrin_info,
            "shared_scope": "shared",
            "vectorize": {
                "A": 8,
                "B": 8,
            },
            "rasterization_plan": bitblas.base.roller.Rasterization2DColumn(10)
        })
    )
    with tvm.transform.PassContext(config={"tir.use_async_copy": True, "tir.merge_static_smem": False, "cuda.kernels_output_dir": "./debug/bitblas_fp16xint4_fp16_pb_noscale_with_default"}):
        rt_mod = tvm.build(sch_normal.mod, target=matmul.target)
    time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, tvm.cuda(), number=10)
    profile_tensors = matmul.get_profile_tensors()
    latency = time_evaluator(*profile_tensors).mean * 1e3
    # print(rt_mod.imported_modules[0].get_source())
    print(f"Time cost is: {latency:.3f} ms")

    sch_reduce = bitblas.gpu.MatmulTensorizationMMAWithDequantizeInfo().sch_shared_memory_prefetch_with_config(
        func,
        bitblas.base.roller.hint.Hint().from_dict({
            "warp": [16, 16],
            "block": [16, 64],
            "rstep": [128],
            "pipeline_stage": 2,
            "use_async": True,
            "intrin_info": intrin_info,
            "shared_scope": "shared",
            "vectorize": {
                "A": 8,
                "B": 8,
            },
            "block_reduction_depth": 2,
            "rasterization_plan": bitblas.base.roller.Rasterization2DColumn(10)
        })
    )
    with tvm.transform.PassContext(config={"tir.use_async_copy": True, "tir.merge_static_smem": False, "cuda.kernels_output_dir": "./debug/bitblas_fp16xint4_fp16_pb_noscale_with_default"}):
        rt_mod = tvm.build(sch_reduce.mod, target=matmul.target)
    time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, tvm.cuda(), number=10)
    latency = time_evaluator(*profile_tensors).mean * 1e3
    # print(rt_mod.imported_modules[0].get_source())
    print(f"Time cost is: {latency:.3f} ms")
 

@LeiWang1999 LeiWang1999 marked this pull request as ready for review June 30, 2024 11:40
@LeiWang1999 LeiWang1999 merged commit e7ed676 into microsoft:main Jun 30, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant