Skip to content

Commit

Permalink
[Fix] Tensor core type issue for dense (apache#7187)
Browse files Browse the repository at this point in the history
* fix tc type issue for dense

* fix lint

* rm float 32

Co-authored-by: Leyuan Wang <ziyu.guo@bytedance.com>
  • Loading branch information
2 people authored and Tushar Dey committed Jan 20, 2021
1 parent db2c7fa commit 105a9f0
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,26 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
if target.kind.name == "cuda":
if nvcc.have_tensorcore(target=target):
if (
(i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
(
data.dtype in ["float16", "int8", "uint8"]
and (
(i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
)
)
or (
data.dtype in ["int4", "uint4"]
and i % 32 == 0
and b % 8 == 0
and o % 8 == 0
)
or (
data.dtype in ["int1", "uint1"]
and i % 128 == 0
and b % 8 == 0
and o % 8 == 0
)
):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_tensorcore),
Expand Down

0 comments on commit 105a9f0

Please sign in to comment.