Skip to content

Commit

Permalink
add dense and bmm test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 14, 2022
1 parent a957dde commit 7291e47
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
38 changes: 38 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,5 +676,43 @@ def test_dense_vnni():
np.testing.assert_equal(out, ref)


@pytest.mark.skip("Requires GFX10 AMDGPU")
def test_dense_rocm_sdot4():
data_shape = (32, 96)
weight_shape = (128, 96)

data_dtype = "int8"
data = relay.var("data", shape=data_shape, dtype=data_dtype)
weight = relay.var("weight", shape=weight_shape, dtype="int8")
bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
dense = relay.nn.dense(data, weight, out_dtype="int32")
out = relay.nn.bias_add(dense, bias)
mod = tvm.IRModule.from_expr(out)

target = "rocm -mattr=+dotprod"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target)

asm = lib.lib.imported_modules[0].get_source("asm")
assert "v_dot4_i32_i8" in asm

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

a = np.random.uniform(1, 10, size=data_shape).astype(data_dtype)
b = np.random.uniform(1, 10, size=weight_shape).astype("int8")
c = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")

runtime.set_input("data", a)
runtime.set_input("weight", b)
runtime.set_input("bias", c)
runtime.run()

out = runtime.get_output(0).numpy()
ref = np.dot(a.astype("int32"), b.transpose().astype("int32")) + c

np.testing.assert_equal(out, ref)


if __name__ == "__main__":
pytest.main([__file__])
35 changes: 35 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,41 @@ def test_batch_matmul_vnni():
np.testing.assert_equal(out, ref)


@pytest.mark.skip("Requires GFX10 AMDGPU")
def test_batch_matmul_rocm_sdot4():
x_shape = (16, 32, 96)
y_shape = (16, 128, 96)

lhs_dtype = "int8"
x = relay.var("x", shape=x_shape, dtype=lhs_dtype)
y = relay.var("y", shape=y_shape, dtype="int8")
bmm = relay.nn.batch_matmul(x, y, out_dtype="int32")

mod = tvm.IRModule.from_expr(bmm)

target = "rocm -mattr=+dotprod"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target)

asm = lib.lib.imported_modules[0].get_source("asm")
assert "v_dot4_i32_i8" in asm

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

x_np = np.random.uniform(1, 10, size=x_shape).astype(lhs_dtype)
y_np = np.random.uniform(1, 10, size=y_shape).astype("int8")

runtime.set_input("x", x_np)
runtime.set_input("y", y_np)
runtime.run()

out = runtime.get_output(0).numpy()
ref = tvm.topi.testing.batch_matmul(x_np, y_np, out_dtype="int32")

np.testing.assert_equal(out, ref)


@tvm.testing.uses_gpu
def test_shape_of():
shape = (10, 5, 12)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def get_ref_data():
4,
False,
),
# Disable on CI since it does not support spirv int8 dot product or rocm
# Disable on CI since it does not support spirv int8 dot product
# (
# "vulkan -from_device=0",
# lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
Expand Down

0 comments on commit 7291e47

Please sign in to comment.