Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Potentially improve performance through block reduction #63

Merged
merged 38 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
75d2f3d
improve e4m3 decoding.
May 21, 2024
dd744d0
Merge branch 'main' of https://github.com/microsoft/BitBLAS into main
May 23, 2024
00bfa31
append fp16xint1
May 25, 2024
8cd8b10
Update submodule commit reference
Jun 1, 2024
9122ff7
chore: Update shared memory scope for float32 output dtype
Jun 1, 2024
b508acc
BUGFIX: UINT8/INT8 Decoding
Jun 2, 2024
58d55b7
feat: Add rasterization options for roller module
Jun 5, 2024
e7547ce
Refactor tensorcore_legalization method to optimize tensor core usage
Jun 5, 2024
678a2e1
feat: Add function to collect variables from expression, improve for …
Jun 5, 2024
3088b35
chore: Update typing import in __init__.py
Jun 5, 2024
5d206b3
chore: Refactor CPU execution of operators
Jun 5, 2024
e06ce10
Refactor matmul implementation for splitk layout
Jun 5, 2024
d67cc6d
Refactor matmul implementation for splitk layout
Jun 5, 2024
9e36b6d
Refactor matmul implementation for splitk layout
Jun 5, 2024
e1a0149
chore: Update version to 0.0.1.dev8
Jun 5, 2024
df0ed7a
chore: Enable debug output in bitblas.set_debug_level()
Jun 5, 2024
a0f651a
Refactor Linear module matmul implementation for splitk layout
Jun 5, 2024
88295a7
Refactor matmul implementation for splitk layout
Jun 5, 2024
3366dce
Merge branch 'main' of https://github.com/microsoft/BitBLAS into lei/…
Jun 5, 2024
25b5c63
Refactor CUDA kernel launch string for dynamic symbolic set
Jun 5, 2024
26a9f1b
Bumpt version to v0.0.1.dev9
Jun 5, 2024
251bf08
Merge branch 'main' of https://github.com/microsoft/BitBLAS into lei/…
Jun 5, 2024
e0cf62c
Refactor CUDA kernel launch string for dynamic symbolic set
Jun 6, 2024
2e4e8dd
Bump version to v0.0.1.dev10
Jun 6, 2024
0dec7d8
Merge branch 'main' into lei/splitk
LeiWang1999 Jun 6, 2024
81f5b9a
Refactor CUDA kernel launch string for dynamic symbolic set
Jun 6, 2024
ec64f91
Merge branch 'lei/splitk' of https://github.com/LeiWang1999/MSBitBLAS…
Jun 6, 2024
5e71163
Bump version to v0.0.1.dev12 and add MatmulConfigWithSplitK and Matmu…
Jun 6, 2024
d0e0726
Merge branch 'main' into lei/splitk
LeiWang1999 Jun 6, 2024
30c0ae7
fix the typo
Jun 29, 2024
4bbccae
Refactor CUDA kernel launch string for dynamic symbolic set
Jun 29, 2024
0d1b649
Refactor CUDA kernel launch string for dynamic symbolic set
Jun 29, 2024
2ce41bb
Refactor CUDA kernel launch string for dynamic symbolic set
Jun 29, 2024
866f561
Refactor CUDA kernel launch string for dynamic symbolic set
Jun 30, 2024
8d2393c
Merge branch 'main' of https://github.com/microsoft/BitBLAS into lei/…
Jun 30, 2024
22c12d7
Merge branch 'microsoft:main' into main
LeiWang1999 Jun 30, 2024
d9fdc21
Merge branch 'main' of https://github.com/LeiWang1999/MSBitBLAS into …
Jun 30, 2024
1e534f4
Merge branch 'lei/splitk' of https://github.com/LeiWang1999/MSBitBLAS…
Jun 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/QuickStart.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import torch

# enabling debug output

bitblas.set_debug_level("Debug")
bitblas.set_log_level("Debug")
matmul_config = bitblas.MatmulConfig(
M=1, # M dimension
N=1024, # N dimension
Expand Down Expand Up @@ -129,7 +129,7 @@ import bitblas
import torch

# enabling debug output
bitblas.set_debug_level("Debug")
bitblas.set_log_level("Debug")

model = bitblas.Linear(
in_features=1024,
Expand Down Expand Up @@ -185,7 +185,7 @@ from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import (
)

# enabling debug output
bitblas.set_debug_level("Debug")
bitblas.set_log_level("Debug")

in_features = 1024
out_features = 1024
Expand Down
2 changes: 2 additions & 0 deletions python/bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import logging
from tqdm import tqdm


class TqdmLoggingHandler(logging.Handler):
""" Custom logging handler that directs log output to tqdm progress bar to avoid interference. """

Expand All @@ -61,6 +62,7 @@ def set_log_level(level):

Args:
level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO).
OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'
"""
if isinstance(level, str):
level = getattr(logging, level.upper(), logging.INFO)
Expand Down
6 changes: 4 additions & 2 deletions python/bitblas/base/roller/arch/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tvm
from tvm.target import Target
from .arch_base import TileDevice
from typing import List, Dict
from typing import List, Dict, Union


def check_sm_version(arch: str) -> int:
Expand All @@ -28,7 +28,9 @@ def __init__(

class CUDA(TileDevice):

def __init__(self, target: Target):
def __init__(self, target: Union[Target, str]):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
self.sm_version = check_sm_version(self.target.arch)
device = tvm.runtime.cuda(0)
Expand Down
12 changes: 9 additions & 3 deletions python/bitblas/base/roller/hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,20 @@ def __init__(self) -> None:
self.arch = None
self.use_tc = None # todo(lei): this should be renamed.

# special axes tiling info
# Special axes tiling info
self.block = []
self.thread = []
# special axes for tensorCore
# Special axes for MMA
self.warp = []
# reduce axes tiling info
# Reduce axes tiling info
self.rstep = []
self.reduce_thread = []
self.rasterization_plan = NoRasterization()
self.cached_tensors = []
self.output_strides = {}
self.schedule_stages = None
# Config for block reduction
self.block_reduction_depth = None # type: int

# Experimental
self._raxis_order = []
Expand Down Expand Up @@ -203,6 +205,10 @@ def to_dict(self) -> Dict:
dic["raxis_order"] = self._raxis_order
if self.vectorize != {}:
dic["vectorize"] = self.vectorize
if self.pipeline_stage != 1:
dic["pipeline_stage"] = self.pipeline_stage
if self.block_reduction_depth is not None:
dic["block_reduction_depth"] = self.block_reduction_depth
return dic

def from_dict(self, dic: Dict) -> "Hint":
Expand Down
15 changes: 15 additions & 0 deletions python/bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self,
self.wmma_k = 16
self.pipeline_stage: int = 1
self.use_async_copy: bool = False
self.block_reduction_depth: Optional[int] = None
self._legalize_info()

def _legalize_info(self):
Expand All @@ -44,6 +45,11 @@ def _legalize_info(self):
self.use_async_copy = True
else:
self.use_async_copy = False
# TODO: block reduction depth is not used for now.
# As there still exists some performance issues for block reduction.
# block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth")
# if block_reduction_depth:
# self.block_reduction_depth = block_reduction_depth

def _compute_tc_strides(
self,
Expand Down Expand Up @@ -114,6 +120,7 @@ def _check_small_tile(td: TileDict):

smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap)
rstep_map = td.rstep_map.copy()
is_block_reduction = self.block_reduction_depth is not None

def _optimize(node, rstep):
all_steps = self.get_node_reduce_step_candidates(node)
Expand Down Expand Up @@ -177,6 +184,13 @@ def _enlarge(rstep_id):
if len(node.raxis) > 0:
rstep = _optimize(node, rstep_map)
rstep_map = rstep

if is_block_reduction:
# If block reduction, we should constrain the max value is 64
# Otherwise it will introduce an issue of cuda invalid args.
MAX_REDUCE_K = 64
for k in rstep_map:
rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K)
td.rstep_map = rstep_map
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)
return
Expand Down Expand Up @@ -289,6 +303,7 @@ def _score(node, thread): # small is better
codegen_dict.warp = warp_tile
codegen_dict.use_tc = True
codegen_dict.pipeline_stage = self.pipeline_stage
codegen_dict.block_reduction_depth = self.block_reduction_depth
codegen_dict.use_async = self.use_async_copy
codegen_dict.rstep = [int(rsteps[ax.var.name]) for ax in node.raxis]
codegen_dict.cached_tensors = td.cached_tensors_map[node]
Expand Down
5 changes: 3 additions & 2 deletions python/bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def apply_and_build_parallel(func,
arch,
num_repeats=3,
max_workers=10,
timeout=30,
data_distribution="uniform") -> CompileResult:
cpresults = []

Expand All @@ -187,10 +188,10 @@ def _apply_schedule(f, c):

with ThreadPoolExecutor(max_workers=4) as scheduler:
futures = {scheduler.submit(_apply_schedule, func, config) for config in configs}
for future in as_completed(futures):
for future in as_completed(futures, timeout=timeout):
_sched.append(future.result())

builder = PopenPoolExecutor(max_workers=max_workers)
builder = PopenPoolExecutor(max_workers=max_workers, timeout=timeout)

# build in process parallel
def _build(context) -> str:
Expand Down
2 changes: 1 addition & 1 deletion python/bitblas/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Modifications Copyright (c) Microsoft.
# The code below is mostly copied from apache/tvm gemv.py in dlight.
"""A rule for GEMV and DecodeGEMV."""
import re

from functools import reduce
from typing import List, Optional, Union, Dict

Expand Down
1 change: 0 additions & 1 deletion python/bitblas/gpu/gemv_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ def get_vectorize_factor(target_format):
assert len(config.thread) == 2, "SplitK only support 2D thread config"
num_warps = int(num_warps // config.thread[0])


# get target dequantize buffer's idx
def get_idx(weight_decode_info: Dict):
# for LUT dequantize, the expr is LUT(w), the idx is 1
Expand Down
10 changes: 10 additions & 0 deletions python/bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,16 @@ def check_last_trait(region: List[Range]):
if func.attrs is not None and "weight_transform_kind" in func.attrs:
intrin_info["weight_transform_kind"] = func.attrs["weight_transform_kind"]
tags["intrin_info"] = intrin_info
# Analysis Block Reduction Optimization
# Currently, we only support block reduction depth 2 for small M
# When the func is a dequantize like ops, we should consider the M
if hasattr(func.attrs, "dequantize_info"):
for arg in func.params:
inp_shape = func.buffer_map[arg].shape
M = inp_shape[0]
if isinstance(M, tir.IntImm) and M <= 128:
tags["block_reduction_depth"] = 2
break

return tags

Expand Down
Loading
Loading