diff --git a/taichi/codegen/cc/cc_kernel.h b/taichi/codegen/cc/cc_kernel.h index e333e9451a50e..ac6086bc23bed 100644 --- a/taichi/codegen/cc/cc_kernel.h +++ b/taichi/codegen/cc/cc_kernel.h @@ -24,7 +24,7 @@ class CCKernel { } void compile(); - void launch(RuntimeContext *ctx); + void launch(LaunchContextBuilder &ctx); std::string get_object() { return obj_path_; } diff --git a/taichi/codegen/cc/cc_program.cpp b/taichi/codegen/cc/cc_program.cpp index be01aef5a06c2..df992d4a19e75 100644 --- a/taichi/codegen/cc/cc_program.cpp +++ b/taichi/codegen/cc/cc_program.cpp @@ -23,9 +23,9 @@ FunctionType CCProgramImpl::compile(const CompileConfig &compile_config, this->add_kernel(std::move(ker)); return [ker_ptr](LaunchContextBuilder &ctx) { for (auto &[idx, ptr] : ctx.array_ptrs) { - ctx.get_context().args[idx[0]] = (uint64)ptr; + ctx.cc_args[idx[0]] = (uint64)ptr; } - return ker_ptr->launch(&ctx.get_context()); + return ker_ptr->launch(ctx); }; } @@ -102,7 +102,7 @@ void CCRuntime::compile() { execute(cc_program_impl_->config->cc_compile_cmd, obj_path_, src_path_); } -void CCKernel::launch(RuntimeContext *ctx) { +void CCKernel::launch(LaunchContextBuilder &ctx) { ActionRecorder::get_instance().record("launch_kernel", { ActionArg("kernel_name", name_), @@ -181,10 +181,11 @@ CCFuncEntryType *CCProgramImpl::load_kernel(std::string const &name) { return reinterpret_cast(dll_->load_function("Tk_" + name)); } -CCContext *CCProgramImpl::update_context(RuntimeContext *ctx) { +CCContext *CCProgramImpl::update_context(LaunchContextBuilder &ctx) { // TODO(k-ye): Do you have other zero-copy ideas for arg buf? - std::memcpy(context_->args, ctx->args, taichi_max_num_args * sizeof(uint64)); - context_->earg = (int *)ctx->extra_args; + std::memcpy(context_->args, ctx.cc_args, + taichi_max_num_args * sizeof(uint64)); + context_->earg = (int *)ctx.get_context().extra_args; return context_.get(); } diff --git a/taichi/codegen/cc/cc_program.h b/taichi/codegen/cc/cc_program.h index ead359f240d53..09fa946298ff5 100644 --- a/taichi/codegen/cc/cc_program.h +++ b/taichi/codegen/cc/cc_program.h @@ -71,7 +71,7 @@ class CCProgramImpl : public ProgramImpl { CCFuncEntryType *load_kernel(std::string const &name); void relink(); - CCContext *update_context(RuntimeContext *ctx); + CCContext *update_context(LaunchContextBuilder &ctx); void context_to_result_buffer(); void dump_cache_data_to_disk() override { diff --git a/taichi/program/context.h b/taichi/program/context.h index dba590cd5690c..4c909c473f956 100644 --- a/taichi/program/context.h +++ b/taichi/program/context.h @@ -16,11 +16,7 @@ struct RuntimeContext { char *arg_buffer{nullptr}; LLVMRuntime *runtime{nullptr}; - // args can contain: - // - primitive_types - // - raw ptrs: for external array, or torch-based ndarray - // - DeviceAllocation*: for taichi ndaray - uint64_t args[taichi_max_num_args_total]; + uint64_t grad_args[taichi_max_num_args_total]; int32_t extra_args[taichi_max_num_args_extra][taichi_max_num_indices]; int32_t cpu_thread_id; diff --git a/taichi/program/launch_context_builder.cpp b/taichi/program/launch_context_builder.cpp index 8efbd78a0877f..9bb4168f26b9e 100644 --- a/taichi/program/launch_context_builder.cpp +++ b/taichi/program/launch_context_builder.cpp @@ -145,7 +145,9 @@ T LaunchContextBuilder::get_grad_arg(int i) { template void LaunchContextBuilder::set_arg(int i, T v) { set_struct_arg({i}, v); - ctx_->args[i] = taichi_union_cast_with_different_sizes(v); + if (kernel_->arch == Arch::cc) { + cc_args[i] = taichi_union_cast_with_different_sizes(v); + } set_array_device_allocation_type(i, DevAllocType::kNone); } diff --git a/taichi/program/launch_context_builder.h b/taichi/program/launch_context_builder.h index 5cac0dc2f7c18..a463686a51662 100644 --- a/taichi/program/launch_context_builder.h +++ b/taichi/program/launch_context_builder.h @@ -103,6 +103,9 @@ class LaunchContextBuilder { size_t result_buffer_size{0}; bool has_grad[taichi_max_num_args_total]; + // TODO: remove this after CC backend is removed + uint64 cc_args[taichi_max_num_args]{0}; + // Note that I've tried to group `array_runtime_size` and // `is_device_allocations` into a small struct. However, it caused some test // cases to stuck. diff --git a/taichi/runtime/llvm/runtime_module/runtime.cpp b/taichi/runtime/llvm/runtime_module/runtime.cpp index 5371c203a80c0..3d13b1905c54f 100644 --- a/taichi/runtime/llvm/runtime_module/runtime.cpp +++ b/taichi/runtime/llvm/runtime_module/runtime.cpp @@ -286,7 +286,6 @@ STRUCT_FIELD_ARRAY(PhysicalCoordinates, val); #include "taichi/program/context.h" #include "taichi/runtime/llvm/runtime_module/mem_request.h" -STRUCT_FIELD_ARRAY(RuntimeContext, args); STRUCT_FIELD_ARRAY(RuntimeContext, grad_args); STRUCT_FIELD(RuntimeContext, runtime); STRUCT_FIELD(RuntimeContext, result_buffer)