Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 14, 2022
1 parent b2208a7 commit e781ee1
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/tvm/tir/tensor_intrin/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def sdot4(
T.reinterpret(B.vload([0], "int8x4"), dtype="int32"),
T.int32(0),
T.bool(1),
dtype="int32"
dtype="int32",
)


Expand Down
13 changes: 9 additions & 4 deletions python/tvm/topi/cuda/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,15 @@ def _instr(index):
prev_z = 0 if index == 0 else zz.vload(0)

# new_z = tvm.tir.call_pure_extern(zz_dtype, "__dp4a", vec_x, vec_y, prev_z)
new_z = tvm.tir.call_llvm_pure_intrin(zz_dtype, "llvm.amdgcn.sdot4", tvm.tir.const(4, "uint32"),
tvm.tir.call_intrin("int32", "tir.reinterpret", vec_x),
tvm.tir.call_intrin("int32", "tir.reinterpret", vec_y),
prev_z, True)
new_z = tvm.tir.call_llvm_pure_intrin(
zz_dtype,
"llvm.amdgcn.sdot4",
tvm.tir.const(4, "uint32"),
tvm.tir.call_intrin("int32", "tir.reinterpret", vec_x),
tvm.tir.call_intrin("int32", "tir.reinterpret", vec_y),
prev_z,
True,
)
ib.emit(zz.vstore(0, new_z))

return ib.get()
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def swap(arr, axis):


def is_target(names):
"""Return True if the name of the current target is one of provided names"""
names = [names] if isinstance(names, str) else names
target = tvm.target.Target.current(allow_none=False)
return any(name in target.keys for name in names)
1 change: 0 additions & 1 deletion tests/python/topi/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,6 @@ def test_conv2d_nchw(in_dtype):
with Int8Fallback():
# ResNet18 workloads where channels in / out are multiple of oc_block_factor
verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
return
verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 1, 1, 0)
verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 3, 2, 1)
verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 1, 2, 0)
Expand Down

0 comments on commit e781ee1

Please sign in to comment.