Skip to content

Commit

Permalink
[IR] Add type_check for TernaryOpExpression (#3381)
Browse files Browse the repository at this point in the history
* [IR] Add type_check for TernaryOpExpression

* Manual Format

* fix
  • Loading branch information
KuribohG authored Nov 4, 2021
1 parent 30195de commit bcd0783
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 0 deletions.
19 changes: 19 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,25 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void TernaryOpExpression::type_check() {
auto op1_type = op1->ret_type;
auto op2_type = op2->ret_type;
auto op3_type = op3->ret_type;
if (op1_type == PrimitiveType::unknown ||
op2_type == PrimitiveType::unknown || op3_type == PrimitiveType::unknown)
return;
auto error = [&]() {
throw std::runtime_error(fmt::format(
"TypeError: unsupported operand type(s) for '{}': '{}', '{}' and '{}'",
ternary_type_name(type), op1->ret_type->to_string(),
op2->ret_type->to_string(), op3->ret_type->to_string()));
};
if (!is_integral(op1_type) || !op2_type->is<PrimitiveType>() ||
!op3_type->is<PrimitiveType>())
error();
ret_type = promoted_type(op2_type, op3_type);
}

void TernaryOpExpression::flatten(FlattenContext *ctx) {
// if (stmt)
// return;
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ class TernaryOpExpression : public Expression {
this->op3.set(load_if_ptr(op3));
}

void type_check() override;

void serialize(std::ostream &ss) override {
ss << ternary_type_name(type) << '(';
op1->serialize(ss);
Expand Down
15 changes: 15 additions & 0 deletions tests/cpp/ir/frontend_type_inference_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,20 @@ TEST(FrontendTypeInference, UnaryOp) {
EXPECT_EQ(bit_not_i16->ret_type, PrimitiveType::i16);
}

TEST(FrontendTypeInference, TernaryOp) {
auto const_i16 = Expr::make<ConstExpression, int16>(-(1 << 10));
const_i16->type_check();
EXPECT_EQ(const_i16->ret_type, PrimitiveType::i16);
auto cast_i8 = cast(const_i16, PrimitiveType::i8);
cast_i8->type_check();
EXPECT_EQ(cast_i8->ret_type, PrimitiveType::i8);
auto const_f32 = Expr::make<ConstExpression, float32>(5.0);
const_f32->type_check();
EXPECT_EQ(const_f32->ret_type, PrimitiveType::f32);
auto ternary_f32 = expr_select(const_i16, cast_i8, const_f32);
ternary_f32->type_check();
EXPECT_EQ(ternary_f32->ret_type, PrimitiveType::f32);
}

} // namespace lang
} // namespace taichi
13 changes: 13 additions & 0 deletions tests/python/test_type_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,16 @@ def bitwise_float():

with pytest.raises(SystemExit):
bitwise_float()


@ti.test(arch=ti.cpu)
def test_ternary_op():
@ti.kernel
def select():
a = 1.1
b = 3
c = 3.6
d = b if a else c

with pytest.raises(SystemExit):
select()

0 comments on commit bcd0783

Please sign in to comment.