diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index dc34626205a1..dd5f54e52455 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -253,6 +253,23 @@ class DataTypeRewriter : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } + PrimExpr VisitExpr_(const RampNode* op) final { + PrimExpr base = VisitExpr(op->base); + PrimExpr stride = VisitExpr(op->stride); + if (base.same_as(op->base) && stride.same_as(op->stride)) { + return GetRef(op); + } else { + if (base.dtype().is_int()) { + ICHECK(stride.dtype().is_int()) << "Ramp base is int but stride is " << stride.dtype(); + int bits = std::max(base.dtype().bits(), stride.dtype().bits()); + DataType dtype = base.dtype().with_bits(bits); + if (base.dtype() != dtype) base = cast(dtype, base); + if (stride.dtype() != dtype) stride = cast(dtype, stride); + } + return Ramp(base, stride, op->lanes); + } + } + PrimExpr VisitExpr_(const SizeVarNode* op) final { if (visitor_.vmap.find(op) != visitor_.vmap.end()) { if (vmap_.find(op) == vmap_.end()) { diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index cd2d230f5775..0c9c97af650d 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -101,7 +101,7 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype, 0), IntImm(var->dtype, 1), var_lanes); } Stmt VisitStmt(const Stmt& stmt) final { diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 9b95266d3287..667fad0317db 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -27,7 +27,7 @@ def lower_stmt(params, stmt, target_bits): return stmt -def lower_sch(sch, args, target_bits): +def lower_sch(sch, args, target_bits, extra_passes=None): binds = {} arg_list = [] for x in args: @@ -42,6 +42,9 @@ def lower_sch(sch, args, target_bits): mod = schedule_to_module(sch, args) mod = tvm.tir.transform.StorageFlatten(64)(mod) + if extra_passes: + for p in extra_passes: + mod = p(mod) return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body @@ -255,6 +258,25 @@ def check(shape, index, target_bits, target_dtype): ) +def test_ramp_dtype_consistency(): + """ + for (i :int64, (int64)0, (int64)4) { + A[ramp(i*(int64)2, (int64)1, 2)] = cast(int64, 2 ** 31 - 1) * i; + } + The infer result: + base: int64 -> int64 (since i is involved in another int64 expr) + stride: int64 -> int32 + + Thus ramp should still use int64 for both stride and base after rewrite. + """ + n = tvm.tir.IntImm("int64", 4) + m = tvm.tir.IntImm("int64", 2) + A = te.compute((n, m), lambda i, j: tvm.tir.Cast("int64", 2 ** 31 - 1) * i, name="A") + s = te.create_schedule(A.op) + s[A].vectorize(A.op.axis[1]) + lower_sch(s, [A], 32, extra_passes=[tvm.tir.transform.VectorizeLoop()]) + + if __name__ == "__main__": test_basic() test_thread_axis() @@ -263,3 +285,4 @@ def check(shape, index, target_bits, target_dtype): test_slice() test_relay_basic() test_relay_take() + test_ramp_dtype_consistency() diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index b1e580957b24..1a0d84a4f807 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -205,6 +205,14 @@ def test_ir(A, B, C): assert expected in error_msg +def test_vectorize_dtype_mismatch(): + n = tvm.tir.IntImm("int64", 4) + A = te.compute((n,), lambda i: tvm.tir.IntImm("int64", 2 ** 31 - 1) + i, name="A") + s = te.create_schedule(A.op) + s[A].vectorize(A.op.axis[0]) + tvm.lower(s, [A], "llvm", simple_mode=True) + + if __name__ == "__main__": test_vectorize_vector() test_vectorize_with_if() @@ -214,3 +222,4 @@ def test_ir(A, B, C): test_vectorize_with_ge_cond() test_vectorize_let() test_vectorize_while_fail() + test_vectorize_dtype_mismatch()