Skip to content

Commit

Permalink
[llvm] Support real function which has scalar arguments (#4422)
Browse files Browse the repository at this point in the history
* wip

* support function with scalar arguments

* remove debug

* fix test
  • Loading branch information
lin-hitonami authored and pull[bot] committed Jan 16, 2023
1 parent 8776f60 commit 1040523
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 170 deletions.
12 changes: 6 additions & 6 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,16 +1108,16 @@ def build_If(ctx, node):

@staticmethod
def build_Expr(ctx, node):
build_stmt(ctx, node.value)
if not isinstance(
node.value,
ast.Call) or not impl.get_runtime().experimental_real_function:
build_stmt(ctx, node.value)
return None

args = [build_stmt(ctx, node.value.func)
] + [arg.ptr for arg in build_stmts(ctx, node.value.args)]
impl.insert_expr_stmt_if_ti_func(ctx.ast_builder, *args)

is_taichi_function = getattr(node.value.func.ptr,
'_is_taichi_function', False)
if is_taichi_function:
func_call_result = node.value.ptr
ctx.ast_builder.insert_expr_stmt(func_call_result.ptr)
return None

@staticmethod
Expand Down
28 changes: 0 additions & 28 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,34 +236,6 @@ def make_tensor_element_expr(_var, _indices, shape, stride):
shape, stride))


@taichi_scope
def insert_expr_stmt_if_ti_func(ast_builder, func, *args, **kwargs):
"""This method is used only for real functions. It inserts a
FrontendExprStmt to the C++ AST to hold the function call if `func` is a
Taichi function.
Args:
func: The function to be called.
args: The arguments of the function call.
kwargs: The keyword arguments of the function call.
Returns:
The return value of the function call if it's a non-Taichi function.
Returns None if it's a Taichi function."""
is_taichi_function = getattr(func, '_is_taichi_function', False)
# If is_taichi_function is true: call a decorated Taichi function
# in a Taichi kernel/function.

if is_taichi_function:
# Compiles the function here.
# Invokes Func.__call__.
func_call_result = func(*args, **kwargs)
# Insert FrontendExprStmt here.
return ast_builder.insert_expr_stmt(func_call_result.ptr)
# Call the non-Taichi function directly.
return func(*args, **kwargs)


class PyTaichi:
def __init__(self, kernels=None):
self.materialized = False
Expand Down
24 changes: 24 additions & 0 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2376,6 +2376,30 @@ llvm::Value *CodeGenLLVM::create_mesh_xlogue(std::unique_ptr<Block> &block) {
return xlogue;
}

void CodeGenLLVM::visit(FuncCallStmt *stmt) {
if (!func_map.count(stmt->func)) {
auto guard = get_function_creation_guard(
{llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0)});
func_map.insert({stmt->func, guard.body});
stmt->func->ir->accept(this);
}
llvm::Function *llvm_func = func_map[stmt->func];
auto *new_ctx = builder->CreateAlloca(get_runtime_type("RuntimeContext"));
call("RuntimeContext_set_runtime", new_ctx, get_runtime());
for (int i = 0; i < stmt->args.size(); i++) {
auto *original = llvm_val[stmt->args[i]];
int src_bits = original->getType()->getPrimitiveSizeInBits();
auto *cast = builder->CreateBitCast(
original, llvm::Type::getIntNTy(*llvm_context, src_bits));
auto *val =
builder->CreateZExt(cast, llvm::Type::getInt64Ty(*llvm_context));
call("RuntimeContext_set_args", new_ctx,
llvm::ConstantInt::get(*llvm_context, llvm::APInt(32, i, true)), val);
}

llvm_val[stmt] = create_call(llvm_func, {new_ctx});
}

TLANG_NAMESPACE_END

#endif // #ifdef TI_WITH_LLVM
4 changes: 4 additions & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

std::unordered_map<const Stmt *, std::vector<llvm::Value *>> loop_vars_llvm;

std::unordered_map<Function *, llvm::Function *> func_map;

using IRVisitor::visit;
using LLVMModuleBuilder::call;

Expand Down Expand Up @@ -378,6 +380,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *val,
std::function<llvm::Value *(llvm::Value *, llvm::Value *)> op);

void visit(FuncCallStmt *stmt) override;

~CodeGenLLVM() override = default;
};

Expand Down
16 changes: 11 additions & 5 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ void compile_to_offloads(IRNode *ir,
print("Simplified I");
irpass::analysis::verify(ir);

if (irpass::inlining(ir, config, {})) {
print("Functions inlined");
irpass::analysis::verify(ir);
}

if (is_extension_supported(config.arch, Extension::mesh)) {
irpass::analysis::gather_meshfor_relation_types(ir);
}
Expand Down Expand Up @@ -300,6 +295,17 @@ void compile_inline_function(IRNode *ir,
irpass::lower_ast(ir);
print("Lowered");
}
irpass::lower_access(ir, config, {{}, true});
print("Access lowered");
irpass::analysis::verify(ir);

irpass::die(ir);
print("DIE");
irpass::analysis::verify(ir);

irpass::flag_access(ir);
print("Access flagged III");
irpass::analysis::verify(ir);

irpass::type_check(ir, config);
print("Typechecked");
Expand Down
Loading

0 comments on commit 1040523

Please sign in to comment.