Skip to content

Commit

Permalink
verify batch is int or IntImm
Browse files Browse the repository at this point in the history
  • Loading branch information
tkonolige committed Jan 16, 2021
1 parent 792ac46 commit 6fc1fa0
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/tvm/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=invalid-name, unused-argument
"""Schedule for dense operator"""
import logging
from tvm import te
from tvm import te, tir
import tvm.autotvm as autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cublas
Expand All @@ -44,8 +44,10 @@ def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
matmul = cublas.matmul(data, weight, False, True)
if isinstance(batch, int):
cfg.add_flop(batch * in_dim * out_dim * 2)
else:
elif isinstance(batch, tir.IntImm):
cfg.add_flop(batch.value * in_dim * out_dim * 2)
else:
assert isinstance(batch, (int, tir.IntImm)), f"batch must be an int or IntImm, but it is {type(batch)}"
if bias is not None:
matmul = te.compute(
(batch, out_dim), lambda i, j: matmul[i, j] + bias[j], tag=tag.BROADCAST
Expand Down

0 comments on commit 6fc1fa0

Please sign in to comment.