Skip to content

Commit

Permalink
[TIR] Change Integer Implicit Conversion Rule to C Standard Way (#8733)
Browse files Browse the repository at this point in the history
  • Loading branch information
Johnson9009 authored Aug 17, 2021
1 parent dbf9ce5 commit 2793113
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 48 deletions.
63 changes: 28 additions & 35 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,6 @@ Type GetType(const PrimExpr& expr) {
return PrimType(dtype);
}

// simple cast that only checks if type matches and cast
inline PrimExpr SimpleCast(const DataType& t, PrimExpr value, Span span) {
if (value.dtype() == t) return value;
return tir::Cast(t, value, span);
}

// LargeUIntImm
PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high, Span span) {
return tir::Call(
Expand Down Expand Up @@ -113,48 +107,47 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
}
if (lhs.dtype() == rhs.dtype()) return;

ltype = lhs.dtype();
rtype = rhs.dtype();
// We keep dtypes conversion to be relatively consistent to reduce the amount code generated by
// operators. This can be helpful for users to find potential type conversion problems. The
// following are exceptions:
if (lhs.dtype().is_float() && rhs.dtype().is_float()) {
if (ltype.is_float() && rtype.is_float()) {
// Given two dissimilar floats, cast the lower bit version to the higher bit version.
// E.g. fp16 + fp32 --> fp32 + fp32
if (lhs.dtype().bits() < rhs.dtype().bits()) {
lhs = cast(rhs.dtype(), lhs);
} else if (lhs.dtype().bits() > rhs.dtype().bits()) {
rhs = cast(lhs.dtype(), rhs);
if (ltype.bits() < rtype.bits()) {
lhs = cast(rtype, lhs);
} else {
rhs = cast(ltype, rhs);
}
} else if (!lhs.dtype().is_float() &&
(rhs.dtype().is_float() ||
datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) {
} else if (!ltype.is_float() &&
(rtype.is_float() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
// Cast int->float when the other operand is a float
lhs = cast(rhs.dtype(), lhs);
} else if ((lhs.dtype().is_float() ||
datatype::Registry::Global()->GetTypeRegistered(lhs.dtype().code())) &&
!rhs.dtype().is_float()) {
lhs = cast(rtype, lhs);
} else if ((ltype.is_float() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
!rtype.is_float()) {
// Cast int->float when the other operand is a float
rhs = cast(lhs.dtype(), rhs);
} else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) ||
(lhs.dtype().is_uint() && rhs.dtype().is_uint())) {
rhs = cast(ltype, rhs);
} else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) {
// Promote int to higher bits e.g. int8 + int16 --> int16 + int16
if (lhs.dtype().bits() < rhs.dtype().bits()) {
lhs = cast(rhs.dtype(), lhs);
if (ltype.bits() < rtype.bits()) {
lhs = cast(rtype, lhs);
} else {
rhs = cast(lhs.dtype(), rhs);
rhs = cast(ltype, rhs);
}
} else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) ||
(lhs.dtype().is_uint() && rhs.dtype().is_int())) {
} else if ((ltype.is_int() && rtype.is_uint()) || (ltype.is_uint() && rtype.is_int())) {
// Handle mixing signed and unsigned integers
int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits());
// if the signed int range is bigger than that of uint, try uint->int
if (lhs.dtype().is_int() && rhs.dtype().bits() <= bits - 1) {
rhs = cast(lhs.dtype(), rhs);
} else if (rhs.dtype().is_int() && lhs.dtype().bits() <= bits - 1) {
lhs = cast(rhs.dtype(), lhs);
if (ltype.bits() < rtype.bits()) {
lhs = cast(rtype, lhs);
} else if (ltype.bits() > rtype.bits()) {
rhs = cast(ltype, rhs);
} else {
// the ranges of uint and int types conflit, try SimpleCast
lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span);
rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span);
// The width of signed and unsigned integers is same.
if (ltype.is_uint()) {
rhs = cast(ltype, rhs);
} else {
lhs = cast(rtype, lhs);
}
}
} else {
LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
Expand Down
20 changes: 10 additions & 10 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5357,16 +5357,16 @@ def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None):
real = get_random_uniform(shape=[10], seed=5)
expected = np.asarray(
[
0.8614111,
0.46572232,
0.6007328,
0.21619737,
0.6361222,
0.7298056,
0.13094282,
0.03556716,
0.32997167,
0.2977605,
0.043976,
0.96656,
0.292199,
0.904297,
0.25167,
0.521778,
0.778985,
0.085463,
0.939846,
0.194201,
]
)
tvm.testing.assert_allclose(real, expected, rtol=1e-5)
Expand Down
7 changes: 4 additions & 3 deletions tests/python/unittest/test_tir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def verify_general_dtype_support(f, is_conditional=False):
[("bool", "int32"), "int32"],
[("int32", "float32"), "float32"],
[("int32", "int64"), "int64"],
[("uint32", "int32"), "int32"],
[("uint32", "int8"), "uint32"],
[("uint32", "int32"), "uint32"],
]
for (lhs_dtype, rhs_dtype), out_dtype in rules:
lhs = te.var("lhs", dtype=lhs_dtype)
Expand Down Expand Up @@ -184,8 +185,8 @@ def test_if_then_else():
[(te.var("cond", dtype="bool"), "bool", "int32"), "int32"],
[(True, "int32", "float32"), "float32"],
[(False, "int32", "int64"), "int64"],
[(te.var("cond", dtype="bool"), "uint32", "int32"), "int32"],
[(te.var("cond", dtype="int32"), "uint32", "int32"), "int32"],
[(te.var("cond", dtype="bool"), "uint32", "int32"), "uint32"],
[(te.var("cond", dtype="int32"), "uint32", "int32"), "uint32"],
]
for (cond, lhs_dtype, rhs_dtype), out_dtype in cases:
lhs = te.var("lhs", dtype=lhs_dtype)
Expand Down

0 comments on commit 2793113

Please sign in to comment.