Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any example to use float16xfp4_e2m1 matmul? #253

Closed
yaoyaoding opened this issue Nov 28, 2024 · 2 comments
Closed

Any example to use float16xfp4_e2m1 matmul? #253

yaoyaoding opened this issue Nov 28, 2024 · 2 comments

Comments

@yaoyaoding
Copy link

Hi @LeiWang1999,

I encounter the following error when I try to build a matmul with a_dtype = 'float16' and b_dtype = 'fp4_e2m1' with bitblas, however I encounter the following error:

tvm.error.InternalError: Traceback (most recent call last):
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::PrimExpr (tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span)>::AssignTypedLambda<tvm::PrimExpr (*)(tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span)>(tvm::PrimExpr (*)(tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::reinterpret(tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span)
  File "/root/BitBLAS/3rdparty/tvm/src/tir/op/op.cc", line 415
InternalError: Check failed: (value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes()) is false: Bitcast requires size match float16 vs uint32

which caused by the following python function

def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
    assert nbit == 4
    assert dtype == "float16"
    assert val.dtype == "uint32"
    # e_f4 == 0 -> e_f16 = 0
    # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
    mask = tvm.tir.const((1 << nbit) - 1, "uint32")
    f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask
    s = f4 >> tir.const(3, "uint32")
    e_f4 = f4 & tir.const(7, "uint32")
    e_f16 = e_f4 | tir.const(8, "uint32")
    val_f16 = tir.reinterpret("float16",
                              (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32"))
    return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)

At this line

    val_f16 = tir.reinterpret("float16",
                              (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32"))

seems bitblas is trying to reinterpret a "uint32" to "float16", making tvm complains.

I am using the following matmul config.

        config = MatmulConfig(
            M=4096,
            N=4096,
            K=4096,
            A_dtype="float16",
            W_dtype="fp4_e2m1",
            out_dtype="float16",
            group_size=128,
            accum_dtype='float32',
            with_scaling=True,
            with_zeros=True,
            zeros_mode='original',
            storage_dtype="uint32"
        )

Do you have any example to run the fp16xfp4 matmul in bitblas?

(I am using the v0.0.1.dev15 version)

@LeiWang1999
Copy link
Contributor

@yaoyaoding Thanks for your reporting, yeah I think it should be:

def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
    assert nbit == 4
    assert dtype == "float16"
    assert val.dtype == "uint32"
    # e_f4 == 0 -> e_f16 = 0
    # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
    mask = tvm.tir.const((1 << nbit) - 1, "uint16")
    f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
    s = f4 >> tir.const(3, "uint16")
    e_f4 = f4 & tir.const(7, "uint16")
    e_f16 = e_f4 | tir.const(8, "uint16")
    val_f16 = tir.reinterpret("float16",
                              ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
    return tir.Select(e_f4 == tir.const(0, "uint16"), tir.const(0, "float16"), val_f16)

Ad we should extend to a general cast to make it compatible with any storage dtypes.

During benchmarking, we utilize nf4 for our fp4 benchmarking.

@yaoyaoding
Copy link
Author

yaoyaoding commented Nov 29, 2024

Got it, thanks @LeiWang1999 for the timely reponse and fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants