diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 5cef8d63587d..11f19e934e02 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -19,9 +19,10 @@ import pytest import tvm import tvm.testing -from tvm import tir +from tvm import tir, te from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.tensor_intrin.vnni import INTRIN_NAME as VNNI_INTRIN # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -531,5 +532,40 @@ def test_tensorize_with_annotation(): verify_trace_roundtrip(sch=s, mod=func) +def test_tensorize_vnni(): + n, m, k = 128, 128, 128 + X = te.placeholder((m, k), name="X", dtype="uint8") + packed_W = te.placeholder((n // 16, k // 4, 16, 4), name="packedW", dtype="int8") + + ak = te.reduce_axis((0, k), name="k") + matmul = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype("int32") + * packed_W[ + tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4 + ].astype("int32"), + axis=ak, + ), + name="compute", + ) + + func = te.create_prim_func([X, packed_W, matmul]) + + sch = tir.Schedule(func, debug_mask="all") + block = sch.get_block("compute") + _, j, k = sch.get_loops(block) + + _, ji = sch.split(j, factors=[None, 16]) + ko, ki = sch.split(k, factors=[None, 4]) + sch.reorder(ko, ji, ki) + + sch.decompose_reduction(block, ko) + sch.tensorize(ji, VNNI_INTRIN) + + verify_trace_roundtrip(sch=sch, mod=func) + + if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + # sys.exit(pytest.main([__file__] + sys.argv[1:])) + test_tensorize_vnni()