Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Refactor codebase to save import time #262

Merged
merged 11 commits into from
Dec 13, 2024

Conversation

LeiWang1999
Copy link
Contributor

This pull request mainly shift tvm intrin related import into lazy import format to save the module import time when we write import bitblas in python.

Before this modification be applied:

import time:  39164.91985321045 ms
config time:  6095.165014266968 ms
hardware_aware_finetune time:  33147.03679084778 ms

After this modification be applied:

import time:  2109.856128692627 ms
config time:  22495.219230651855 ms
hardware_aware_finetune time:  49844.87271308899 ms

Test script to reproduce the results:

import time
import_start_time = time.time()
import bitblas
import_end_time = time.time()

print("bitblas.path: ", bitblas.__path__)

from bitblas import tvm as tvm


M = N = K = 256
A_dtype = "float16"
W_dtype = "float16"
accum_dtype = "float16"
out_dtype = "float16"

bitblas.set_log_level("DEBUG")

config_create_time = time.time()
matmul_config = bitblas.MatmulConfig(
    M=M,  # M dimension
    N=N,  # N dimension
    K=K,  # K dimension
    A_dtype=A_dtype,  # activation A dtype
    W_dtype=W_dtype,  # weight W dtype
    accum_dtype=accum_dtype,  # accumulation dtype
    out_dtype=out_dtype,  # output dtype
    layout="nt",  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    propagate_b=False,  # propagate B matrix
)

matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False, backend="tl")
config_end_time = time.time()

hardware_aware_finetune_start_time = time.time()
matmul.hardware_aware_finetune()
hardware_aware_finetune_end_time = time.time()

print("import time: ", (import_end_time - import_start_time) * 1000, "ms")
print("config time: ", (config_end_time - config_create_time) * 1000, "ms")
print("hardware_aware_finetune time: ", (hardware_aware_finetune_end_time - hardware_aware_finetune_start_time) * 1000, "ms")

@LeiWang1999 LeiWang1999 merged commit f11474b into microsoft:main Dec 13, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant