Skip to content

Commit

Permalink
conv2d topi test working
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 14, 2022
1 parent 6d53c50 commit a957dde
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.add_attr_option<String>("mcpu")
.add_attr_option<String>("mtriple")
.add_attr_option<String>("mattr")
.add_attr_option<Array<String>>("mattr")
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.add_attr_option<Integer>("thread_warp_size", Integer(64))
Expand Down
20 changes: 10 additions & 10 deletions tests/python/topi/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,13 +361,6 @@ def get_ref_data():
# 4,
# False,
# ),
# (
# "rocm",
# lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
# topi.cuda.schedule_conv2d_NCHWc_int8,
# 4,
# False,
# ),
]

build_only_aarch64 = platform.machine() != "aarch64"
Expand All @@ -383,15 +376,22 @@ def get_ref_data():
)

if in_dtype == "int8":
targets.append(
targets += [
(
"llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
topi.arm_cpu.conv2d_NCHWc_int8,
topi.arm_cpu.schedule_conv2d_NCHWc_int8,
8,
build_only_aarch64,
)
)
),
(
"rocm -mattr=+dotprod",
lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
topi.cuda.schedule_conv2d_NCHWc_int8,
4,
False,
),
]

for target, compute, schedule, oc_block_factor, build_only in targets:
check_target(target, compute, schedule, oc_block_factor, build_only)
Expand Down

0 comments on commit a957dde

Please sign in to comment.