Skip to content

Commit

Permalink
[refactor] Remove 'args' from 'RuntimeContext'
Browse files Browse the repository at this point in the history
ghstack-source-id: d4d1b95e60bed611c9f78b7d2c6fca489e6b0e04
Pull Request resolved: #7730
  • Loading branch information
lin-hitonami authored and Taichi Gardener committed Apr 6, 2023
1 parent 201d882 commit 4495c05
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 15 deletions.
2 changes: 1 addition & 1 deletion taichi/codegen/cc/cc_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class CCKernel {
}

void compile();
void launch(RuntimeContext *ctx);
void launch(LaunchContextBuilder &ctx);
std::string get_object() {
return obj_path_;
}
Expand Down
13 changes: 7 additions & 6 deletions taichi/codegen/cc/cc_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ FunctionType CCProgramImpl::compile(const CompileConfig &compile_config,
this->add_kernel(std::move(ker));
return [ker_ptr](LaunchContextBuilder &ctx) {
for (auto &[idx, ptr] : ctx.array_ptrs) {
ctx.get_context().args[idx[0]] = (uint64)ptr;
ctx.cc_args[idx[0]] = (uint64)ptr;
}
return ker_ptr->launch(&ctx.get_context());
return ker_ptr->launch(ctx);
};
}

Expand Down Expand Up @@ -102,7 +102,7 @@ void CCRuntime::compile() {
execute(cc_program_impl_->config->cc_compile_cmd, obj_path_, src_path_);
}

void CCKernel::launch(RuntimeContext *ctx) {
void CCKernel::launch(LaunchContextBuilder &ctx) {
ActionRecorder::get_instance().record("launch_kernel",
{
ActionArg("kernel_name", name_),
Expand Down Expand Up @@ -181,10 +181,11 @@ CCFuncEntryType *CCProgramImpl::load_kernel(std::string const &name) {
return reinterpret_cast<CCFuncEntryType *>(dll_->load_function("Tk_" + name));
}

CCContext *CCProgramImpl::update_context(RuntimeContext *ctx) {
CCContext *CCProgramImpl::update_context(LaunchContextBuilder &ctx) {
// TODO(k-ye): Do you have other zero-copy ideas for arg buf?
std::memcpy(context_->args, ctx->args, taichi_max_num_args * sizeof(uint64));
context_->earg = (int *)ctx->extra_args;
std::memcpy(context_->args, ctx.cc_args,
taichi_max_num_args * sizeof(uint64));
context_->earg = (int *)ctx.get_context().extra_args;
return context_.get();
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/cc/cc_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class CCProgramImpl : public ProgramImpl {
CCFuncEntryType *load_kernel(std::string const &name);
void relink();

CCContext *update_context(RuntimeContext *ctx);
CCContext *update_context(LaunchContextBuilder &ctx);
void context_to_result_buffer();

void dump_cache_data_to_disk() override {
Expand Down
6 changes: 1 addition & 5 deletions taichi/program/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@ struct RuntimeContext {
char *arg_buffer{nullptr};

LLVMRuntime *runtime{nullptr};
// args can contain:
// - primitive_types
// - raw ptrs: for external array, or torch-based ndarray
// - DeviceAllocation*: for taichi ndaray
uint64_t args[taichi_max_num_args_total];

uint64_t grad_args[taichi_max_num_args_total];
int32_t extra_args[taichi_max_num_args_extra][taichi_max_num_indices];
int32_t cpu_thread_id;
Expand Down
4 changes: 3 additions & 1 deletion taichi/program/launch_context_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ T LaunchContextBuilder::get_grad_arg(int i) {
template <typename T>
void LaunchContextBuilder::set_arg(int i, T v) {
set_struct_arg({i}, v);
ctx_->args[i] = taichi_union_cast_with_different_sizes<uint64>(v);
if (kernel_->arch == Arch::cc) {
cc_args[i] = taichi_union_cast_with_different_sizes<uint64>(v);
}
set_array_device_allocation_type(i, DevAllocType::kNone);
}

Expand Down
3 changes: 3 additions & 0 deletions taichi/program/launch_context_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ class LaunchContextBuilder {
size_t result_buffer_size{0};
bool has_grad[taichi_max_num_args_total];

// TODO: remove this after CC backend is removed
uint64 cc_args[taichi_max_num_args]{0};

// Note that I've tried to group `array_runtime_size` and
// `is_device_allocations` into a small struct. However, it caused some test
// cases to stuck.
Expand Down
1 change: 0 additions & 1 deletion taichi/runtime/llvm/runtime_module/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ STRUCT_FIELD_ARRAY(PhysicalCoordinates, val);
#include "taichi/program/context.h"
#include "taichi/runtime/llvm/runtime_module/mem_request.h"

STRUCT_FIELD_ARRAY(RuntimeContext, args);
STRUCT_FIELD_ARRAY(RuntimeContext, grad_args);
STRUCT_FIELD(RuntimeContext, runtime);
STRUCT_FIELD(RuntimeContext, result_buffer)
Expand Down

0 comments on commit 4495c05

Please sign in to comment.