Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Enhance Operator Cache to support multi-thread environments #205

Merged
merged 14 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 1 files
+1 −2 python/tvm/tl/engine.py
67 changes: 39 additions & 28 deletions bitblas/cache/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from bitblas import tvm
from tvm.contrib.tar import tar
import logging
import threading

logger = logging.getLogger(__name__)

Expand All @@ -24,53 +25,63 @@ class OperatorCache:
"""
Manages a cache for operator instances (e.g., Matmul, Convolution) based on their configurations.
"""
# A lock to synchronize access to the cache
# RLock is used to allow reentrant locking
# As load_from_database calls _load_operator which
# calls _instantiate_and_add_operator
cache_locker = threading.RLock()

def __init__(self):
self.cache = {}

def add(self, config: OperatorConfig, op_inst: Operator):
self.cache[config] = op_inst
with self.cache_locker:
self.cache[config] = op_inst

def get(self, config: OperatorConfig):
return self.cache.get(config)
with self.cache_locker:
return self.cache.get(config)

def exists(self, config):
return config in self.cache

def clear(self):
self.cache.clear()
with self.cache_locker:
self.cache.clear()

def size(self):
return len(self.cache)

def save_into_database(self, database_path=None, target=None):
database_path = self._ensure_database_path(database_path)
for config, op_inst in self.cache.items():
arch_str = self._determine_arch_str(op_inst, target)
arch_path = os.path.join(database_path, arch_str)
self._ensure_directory(arch_path)
hash_str = sha256(repr(config).encode()).hexdigest()
config_path = os.path.join(arch_path, hash_str)
# if the config already exists, skip saving
if os.path.exists(config_path):
continue
self._ensure_directory(config_path)
self._save_operator_config_and_artifact(config, op_inst, config_path)
with self.cache_locker:
database_path = self._ensure_database_path(database_path)
for config, op_inst in self.cache.items():
arch_str = self._determine_arch_str(op_inst, target)
arch_path = os.path.join(database_path, arch_str)
self._ensure_directory(arch_path)
hash_str = sha256(repr(config).encode()).hexdigest()
config_path = os.path.join(arch_path, hash_str)
# if the config already exists, skip saving
if os.path.exists(config_path):
continue
self._ensure_directory(config_path)
self._save_operator_config_and_artifact(config, op_inst, config_path)

def load_from_database(self, database_path, target=None):
if not os.path.exists(database_path):
logger.info(
f"Database path {database_path} does not exist, skipping loading operators from the database"
)
return
arch_str = self._determine_target_arch_str(target)
arch_path = os.path.join(database_path, arch_str)
if not os.path.exists(arch_path):
logger.info(
f"Target {arch_str} does not exist in the database, skipping loading operators from the database"
)
return
self._load_operators_from_arch_path(arch_path, target)
with self.cache_locker:
if not os.path.exists(database_path):
logger.info(
f"Database path {database_path} does not exist, skipping loading operators from the database"
)
return
arch_str = self._determine_target_arch_str(target)
arch_path = os.path.join(database_path, arch_str)
if not os.path.exists(arch_path):
logger.info(
f"Target {arch_str} does not exist in the database, skipping loading operators from the database"
)
return
self._load_operators_from_arch_path(arch_path, target)

def _ensure_database_path(self, database_path):
if database_path is None:
Expand Down
12 changes: 1 addition & 11 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm as tvm
from tvm import DataType
import tvm.tl.language as T
from typing import Optional
from bitblas.tl.utils import (
get_mma_micro_size,
make_swizzle_layout,
)

from bitblas.ops.base_scheduler import BaseScheduler

from dataclasses import dataclass
Expand Down Expand Up @@ -43,17 +36,14 @@ class MatmulFineGrainSIMTScheduler(BaseScheduler):
def with_default_config(self):
raise NotImplementedError

def apply_config(
self,
):
def apply_config(self,):

# M, N, K = self.M, self.N, self.K
# trans_A, trans_B = self.trans_A, self.trans_B
# in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype

raise NotImplementedError


def __post_init__(self):
# Validate the matrix transpose settings
assert self.trans_A is False, "Currently only support Matrix A not transposed"
Expand Down
35 changes: 13 additions & 22 deletions bitblas/tl/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,19 @@
import os
from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from typing import List, Tuple, Optional, Dict, Union, Literal, Callable
from typing import List, Tuple, Optional, Dict, Literal
from tvm import tir, IRModule
from tvm.runtime import Module
from tvm.tir import Schedule
from tvm.relax.expr import Function
import tvm.tl as tl
import bitblas
from bitblas.ops.base_scheduler import BaseScheduler
from bitblas.base.arch import CUDA
from bitblas.base import Hint
from bitblas.base.utils import get_dummy_input_arrays
from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy
from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags
import tempfile
import itertools
from tvm.ir.supply import GlobalVarSupply
from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
from bitblas.utils.tensor_adapter import (
np_float2np_bf16,)
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,8 +60,8 @@ def profile(self, data_distribution="uniform"):


def _apply_config(
scheduler: BaseScheduler,
config: Dict = None,
scheduler: BaseScheduler,
config: Dict = None,
) -> Optional[IRModule]:
"""
find rules:
Expand Down Expand Up @@ -121,13 +114,15 @@ def _build(context) -> str:
return idx, None, None

config = configs[idx]
assert config is not None

@tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True)
def tvm_callback_cuda_postproc(code, _):
code = tensor_replace_dp4a(code)
code = tensor_remove_make_int4(code)
code = tensor_remove_make_int2(code)
return code

# check only have one function in the module
if len(mod.functions) > 1:
raise ValueError("Only support one function in the module")
Expand Down Expand Up @@ -168,12 +163,12 @@ def tvm_callback_cuda_postproc(code, _):
continue
rt_mod = tvm.runtime.load_module(artifact_path)
# Transform Tuning Config to Hint
hint = Hint.from_dict(
{
**{"arch": arch},
**config,
}
)
hint = Hint.from_dict({
**{
"arch": arch
},
**config,
})
cpresult = CompileResult(hint, sch, rt_mod)
timer_cuda_mod = rt_mod.time_evaluator(
rt_mod.entry_name, arch.device, number=num_repeats)
Expand Down Expand Up @@ -250,11 +245,8 @@ def fast_tune(
raise NotImplementedError(
"Currently do not support fast tune with none-dynamic range set")
if opt_shapes:
for name, shape in opt_shapes.items():
var = find_var_from_func(func, name)
specilized_func = func.specialize({
var: shape.astype(var.dtype)
}).with_attr("is_specialized")
raise NotImplementedError(
"Currently do not support fast tune with none-dynamic range set")

arch = CUDA(target)

Expand All @@ -281,4 +273,3 @@ def fast_tune(
)

return cpresults, best

1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
PYPI_BUILD = os.environ.get("PYPI_BUILD", "False").lower() == "true"
PACKAGE_NAME = "bitblas"
ROOT_DIR = os.path.dirname(__file__)
MAIN_CUDA_VERSION = "12.1"

# BitBLAS only supports Linux platform
assert sys.platform.startswith("linux"), "BitBLAS only supports Linux platform (including WSL)."
Expand Down
126 changes: 126 additions & 0 deletions testing/python/cache/test_operator_cache_spin_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import pytest
import os
import torch
import bitblas
import threading
from bitblas import Matmul, MatmulConfig
from bitblas.cache import global_operator_cache
from bitblas import tvm as tvm
from tvm.contrib import utils

target = bitblas.utils.auto_detect_nvidia_target()
bitblas.set_log_level("DEBUG")


def get_codegen_result(ops, target):
code = ops.get_source(target=target)
return code


def tune_op_in_thread(thread_id, matmul_config, database_path):
"""Each thread tunes the given Matmul operation and tries to save it into the global cache."""
matmul = Matmul(
config=matmul_config,
target=target,
enable_tuning=False,
)
print(f"Thread {thread_id}: Starting hardware-aware tuning...")
# matmul.hardware_aware_finetune(topk=20)
success = False
try:
print(f"Thread {thread_id}: Adding operation to global cache...")
global_operator_cache.add(matmul.config, matmul)

global_operator_cache.save_into_database(database_path, target=target)
assert os.path.exists(database_path), "Database file was not created."
global_operator_cache.clear()
assert global_operator_cache.size() == 0, "Global cache was not cleared properly."
global_operator_cache.load_from_database(database_path, target=target)
assert global_operator_cache.size() > 0, (
f"Thread {thread_id}: Global cache was not loaded properly as it is empty.")

success = True
except Exception as hash_error:
print(f"Thread {thread_id}: Error encountered - {hash_error}")
assert success, f"Thread {thread_id}: Failed to add operation to global cache."


@pytest.mark.parametrize(
"M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout",
[
(1, 1024, 1024, "float16", "float16", "float16", False, False, False, "nt"),
],
)
def test_global_cache_save_to_database_multithreaded(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
with_bias,
propagate_a,
propagate_b,
layout,
):
num_threads = 4
global_operator_cache.clear()

# For real world scenarios, all workers should share the same database path
tempdir = utils.tempdir()
database_path = str(tempdir.path)

matmul_config = MatmulConfig(
M=M,
N=N,
K=K,
A_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
with_bias=with_bias,
propagate_a=propagate_a,
propagate_b=propagate_b,
layout=layout,
)

# Launch four threads, each tuning the same operation
threads = []
for thread_id in range(num_threads):
thread = threading.Thread(
target=tune_op_in_thread, args=(thread_id, matmul_config, database_path))
threads.append(thread)
thread.start()

# Wait for all threads to complete
for thread in threads:
thread.join()

matmul = global_operator_cache.get(matmul_config)
assert matmul is not None, "Matmul operation not found in cache after reload."

# Verify that the operation produces correct results
input_shape = (M, K)
weight_shape = (N, K) if layout == "nt" else (K, N)

inputs = []
inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda())
inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda())
ref_result = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1])

permuted_inputs = []
if matmul.input_transform is not None:
permuted_inputs.append(matmul.input_transform(inputs[0].cpu()).cuda())
else:
permuted_inputs.append(inputs[0])
if matmul.weight_transform is not None:
permuted_inputs.append(matmul.weight_transform(inputs[1].cpu()).cuda())
else:
permuted_inputs.append(inputs[1])

bitblas_output = matmul(*permuted_inputs)
torch.testing.assert_close(bitblas_output, ref_result, rtol=1e-2, atol=1e-2)


# fmt: on
if __name__ == "__main__":
bitblas.testing.main()
Loading