Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Change Integer Implicit Conversion Rule to C Standard Way #8733

Merged
merged 1 commit into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 28 additions & 35 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,6 @@ Type GetType(const PrimExpr& expr) {
return PrimType(dtype);
}

// simple cast that only checks if type matches and cast
inline PrimExpr SimpleCast(const DataType& t, PrimExpr value, Span span) {
if (value.dtype() == t) return value;
return tir::Cast(t, value, span);
}

// LargeUIntImm
PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high, Span span) {
return tir::Call(
Expand Down Expand Up @@ -113,48 +107,47 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
}
if (lhs.dtype() == rhs.dtype()) return;

ltype = lhs.dtype();
rtype = rhs.dtype();
// We keep dtypes conversion to be relatively consistent to reduce the amount code generated by
// operators. This can be helpful for users to find potential type conversion problems. The
// following are exceptions:
if (lhs.dtype().is_float() && rhs.dtype().is_float()) {
if (ltype.is_float() && rtype.is_float()) {
// Given two dissimilar floats, cast the lower bit version to the higher bit version.
// E.g. fp16 + fp32 --> fp32 + fp32
if (lhs.dtype().bits() < rhs.dtype().bits()) {
lhs = cast(rhs.dtype(), lhs);
} else if (lhs.dtype().bits() > rhs.dtype().bits()) {
rhs = cast(lhs.dtype(), rhs);
if (ltype.bits() < rtype.bits()) {
lhs = cast(rtype, lhs);
} else {
rhs = cast(ltype, rhs);
}
} else if (!lhs.dtype().is_float() &&
(rhs.dtype().is_float() ||
datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) {
} else if (!ltype.is_float() &&
(rtype.is_float() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
// Cast int->float when the other operand is a float
lhs = cast(rhs.dtype(), lhs);
} else if ((lhs.dtype().is_float() ||
datatype::Registry::Global()->GetTypeRegistered(lhs.dtype().code())) &&
!rhs.dtype().is_float()) {
lhs = cast(rtype, lhs);
} else if ((ltype.is_float() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
!rtype.is_float()) {
// Cast int->float when the other operand is a float
rhs = cast(lhs.dtype(), rhs);
} else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) ||
(lhs.dtype().is_uint() && rhs.dtype().is_uint())) {
rhs = cast(ltype, rhs);
} else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) {
// Promote int to higher bits e.g. int8 + int16 --> int16 + int16
if (lhs.dtype().bits() < rhs.dtype().bits()) {
lhs = cast(rhs.dtype(), lhs);
if (ltype.bits() < rtype.bits()) {
lhs = cast(rtype, lhs);
} else {
rhs = cast(lhs.dtype(), rhs);
rhs = cast(ltype, rhs);
}
} else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) ||
(lhs.dtype().is_uint() && rhs.dtype().is_int())) {
} else if ((ltype.is_int() && rtype.is_uint()) || (ltype.is_uint() && rtype.is_int())) {
// Handle mixing signed and unsigned integers
int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits());
// if the signed int range is bigger than that of uint, try uint->int
if (lhs.dtype().is_int() && rhs.dtype().bits() <= bits - 1) {
rhs = cast(lhs.dtype(), rhs);
} else if (rhs.dtype().is_int() && lhs.dtype().bits() <= bits - 1) {
lhs = cast(rhs.dtype(), lhs);
if (ltype.bits() < rtype.bits()) {
lhs = cast(rtype, lhs);
} else if (ltype.bits() > rtype.bits()) {
rhs = cast(ltype, rhs);
} else {
// the ranges of uint and int types conflit, try SimpleCast
lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span);
rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span);
// The width of signed and unsigned integers is same.
if (ltype.is_uint()) {
rhs = cast(ltype, rhs);
} else {
lhs = cast(rtype, lhs);
}
}
} else {
LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
Expand Down
20 changes: 10 additions & 10 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5358,16 +5358,16 @@ def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None):
real = get_random_uniform(shape=[10], seed=5)
expected = np.asarray(
[
0.8614111,
0.46572232,
0.6007328,
0.21619737,
0.6361222,
0.7298056,
0.13094282,
0.03556716,
0.32997167,
0.2977605,
0.043976,
0.96656,
0.292199,
0.904297,
0.25167,
0.521778,
0.778985,
0.085463,
0.939846,
0.194201,
]
)
tvm.testing.assert_allclose(real, expected, rtol=1e-5)
Expand Down
7 changes: 4 additions & 3 deletions tests/python/unittest/test_tir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def verify_general_dtype_support(f, is_conditional=False):
[("bool", "int32"), "int32"],
[("int32", "float32"), "float32"],
[("int32", "int64"), "int64"],
[("uint32", "int32"), "int32"],
[("uint32", "int8"), "uint32"],
[("uint32", "int32"), "uint32"],
]
for (lhs_dtype, rhs_dtype), out_dtype in rules:
lhs = te.var("lhs", dtype=lhs_dtype)
Expand Down Expand Up @@ -184,8 +185,8 @@ def test_if_then_else():
[(te.var("cond", dtype="bool"), "bool", "int32"), "int32"],
[(True, "int32", "float32"), "float32"],
[(False, "int32", "int64"), "int64"],
[(te.var("cond", dtype="bool"), "uint32", "int32"), "int32"],
[(te.var("cond", dtype="int32"), "uint32", "int32"), "int32"],
[(te.var("cond", dtype="bool"), "uint32", "int32"), "uint32"],
[(te.var("cond", dtype="int32"), "uint32", "int32"), "uint32"],
]
for (cond, lhs_dtype, rhs_dtype), out_dtype in cases:
lhs = te.var("lhs", dtype=lhs_dtype)
Expand Down