Skip to content

Commit

Permalink
[TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType …
Browse files Browse the repository at this point in the history
…passes (#10172)

[TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes
  • Loading branch information
lazycal authored Feb 22, 2022
1 parent d8d28bf commit d8e39fd
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 2 deletions.
17 changes: 17 additions & 0 deletions src/tir/transforms/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,23 @@ class DataTypeRewriter : public StmtExprMutator {
return StmtExprMutator::VisitExpr_(op);
}

PrimExpr VisitExpr_(const RampNode* op) final {
PrimExpr base = VisitExpr(op->base);
PrimExpr stride = VisitExpr(op->stride);
if (base.same_as(op->base) && stride.same_as(op->stride)) {
return GetRef<PrimExpr>(op);
} else {
if (base.dtype().is_int()) {
ICHECK(stride.dtype().is_int()) << "Ramp base is int but stride is " << stride.dtype();
int bits = std::max(base.dtype().bits(), stride.dtype().bits());
DataType dtype = base.dtype().with_bits(bits);
if (base.dtype() != dtype) base = cast(dtype, base);
if (stride.dtype() != dtype) stride = cast(dtype, stride);
}
return Ramp(base, stride, op->lanes);
}
}

PrimExpr VisitExpr_(const SizeVarNode* op) final {
if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
if (vmap_.find(op) == vmap_.end()) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
using StmtMutator::operator();

Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp(0, 1, var_lanes);
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
}

Stmt VisitStmt(const Stmt& stmt) final {
Expand Down
25 changes: 24 additions & 1 deletion tests/python/unittest/test_tir_transform_narrow_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def lower_stmt(params, stmt, target_bits):
return stmt


def lower_sch(sch, args, target_bits):
def lower_sch(sch, args, target_bits, extra_passes=None):
binds = {}
arg_list = []
for x in args:
Expand All @@ -42,6 +42,9 @@ def lower_sch(sch, args, target_bits):

mod = schedule_to_module(sch, args)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
if extra_passes:
for p in extra_passes:
mod = p(mod)
return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body


Expand Down Expand Up @@ -255,6 +258,25 @@ def check(shape, index, target_bits, target_dtype):
)


def test_ramp_dtype_consistency():
"""
for (i :int64, (int64)0, (int64)4) {
A[ramp(i*(int64)2, (int64)1, 2)] = cast(int64, 2 ** 31 - 1) * i;
}
The infer result:
base: int64 -> int64 (since i is involved in another int64 expr)
stride: int64 -> int32
Thus ramp should still use int64 for both stride and base after rewrite.
"""
n = tvm.tir.IntImm("int64", 4)
m = tvm.tir.IntImm("int64", 2)
A = te.compute((n, m), lambda i, j: tvm.tir.Cast("int64", 2 ** 31 - 1) * i, name="A")
s = te.create_schedule(A.op)
s[A].vectorize(A.op.axis[1])
lower_sch(s, [A], 32, extra_passes=[tvm.tir.transform.VectorizeLoop()])


if __name__ == "__main__":
test_basic()
test_thread_axis()
Expand All @@ -263,3 +285,4 @@ def check(shape, index, target_bits, target_dtype):
test_slice()
test_relay_basic()
test_relay_take()
test_ramp_dtype_consistency()
9 changes: 9 additions & 0 deletions tests/python/unittest/test_tir_transform_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ def test_ir(A, B, C):
assert expected in error_msg


def test_vectorize_dtype_mismatch():
n = tvm.tir.IntImm("int64", 4)
A = te.compute((n,), lambda i: tvm.tir.IntImm("int64", 2 ** 31 - 1) + i, name="A")
s = te.create_schedule(A.op)
s[A].vectorize(A.op.axis[0])
tvm.lower(s, [A], "llvm", simple_mode=True)


if __name__ == "__main__":
test_vectorize_vector()
test_vectorize_with_if()
Expand All @@ -214,3 +222,4 @@ def test_ir(A, B, C):
test_vectorize_with_ge_cond()
test_vectorize_let()
test_vectorize_while_fail()
test_vectorize_dtype_mismatch()

0 comments on commit d8e39fd

Please sign in to comment.