Skip to content

Commit

Permalink
[Dev] Merge BlockReduce with naive schedule template (#119)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* bug fix in test

* lint fix.

* test cuda i4 kernel

* Refactor copyright notice in i4matmul.hpp

* Refactor BitBLASLinear test module for improved readability and maintainability

* refactor test as version below python 3.9 cannot handle int32 overflow.

* format lint for test

* Refactor test_int4b_fp16_convert.py for improved readability and maintainability

* remove unused design file

* move tile device from package to base

* dummy impl for codegen

* Refactor file structure for ladder_permutate module

* Refactor backend class and fix typos in comments

* Deep refactor Lib related code.

* remove ci pull.

* LintFix

* refactor builder for whl build

* Refactor TIRWrapper.wrap() method to include an assertion for the optimized module

* Refactor lib_generator to set library and source paths

* lint fix

* BitNet vllm integration

* chore: update codespell to version 2.3.0

* Lintfix

* Bump version to 0.0.1.dev13

* lint fix

* disable fast decoding [u]int4xint8 by default.

* optimize from dict design in Hint

* Implement SplitK

* bitnet benchmark generation.

* Add benchmark script for BitNet integration

* AtomicAdd Support

* LintFix

* ci fix when 3rdparty tvm is initialized.

* bug fix for setup

* fix a bug in block reduce

* typo fix

* BUG Fix for block reduce.

* Lint fix

* Refactor block reduce schedule template

* transform branch from bitblas to bitblas_tl

* Fix subproject commit reference in 3rdparty/tvm

* chore: update submodule branch from bitblas to bitblas_tl

* force update config.cmake

* Bug fix

* Fix subproject commit reference in 3rdparty/cutlass

* chore: Add submodule for cutlass library

* update tl cutlass path

* Refactor BitBLASLinear test module for improved readability and maintainability

* format fix

* Copy CUTLASS to the package directory

* Refactor setup.py to include additional TVM header files

* lint fix

* bug fix

* Refactor BitBLASLinear test module for improved readability and maintainability

* Implement Matmul Benchmark Design

* chore: Update BitBLAS Matmul benchmark script

* lint fix

* Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* lint fix

* Benchmark bot test

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* int8 test case

* Refactor compare_benchmark.py to handle missing benchmark results gracefully

* ci fix

* disable ci for test benchmark

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* remove cli installation

* chore: Create virtual environment and install dependencies for benchmark

* chore: Update benchmark workflow to include comparison step

* Lint fix

* upodate tvm cmmit

* Imporve lower warp memory pass

* Bug fix

* Enhance to support warp schedule.

* Enhance LOP3 Instructions

* Enhance LOP3 Instructions

* add test for stage3 propagate

* implement propagate func

* Stage3 Ladder Permutate integration

* get_ladder_stage3_propagate

* comments benchmark scirpts as the setting is too big

* ci fix for benchmark

* lint fix

* chore: Update benchmark workflow to trigger on pull request comments

* Add LDMatrix Transform 3

* Support GPTQ Test

* Fuse BlockReduce Schedule

* Support mma propagate 3

* Support MMA Propagate Stage 3

* Lint Fix
  • Loading branch information
LeiWang1999 authored Aug 2, 2024
1 parent 632886d commit c37ff29
Show file tree
Hide file tree
Showing 22 changed files with 708 additions and 474 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ jobs:
cd benchmark/operators
python ./benchmark_ops_matmul.py
benchmark_head:
# On pull requests and if the comment starts with `/run-benchmark`
if: github.event.issue.pull_request != null && startsWith(github.event.comment.body, '/run-benchmark')
runs-on: self-hosted
depends-on: [benchmark_base]

steps:
- name: Checkout PR branch code
uses: actions/checkout@v2
with:
Expand Down Expand Up @@ -92,7 +99,7 @@ jobs:
python ./benchmark_ops_matmul.py
benchmark_compare:
if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/run-benchmark')
if: github.event.issue.pull_request != null && contains(github.event.comment.body, '/run-benchmark')
needs: [benchmark_base, benchmark_head]
runs-on: self-hosted

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
run: |
source bitblas_ci/bin/activate
python -m pip install --upgrade pip
if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi
if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi
- name: Install project in wheel mode
run: |
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/tvm
64 changes: 32 additions & 32 deletions benchmark/operators/benchmark_ops_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,41 +108,41 @@ def prepare_benchmark_sets(self):
"FP16xFP16_ACCFP16_NT",
[
*self.prepare_set_group_4x("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 3200),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8640, 3200),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 8640),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 5120),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 13824, 5120),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 13824),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 6656),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 17920, 6656),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 17920),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 1024, 8192),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 8192),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 28672, 8192),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 28672),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 3200),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8640, 3200),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 8640),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 5120),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 13824, 5120),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 13824),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 6656),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 17920, 6656),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 17920),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 1024, 8192),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 8192),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 28672, 8192),
# *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 28672),
],
)

self.add_benchmark_set(
"INT8xINT8_ACCINT32_NT",
[
*self.prepare_set_group_4x("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 3200),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8640, 3200),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 8640),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 5120),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 13824, 5120),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 13824),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 6656),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 17920, 6656),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 17920),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 1024, 8192),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 8192),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 28672, 8192),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 28672),
],
)
# self.add_benchmark_set(
# "INT8xINT8_ACCINT32_NT",
# [
# *self.prepare_set_group_4x("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 3200),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8640, 3200),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 8640),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 5120),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 13824, 5120),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 13824),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 6656),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 17920, 6656),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 17920),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 1024, 8192),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 8192),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 28672, 8192),
# *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 28672),
# ],
# )

def generate_operator_config(self, name: str, M, N, K) -> MatmulConfig:
"""Generate configuration for the given operator."""
Expand Down
2 changes: 1 addition & 1 deletion bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)

from . import testing # noqa: F401
from .utils import auto_detect_nvidia_target # noqa: F401
from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401
from .ops.matmul_dequantize import MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion bitblas/builder/lib_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from bitblas import TileDevice
from bitblas.base.arch import TileDevice
import ctypes
import os
import tempfile
Expand Down
5 changes: 3 additions & 2 deletions bitblas/builder/wrapper/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from bitblas import tvm
from typing import Optional, List, Dict, Union
from tvm import IRModule
from bitblas import TileDevice
from bitblas.base.arch import TileDevice
from bitblas.utils import match_global_kernel
from bitblas.utils.rtmod_analysis import get_annotated_device_mod
import re
from .base import BaseWrapper
import logging

from .base import BaseWrapper

logger = logging.getLogger(__name__)


Expand Down
57 changes: 57 additions & 0 deletions bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,63 @@ def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j):
return ldmatrix_index_map, inversed_index_map


# This function is used to get the index map for the stage3 of the
# Ladder weight propagation, which can be used to avoid the ldmatrix
# Instructions.
def get_ladder_stage3_map(dtype="float16", index_dtype="int32"):

def shared_32x8_to_mma_32x8_layout(i, j):
thread_id = (i % 8) * 4 + (j // 2)
local_id = (i // 8) * 2 + (j % 2)
return thread_id, local_id

def shared_32x16_to_mma_32x16_layout(i, j):
thread_id = (i % 8) * 4 + (j // 4)
local_id = (i // 8) * 4 + (j % 4)
return thread_id, local_id

assert dtype in [
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
], "Only support float16, int8, e4m3_float8, e5m2_float8"
if dtype == "float16":
stage3_layout = shared_32x8_to_mma_32x8_layout
elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
stage3_layout = shared_32x16_to_mma_32x16_layout
else:
raise ValueError("Unknown dtype ", dtype)

# IntraWarp memory layout was occurred by ldmatrix, we should lift the ld_matrix out
def ladder_stage3_permutation_16x16_32x8_32x8_16x16(kernel_i, kernel_j):
thread_id = kernel_i * 2 + kernel_j // 8
local_id = kernel_j % 8
new_thread_id, new_local_id = stage3_layout(thread_id, local_id)
new_kernel_i = (new_thread_id * 8 + new_local_id) // 16
new_kernel_j = (new_thread_id * 8 + new_local_id) % 16
return new_kernel_i, new_kernel_j

def ladder_stage3_permutation_16x32_32x16_32x16_16x32(kernel_i, kernel_j):
thread_id = kernel_i * 2 + kernel_j // 16
local_id = kernel_j % 16
new_thread_id, new_local_id = stage3_layout(thread_id, local_id)
new_kernel_i = (new_thread_id * 16 + new_local_id) // 32
new_kernel_j = (new_thread_id * 16 + new_local_id) % 32
return new_kernel_i, new_kernel_j

if dtype == "float16":
stage3_index_map = ladder_stage3_permutation_16x16_32x8_32x8_16x16
else:
stage3_index_map = ladder_stage3_permutation_16x32_32x16_32x16_16x32

stage3_index_map = IndexMap.from_func(stage3_index_map, index_dtype=index_dtype)
# TODO(lei): index_dtype should be analyzed from the schedule
row, col = [16, 16] if dtype == "float16" else [16, 32]
inversed_index_map = stage3_index_map.inverse([row, col])
return stage3_index_map, inversed_index_map


def layout_propagate_chain(
sch: tir.Schedule,
start_block: BlockRV,
Expand Down
Loading

0 comments on commit c37ff29

Please sign in to comment.