Skip to content

Commit

Permalink
[lang] Split stmt typechecking to the frontend (#3633)
Browse files Browse the repository at this point in the history
  • Loading branch information
re-xyr committed Dec 25, 2021
1 parent 0a7a7c4 commit fa0eb36
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 35 deletions.
1 change: 1 addition & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ void full_simplify(IRNode *root,
const CompileConfig &config,
const FullSimplifyPass::Args &args);
void print(IRNode *root, std::string *output = nullptr);
void frontend_type_check(IRNode *root);
void lower_ast(IRNode *root);
void type_check(IRNode *root, const CompileConfig &config);
bool inlining(IRNode *root,
Expand Down
2 changes: 2 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ void compile_to_offloads(IRNode *ir,
}

if (start_from_ast) {
irpass::frontend_type_check(ir);
irpass::lower_ast(ir);
print("Lowered");
}
Expand Down Expand Up @@ -307,6 +308,7 @@ void compile_inline_function(IRNode *ir,
}

if (start_from_ast) {
irpass::frontend_type_check(ir);
irpass::lower_ast(ir);
print("Lowered");
}
Expand Down
110 changes: 110 additions & 0 deletions taichi/transforms/frontend_type_check.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/frontend_ir.h"
#include "taichi/ir/statements.h"

namespace taichi::lang {

class FrontendTypeCheck : public IRVisitor {
public:
explicit FrontendTypeCheck() {
allow_undefined_visitor = true;
}

void visit(Block *block) override {
std::vector<Stmt *> stmts;
// Make a copy since type casts may be inserted for type promotion.
for (auto &stmt : block->statements) {
stmts.push_back(stmt.get());
}
for (auto stmt : stmts)
stmt->accept(this);
}

void visit(FrontendExprStmt *stmt) override {
// Noop
}

void visit(FrontendAllocaStmt *stmt) override {
// Noop
}

void visit(FrontendSNodeOpStmt *stmt) override {
// Noop
}

void visit(FrontendAssertStmt *stmt) override {
if (not stmt->cond->ret_type->is_primitive(PrimitiveTypeID::i32))
throw TaichiTypeError(fmt::format(
"`assert` conditions must be of type int32; found {}. Consider using "
"`if x != 0:` instead of `if x:` for float values.",
stmt->cond->ret_type->to_string()));
}

void visit(FrontendAssignStmt *stmt) override {
// No implicit cast at frontend for now
}

void visit(FrontendIfStmt *stmt) override {
// TODO: use PrimitiveType::u1 when it's supported
if (not stmt->condition->ret_type->is_primitive(PrimitiveTypeID::i32))
throw TaichiTypeError(fmt::format(
"`if` conditions must be of type int32; found {}. Consider using "
"`assert x != 0` instead of `assert x` for float values.",
stmt->condition->ret_type->to_string()));
if (stmt->true_statements)
stmt->true_statements->accept(this);
if (stmt->false_statements) {
stmt->false_statements->accept(this);
}
}

void visit(FrontendPrintStmt *stmt) override {
// Noop
}

void visit(FrontendEvalStmt *stmt) override {
// Noop
}

void visit(FrontendForStmt *stmt) override {
stmt->body->accept(this);
}

void visit(FrontendFuncDefStmt *stmt) override {
stmt->body->accept(this);
// Determine ret_type after this is actually used
}

void visit(FrontendBreakStmt *stmt) override {
// Noop
}

void visit(FrontendContinueStmt *stmt) override {
// Noop
}

void visit(FrontendWhileStmt *stmt) override {
if (not stmt->cond->ret_type->is_primitive(PrimitiveTypeID::i32))
throw TaichiTypeError(fmt::format(
"`while` conditions must be of type int32; found {}. Consider using "
"`while x != 0:` instead of `while x:` for float values.",
stmt->cond->ret_type->to_string()));
stmt->body->accept(this);
}

void visit(FrontendReturnStmt *stmt) override {
// Noop
}
};

namespace irpass {

void frontend_type_check(IRNode *root) {
TI_AUTO_PROF;
FrontendTypeCheck checker;
root->accept(&checker);
}

} // namespace irpass

} // namespace taichi::lang
26 changes: 3 additions & 23 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ class TypeCheck : public IRVisitor {
}

void visit(IfStmt *if_stmt) override {
// TODO: use PrimitiveType::u1 when it's supported
TI_ASSERT_INFO(
if_stmt->cond->ret_type->is_primitive(PrimitiveTypeID::i32),
"`if` conditions must be of type int32, consider using `if x != 0:` "
"instead of `if x:` for float values.");
if (if_stmt->true_statements)
if_stmt->true_statements->accept(this);
if (if_stmt->false_statements) {
Expand Down Expand Up @@ -256,17 +251,9 @@ class TypeCheck : public IRVisitor {
stmt->ret_type = stmt->cast_type;
}
if (!is_real(stmt->operand->ret_type)) {
if (is_trigonometric(stmt->op_type)) {
TI_ERROR("[{}] Trigonometric operator takes real inputs only, at {}",
stmt->name(), stmt->tb);
} else if (stmt->op_type == UnaryOpType::round ||
stmt->op_type == UnaryOpType::floor ||
stmt->op_type == UnaryOpType::ceil) {
TI_ERROR("[{}] round/floor/ceil takes real inputs only at {}",
stmt->name(), stmt->tb);
} else if (stmt->op_type == UnaryOpType::sqrt ||
stmt->op_type == UnaryOpType::exp ||
stmt->op_type == UnaryOpType::log) {
if (stmt->op_type == UnaryOpType::sqrt ||
stmt->op_type == UnaryOpType::exp ||
stmt->op_type == UnaryOpType::log) {
cast(stmt->operand, config_.default_fp);
}
}
Expand Down Expand Up @@ -361,11 +348,6 @@ class TypeCheck : public IRVisitor {
if (!matching) {
error();
}
if (binary_is_bitwise(stmt->op_type)) {
if (!is_integral(stmt->lhs->ret_type)) {
error("Error: bitwise operations can only apply to integral types.");
}
}
if (is_comparison(stmt->op_type)) {
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(
stmt->lhs->width(), PrimitiveType::i32);
Expand Down Expand Up @@ -403,7 +385,6 @@ class TypeCheck : public IRVisitor {
}

void visit(RangeAssumptionStmt *stmt) override {
TI_ASSERT(stmt->input->ret_type == stmt->base->ret_type);
stmt->ret_type = stmt->input->ret_type;
}

Expand All @@ -425,7 +406,6 @@ class TypeCheck : public IRVisitor {
// TODO: Maybe have a type_inference() pass, which takes in the args/rets
// defined by the kernel. After that, type_check() pass will purely do
// verification, without modifying any types.
TI_ASSERT(rt != PrimitiveType::unknown);
TI_ASSERT(rt->vector_width() == 1);
stmt->ret_type.set_is_pointer(stmt->is_ptr);
}
Expand Down
24 changes: 12 additions & 12 deletions tests/python/test_type_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@ def bitwise_float():
bitwise_float()


# @ti.test(arch=ti.cpu)
# def test_ternary_op():
# @ti.kernel
# def select():
# a = 1.1
# b = 3
# c = 3.6
# d = b if a else c
#
# with pytest.raises(ti.TaichiCompilationError,
# match="for 'select': 'f32', 'i32' and 'f32'"):
# select()
@ti.test(arch=ti.cpu)
def test_ternary_op():
@ti.kernel
def select():
a = 1.1
b = 3
c = 3.6
d = b if a else c

with pytest.raises(TypeError,
match="`if` conditions must be of type int32"):
select()


@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.')
Expand Down

0 comments on commit fa0eb36

Please sign in to comment.