Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Refactor codebase to save import time #262

Merged
merged 11 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 8e2f4b to 5ec617
215 changes: 74 additions & 141 deletions benchmark/operators/benchmark_bitblas_matmul.py
Original file line number Diff line number Diff line change
@@ -1,141 +1,67 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import bitblas
from bitblas.utils.target_detector import auto_detect_nvidia_target
from bitblas import Matmul, MatmulConfig
import argparse
import json


# Initialize the parser
parser = argparse.ArgumentParser(
description="Benchmark BitBLAS int4 on a specific target."
)

# Add arguments to the parser
parser.add_argument(
"--target",
type=str,
default=auto_detect_nvidia_target(),
help="Specify the target device for benchmarking."
)
parser.add_argument(
"--group_size",
type=int,
default=None,
help="Group size for grouped quantization."
)
parser.add_argument(
"--A_dtype",
type=str,
default="float16",
choices=["float16", "float32", "float64", "int32", "int8"], # Assuming these are the valid choices
help="Data type of activation A."
)
parser.add_argument(
"--W_dtype",
type=str,
default="int4",
choices=["float16", "float32", "float64", "int32", "int8", "int4", "int2", "int1", "nf4", "fp4_e2m1"], # Assuming these are the valid choices
help="Data type of weight W."
)
parser.add_argument(
"--accum_dtype",
type=str,
default="float16",
choices=["float16", "int32"], # Assuming these are the valid choices
help="Data type for accumulation."
)
parser.add_argument(
"--out_dtype",
type=str,
default="float16",
choices=["float16", "float32", "int32", "int8"], # Assuming these are the valid choices
help="Data type for output."
)
parser.add_argument(
"--layout",
type=str,
default="nt",
choices=["nt", "nn"], # Assuming these are the valid choices
help="Matrix layout, 'nt' for non-transpose A and transpose W."
)
parser.add_argument(
"--with_bias",
action="store_true",
help="Include bias in the benchmark."
)
parser.add_argument(
"--with_scaling",
action="store_true",
help="Include scaling factor in the quantization."
)
parser.add_argument(
"--with_zeros",
action="store_true",
help="Include zeros in the quantization."
)
parser.add_argument(
"--zeros_mode",
type=str,
default=None,
choices=["original", "rescale", "quantized"], # Replace with actual modes if applicable
help="Specify the mode for calculating zeros."
)

# Parse the arguments
args = parser.parse_args()

# Assign arguments to variables
target = args.target
group_size = args.group_size
A_dtype = args.A_dtype
W_dtype = args.W_dtype
accum_dtype = args.accum_dtype
out_dtype = args.out_dtype
layout = args.layout
with_bias = args.with_bias
group_size = args.group_size
with_scaling = args.with_scaling
with_zeros = args.with_zeros
zeros_mode = args.zeros_mode

test_shapes = [
# square test
(MatmulConfig, Matmul, (1, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
# BLOOM-176B
(MatmulConfig, Matmul, (1, 43008, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (1, 14336, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (1, 57344, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (1, 14336, 57344, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
# # OPT-65B
(MatmulConfig, Matmul, (1, 9216, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (1, 36864, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (1, 9216, 36864, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (1, 22016, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
# # LLAMA-70B/65B
(MatmulConfig, Matmul, (1, 8192, 22016, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (1, 8192, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (1, 28672, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (1, 8192, 28672, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),

# square test
(MatmulConfig, Matmul, (16384, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
# BLOOM-176B
(MatmulConfig, Matmul, (8192, 43008, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (8192, 14336, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (8192, 57344, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (8192, 14336, 57344, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
# # OPT-65B
(MatmulConfig, Matmul, (8192, 9216, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (8192, 36864, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (8192, 9216, 36864, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (8192, 22016, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
# # LLAMA-70B/65B
(MatmulConfig, Matmul, (8192, 8192, 22016, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (8192, 8192, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (8192, 28672, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
(MatmulConfig, Matmul, (8192, 8192, 28672, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
]
bitblas.set_log_level("DEBUG")
# Initialize the parser
parser = argparse.ArgumentParser(description="Benchmark BitBLAS int4 on a specific target.")

# Add arguments to the parser
parser.add_argument(
"--target",
type=str,
default=auto_detect_nvidia_target(),
help="Specify the target device for benchmarking.")

parser.add_argument(
"--backend",
type=str,
default="tir",
choices=["tir", "tl"], # Replace with actual modes if applicable
help="Specify the mode for calculating zeros.")

parser.add_argument("--verbose", type=bool, default=True, help="Enable verbose logging.")

# [A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode]
default_test_shapes = json.dumps([
# ["MatmulConfig", "Matmul", [1, 16384, 16384, "float16", "int4", "float16", "float16", "nt", False, None, False, False, None]]
[
"MatmulConfig", "Matmul",
[
16384, 16384, 16384, "float16", "float16", "float16", "float16", "nt", False, None,
False, False, None
]
]
])

parser.add_argument(
"--test_shapes",
type=str,
default=default_test_shapes,
help="JSON string defining test shapes. Example format: '[[\"MatmulConfig\", \"Matmul\", [1,16384,16384,\"float16\",\"int4\",\"float16\",\"float16\",\"nt\",false,null,false,false,null]]]'"
)

# Parse the arguments
args = parser.parse_args()

# Assign arguments to variables
target = args.target
backend = args.backend
verbose = args.verbose

parsed_test_shapes = json.loads(args.test_shapes)
name_to_class = {"MatmulConfig": MatmulConfig, "Matmul": Matmul}

test_shapes = []
for item in parsed_test_shapes:
config_class_name, operator_class_name, input_args = item
config_class = name_to_class[config_class_name]
operator_class = name_to_class[operator_class_name]
test_shapes.append((config_class, operator_class, tuple(input_args)))

benchmark_sets = []
benchmark_sets.extend(test_shapes)
Expand All @@ -145,12 +71,17 @@
benchmark_results = {}
for config, operator, input_args in benchmark_sets:
config = config(*input_args)
matmul = operator(config, target=target, enable_tuning=True)
kernel_latency = matmul.profile_latency()
if matmul.input_transform is not None:
kernel_latency += matmul.ladder_permutate_a.profile_latency()

print("Time cost is: {:.3f} ms".format(kernel_latency))
print(f"Running benchmark for {operator.__name__} with config: {config}")
op_inst = operator(config, target=target, enable_tuning=True, backend=backend)
kernel_latency = op_inst.profile_latency()
if op_inst.input_transform is not None:
kernel_latency += op_inst.ladder_permutate_a.profile_latency()

print("Time cost of {} is: {:.3f} ms".format(str(config), kernel_latency))

if verbose:
print(op_inst.scheduled_ir_module)
print(op_inst.get_source())

profile_config = {
f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
Expand All @@ -160,7 +91,7 @@

benchmark_results.update(profile_config)

# Define headers for the table
# Define headers for the table
headers = [
"PrimFunc",
"Input Arguments",
Expand All @@ -174,7 +105,9 @@
input_args = "-".join(args[1:])
col_widths[0] = max((max(len(str(headers[0])), len(func_name)) + 2), col_widths[0])
col_widths[1] = max((max(len(str(headers[1])), len(input_args)) + 2, col_widths[1]))
col_widths[2] = max(max(len(str(headers[2])), len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2, col_widths[2])
col_widths[2] = max(
max(len(str(headers[2])), len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2,
col_widths[2])
break

for i, header in enumerate(headers):
Expand All @@ -193,4 +126,4 @@
input_args,
f"{values['BitBLAS_top20_latency']:.3f} ms",
]
print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)]) + "\n")
print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)]))
76 changes: 17 additions & 59 deletions bitblas/base/arch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .arch_base import TileDevice
from .cuda import *
from .cpu import *
from .cdna import *
from .cuda import CUDA
from .cpu import CPU
from .cdna import CDNA
from typing import Union
from tvm.target import Target


def get_arch(target: Union[str, tvm.target.Target] = "cuda") -> TileDevice:
def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
if isinstance(target, str):
target = tvm.target.Target(target)
target = Target(target)

if target.kind.name == "cuda":
return CUDA(target)
Expand All @@ -27,57 +28,14 @@ def auto_infer_current_arch() -> TileDevice:
return get_arch("cuda")


def is_cpu_arch(arch: TileDevice) -> bool:
return isinstance(arch, CPU)


def is_cuda_arch(arch: TileDevice) -> bool:
return isinstance(arch, CUDA)


def is_ampere_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 80 and arch.sm_version < 90)
return all(conditions)


def is_volta_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 70)
conditions.append(arch.sm_version < 80)
return all(conditions)


def is_cdna_arch(arch: TileDevice) -> bool:
return isinstance(arch, CDNA)


def has_mma_support(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 80)
return all(conditions)


def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool:
volta_tensorcore_supported = [
("float16", "float32"),
("float16", "float16"),
]
ampere_tensorcore_supported = [
("float16", "float32"),
("float16", "float16"),
("int8", "int32"),
("int4", "int32"),
("int2", "int32"),
("int1", "int32"),
]

if is_volta_arch(arch):
return (in_dtype, accum_dtype) in volta_tensorcore_supported
elif is_ampere_arch(arch):
return (in_dtype, accum_dtype) in ampere_tensorcore_supported
else:
raise ValueError(f"Unsupported architecture: {arch}")
from .cpu import is_cpu_arch # noqa: F401
from .cuda import (
is_cuda_arch, # noqa: F401
is_volta_arch, # noqa: F401
is_ampere_arch, # noqa: F401
is_ada_arch, # noqa: F401
is_hopper_arch, # noqa: F401
is_tensorcore_supported_precision, # noqa: F401
has_mma_support, # noqa: F401
)
from .cdna import is_cdna_arch # noqa: F401
4 changes: 4 additions & 0 deletions bitblas/base/arch/cdna.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from typing import List, Union


def is_cdna_arch(arch: TileDevice) -> bool:
return isinstance(arch, CDNA)


class CDNA(TileDevice):

def __init__(self, target: Union[Target, str]):
Expand Down
4 changes: 4 additions & 0 deletions bitblas/base/arch/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from .arch_base import TileDevice


def is_cpu_arch(arch: TileDevice) -> bool:
return isinstance(arch, CPU)


# For LLVM Backend, we do not provide the detailed information of the CPU
# As the LLVM backend do not required tuning, just maintain the consistency
class CPU(TileDevice):
Expand Down
Loading
Loading