Skip to content

Commit

Permalink
more wip
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 13, 2022
1 parent 0602f4a commit 6cc6280
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 33 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,16 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.rocm.schedule_dense),
name="dense.rocm",
)
data, weights = inputs
if (data.dtype == "int8"
and weights.dtype == "int8"
and out_type.dtype == "int32"
):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_int8),
wrap_topi_schedule(topi.cuda.schedule_dense_int8),
name="dense_int8.rocm",
)
if target.kind.name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/cuda/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
_, rc_block = s[conv].split(rc_block, factor=4)
target = tvm.target.Target.current(allow_none=False)
do_tensorize = True
if "vulkan" in target.keys or "rocm" in target.keys:
do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
# if "vulkan" in target.keys or "rocm" in target.keys:
# do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
if do_tensorize:
dtypes = (pad_data.dtype, packed_kernel.dtype)
s[conv].tensorize(rc_block, dp4a("shared", "shared", "local", dtypes))
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def _schedule_dense_int8(cfg, s, output):
ko, kt = cfg["tile_k"].apply(s, CC, ko)
target = tvm.target.Target.current(allow_none=False)
do_tensorize = True
if "vulkan" in target.keys or "rocm" in target.keys:
do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
# if "vulkan" in target.keys or "rocm" in target.keys:
# do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
assert False
if do_tensorize:
dtypes = (data.dtype, weight.dtype)
s[CC].tensorize(ki, dp4a("shared", "shared", "local", dtypes))
Expand Down
59 changes: 30 additions & 29 deletions tests/python/topi/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,45 +346,45 @@ def get_ref_data():
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)

targets = [
(
"cuda",
lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
topi.cuda.schedule_conv2d_NCHWc_int8,
4,
False,
),
# Disable on CI since it does not support spirv int8 dot product
# (
# "vulkan -from_device=0",
# "cuda",
# lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
# topi.cuda.schedule_conv2d_NCHWc_int8,
# 4,
# False,
# ),
# Disable on CI since it does not support spirv int8 dot product
(
"rocm",
lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
topi.cuda.schedule_conv2d_NCHWc_int8,
4,
False,
),
]

build_only_aarch64 = platform.machine() != "aarch64"

targets.append(
(
"llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
topi.arm_cpu.conv2d_NCHWc_int8,
topi.arm_cpu.schedule_conv2d_NCHWc_int8,
8,
build_only_aarch64,
)
)

if in_dtype == "int8":
targets.append(
(
"llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
topi.arm_cpu.conv2d_NCHWc_int8,
topi.arm_cpu.schedule_conv2d_NCHWc_int8,
8,
build_only_aarch64,
)
)
# targets.append(
# (
# "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
# topi.arm_cpu.conv2d_NCHWc_int8,
# topi.arm_cpu.schedule_conv2d_NCHWc_int8,
# 8,
# build_only_aarch64,
# )
# )

# if in_dtype == "int8":
# targets.append(
# (
# "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
# topi.arm_cpu.conv2d_NCHWc_int8,
# topi.arm_cpu.schedule_conv2d_NCHWc_int8,
# 8,
# build_only_aarch64,
# )
# )

for target, compute, schedule, oc_block_factor, build_only in targets:
check_target(target, compute, schedule, oc_block_factor, build_only)
Expand Down Expand Up @@ -517,6 +517,7 @@ def test_conv2d_nchw(in_dtype):
with Int8Fallback():
# ResNet18 workloads where channels in / out are multiple of oc_block_factor
verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
return
verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 1, 1, 0)
verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 3, 2, 1)
verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 1, 2, 0)
Expand Down

0 comments on commit 6cc6280

Please sign in to comment.