diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index d29132450227..5db131c44f2a 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -79,12 +79,6 @@ Type GetType(const PrimExpr& expr) { return PrimType(dtype); } -// simple cast that only checks if type matches and cast -inline PrimExpr SimpleCast(const DataType& t, PrimExpr value, Span span) { - if (value.dtype() == t) return value; - return tir::Cast(t, value, span); -} - // LargeUIntImm PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high, Span span) { return tir::Call( @@ -113,48 +107,47 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } if (lhs.dtype() == rhs.dtype()) return; + ltype = lhs.dtype(); + rtype = rhs.dtype(); // We keep dtypes conversion to be relatively consistent to reduce the amount code generated by // operators. This can be helpful for users to find potential type conversion problems. The // following are exceptions: - if (lhs.dtype().is_float() && rhs.dtype().is_float()) { + if (ltype.is_float() && rtype.is_float()) { // Given two dissimilar floats, cast the lower bit version to the higher bit version. // E.g. fp16 + fp32 --> fp32 + fp32 - if (lhs.dtype().bits() < rhs.dtype().bits()) { - lhs = cast(rhs.dtype(), lhs); - } else if (lhs.dtype().bits() > rhs.dtype().bits()) { - rhs = cast(lhs.dtype(), rhs); + if (ltype.bits() < rtype.bits()) { + lhs = cast(rtype, lhs); + } else { + rhs = cast(ltype, rhs); } - } else if (!lhs.dtype().is_float() && - (rhs.dtype().is_float() || - datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) { + } else if (!ltype.is_float() && + (rtype.is_float() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) { // Cast int->float when the other operand is a float - lhs = cast(rhs.dtype(), lhs); - } else if ((lhs.dtype().is_float() || - datatype::Registry::Global()->GetTypeRegistered(lhs.dtype().code())) && - !rhs.dtype().is_float()) { + lhs = cast(rtype, lhs); + } else if ((ltype.is_float() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) && + !rtype.is_float()) { // Cast int->float when the other operand is a float - rhs = cast(lhs.dtype(), rhs); - } else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) || - (lhs.dtype().is_uint() && rhs.dtype().is_uint())) { + rhs = cast(ltype, rhs); + } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) { // Promote int to higher bits e.g. int8 + int16 --> int16 + int16 - if (lhs.dtype().bits() < rhs.dtype().bits()) { - lhs = cast(rhs.dtype(), lhs); + if (ltype.bits() < rtype.bits()) { + lhs = cast(rtype, lhs); } else { - rhs = cast(lhs.dtype(), rhs); + rhs = cast(ltype, rhs); } - } else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) || - (lhs.dtype().is_uint() && rhs.dtype().is_int())) { + } else if ((ltype.is_int() && rtype.is_uint()) || (ltype.is_uint() && rtype.is_int())) { // Handle mixing signed and unsigned integers - int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits()); - // if the signed int range is bigger than that of uint, try uint->int - if (lhs.dtype().is_int() && rhs.dtype().bits() <= bits - 1) { - rhs = cast(lhs.dtype(), rhs); - } else if (rhs.dtype().is_int() && lhs.dtype().bits() <= bits - 1) { - lhs = cast(rhs.dtype(), lhs); + if (ltype.bits() < rtype.bits()) { + lhs = cast(rtype, lhs); + } else if (ltype.bits() > rtype.bits()) { + rhs = cast(ltype, rhs); } else { - // the ranges of uint and int types conflit, try SimpleCast - lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span); - rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span); + // The width of signed and unsigned integers is same. + if (ltype.is_uint()) { + rhs = cast(ltype, rhs); + } else { + lhs = cast(rtype, lhs); + } } } else { LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype; diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0b83b8de363d..93b9cfa07464 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5357,16 +5357,16 @@ def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None): real = get_random_uniform(shape=[10], seed=5) expected = np.asarray( [ - 0.8614111, - 0.46572232, - 0.6007328, - 0.21619737, - 0.6361222, - 0.7298056, - 0.13094282, - 0.03556716, - 0.32997167, - 0.2977605, + 0.043976, + 0.96656, + 0.292199, + 0.904297, + 0.25167, + 0.521778, + 0.778985, + 0.085463, + 0.939846, + 0.194201, ] ) tvm.testing.assert_allclose(real, expected, rtol=1e-5) diff --git a/tests/python/unittest/test_tir_ops.py b/tests/python/unittest/test_tir_ops.py index aeec63abba27..9725650eadae 100644 --- a/tests/python/unittest/test_tir_ops.py +++ b/tests/python/unittest/test_tir_ops.py @@ -119,7 +119,8 @@ def verify_general_dtype_support(f, is_conditional=False): [("bool", "int32"), "int32"], [("int32", "float32"), "float32"], [("int32", "int64"), "int64"], - [("uint32", "int32"), "int32"], + [("uint32", "int8"), "uint32"], + [("uint32", "int32"), "uint32"], ] for (lhs_dtype, rhs_dtype), out_dtype in rules: lhs = te.var("lhs", dtype=lhs_dtype) @@ -184,8 +185,8 @@ def test_if_then_else(): [(te.var("cond", dtype="bool"), "bool", "int32"), "int32"], [(True, "int32", "float32"), "float32"], [(False, "int32", "int64"), "int64"], - [(te.var("cond", dtype="bool"), "uint32", "int32"), "int32"], - [(te.var("cond", dtype="int32"), "uint32", "int32"), "int32"], + [(te.var("cond", dtype="bool"), "uint32", "int32"), "uint32"], + [(te.var("cond", dtype="int32"), "uint32", "int32"), "uint32"], ] for (cond, lhs_dtype, rhs_dtype), out_dtype in cases: lhs = te.var("lhs", dtype=lhs_dtype)