diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index b02073799988a..288197c0efac8 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -699,7 +699,7 @@ FunctionType CUDAModuleToFunctionConverter::convert( TI_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, task.block_dim); cuda_module->launch(task.name, task.grid_dim, task.block_dim, 0, - {&context}); + {&context}, {}); } // copy data back to host diff --git a/taichi/inc/archs.inc.h b/taichi/inc/archs.inc.h index d478981ef502c..979646b0a646b 100644 --- a/taichi/inc/archs.inc.h +++ b/taichi/inc/archs.inc.h @@ -14,6 +14,6 @@ PER_ARCH(opengl) // OpenGL Compute Shaders PER_ARCH(dx11) // Microsoft DirectX 11, WIP PER_ARCH(dx12) // Microsoft DirectX 12, WIP PER_ARCH(opencl) // OpenCL, N/A -PER_ARCH(amdgpu) // AMD GPU, N/A +PER_ARCH(amdgpu) // AMD GPU, WIP PER_ARCH(vulkan) // Vulkan PER_ARCH(gles) // OpenGL ES diff --git a/taichi/jit/jit_module.h b/taichi/jit/jit_module.h index 1fa6e19060a05..0ea19ab08ef0d 100644 --- a/taichi/jit/jit_module.h +++ b/taichi/jit/jit_module.h @@ -2,6 +2,7 @@ #include #include +#include #include "taichi/inc/constants.h" #include "taichi/util/lang_util.h" @@ -33,31 +34,36 @@ class JITModule { return ret; } - static std::vector get_arg_pointers() { - return std::vector(); + inline std::tuple, std::vector > get_arg_pointers() { + return std::make_tuple(std::vector(), std::vector()); } template - static std::vector get_arg_pointers(T &t, Args &...args) { - auto ret = get_arg_pointers(args...); - ret.insert(ret.begin(), &t); - return ret; + inline std::tuple, std::vector > get_arg_pointers( + T &t, + Args &...args) { + auto [arg_pointers, arg_sizes] = get_arg_pointers(args...); + arg_pointers.insert(arg_pointers.begin(), &t); + arg_sizes.insert(arg_sizes.begin(), sizeof(t)); + return std::make_tuple(arg_pointers, arg_sizes); } // Note: **call** is for serial functions // Note: args must pass by value + // Note: AMDGPU need to pass args by extra_arg currently template void call(const std::string &name, Args... args) { if (direct_dispatch()) { get_function(name)(args...); } else { - auto arg_pointers = JITModule::get_arg_pointers(args...); - call(name, arg_pointers); + auto [arg_pointers, arg_sizes] = JITModule::get_arg_pointers(args...); + call(name, arg_pointers, arg_sizes); } } virtual void call(const std::string &name, - const std::vector &arg_pointers) { + const std::vector &arg_pointers, + const std::vector &arg_sizes) { TI_NOT_IMPLEMENTED } @@ -69,20 +75,20 @@ class JITModule { std::size_t block_dim, std::size_t shared_mem_bytes, Args... args) { - auto arg_pointers = JITModule::get_arg_pointers(args...); - launch(name, grid_dim, block_dim, shared_mem_bytes, arg_pointers); + auto [arg_pointers, arg_sizes] = JITModule::get_arg_pointers(args...); + launch(name, grid_dim, block_dim, shared_mem_bytes, arg_pointers, + arg_sizes); } virtual void launch(const std::string &name, std::size_t grid_dim, std::size_t block_dim, std::size_t shared_mem_bytes, - const std::vector &arg_pointers) { + const std::vector &arg_pointers, + const std::vector &arg_sizes) { TI_NOT_IMPLEMENTED } - // directly call the function (e.g. on CPU), or via another runtime system - // (e.g. cudaLaunch)? virtual bool direct_dispatch() const = 0; virtual ~JITModule() { diff --git a/taichi/rhi/amdgpu/amdgpu_context.cpp b/taichi/rhi/amdgpu/amdgpu_context.cpp index 04fb173a0e73b..114c6241817ab 100644 --- a/taichi/rhi/amdgpu/amdgpu_context.cpp +++ b/taichi/rhi/amdgpu/amdgpu_context.cpp @@ -61,17 +61,72 @@ std::string AMDGPUContext::get_device_name() { return str; } +int AMDGPUContext::get_args_byte(std::vector arg_sizes) { + int byte_cnt = 0; + int naive_add = 0; + for (auto &size : arg_sizes) { + naive_add += size; + if (size < 32) { + if ((byte_cnt + size) % 32 > (byte_cnt) % 32 || + (byte_cnt + size) % 32 == 0) + byte_cnt += size; + else + byte_cnt += 32 - byte_cnt % 32 + size; + } else { + if (byte_cnt % 32 != 0) + byte_cnt += 32 - byte_cnt % 32 + size; + else + byte_cnt += size; + } + } + return byte_cnt; +} + +void AMDGPUContext::pack_args(std::vector arg_pointers, + std::vector arg_sizes, + char *arg_packed) { + int byte_cnt = 0; + for (int ii = 0; ii < arg_pointers.size(); ii++) { + // The parameter is taken as a vec4 + if (arg_sizes[ii] < 32) { + if ((byte_cnt + arg_sizes[ii]) % 32 > (byte_cnt % 32) || + (byte_cnt + arg_sizes[ii]) % 32 == 0) { + std::memcpy(arg_packed + byte_cnt, arg_pointers[ii], arg_sizes[ii]); + byte_cnt += arg_sizes[ii]; + } else { + int padding_size = 32 - byte_cnt % 32; + byte_cnt += padding_size; + std::memcpy(arg_packed + byte_cnt, arg_pointers[ii], arg_sizes[ii]); + byte_cnt += arg_sizes[ii]; + } + } else { + if (byte_cnt % 32 != 0) { + int padding_size = 32 - byte_cnt % 32; + byte_cnt += padding_size; + std::memcpy(arg_packed + byte_cnt, arg_pointers[ii], arg_sizes[ii]); + byte_cnt += arg_sizes[ii]; + } else { + std::memcpy(arg_packed + byte_cnt, arg_pointers[ii], arg_sizes[ii]); + byte_cnt += arg_sizes[ii]; + } + } + } +} + void AMDGPUContext::launch(void *func, const std::string &task_name, - void *arg_pointers, + const std::vector &arg_pointers, + const std::vector &arg_sizes, unsigned grid_dim, unsigned block_dim, - std::size_t dynamic_shared_mem_bytes, - int arg_bytes) { + std::size_t dynamic_shared_mem_bytes) { + auto pack_size = get_args_byte(arg_sizes); + char *packed_arg = (char *)std::malloc(pack_size); + pack_args(arg_pointers, arg_sizes, packed_arg); if (grid_dim > 0) { std::lock_guard _(lock_); - void *config[] = {(void *)0x01, const_cast(arg_pointers), - (void *)0x02, &arg_bytes, (void *)0x03}; + void *config[] = {(void *)0x01, (void *)packed_arg, (void *)0x02, + (void *)&pack_size, (void *)0x03}; driver_.launch_kernel(func, grid_dim, 1, 1, block_dim, 1, 1, dynamic_shared_mem_bytes, nullptr, nullptr, reinterpret_cast(&config)); diff --git a/taichi/rhi/amdgpu/amdgpu_context.h b/taichi/rhi/amdgpu/amdgpu_context.h index 7e182e07ea3d7..6c0b3048824f1 100644 --- a/taichi/rhi/amdgpu/amdgpu_context.h +++ b/taichi/rhi/amdgpu/amdgpu_context.h @@ -34,13 +34,19 @@ class AMDGPUContext { return dev_count_ != 0; } + void pack_args(std::vector arg_pointers, + std::vector arg_sizes, + char *arg_packed); + + int get_args_byte(std::vector arg_sizes); + void launch(void *func, const std::string &task_name, - void *arg_pointers, + const std::vector &arg_pointers, + const std::vector &arg_sizes, unsigned grid_dim, unsigned block_dim, - std::size_t dynamic_shared_mem_bytes, - int arg_bytes); + std::size_t dynamic_shared_mem_bytes); void set_debug(bool debug) { debug_ = debug; diff --git a/taichi/rhi/cuda/cuda_context.cpp b/taichi/rhi/cuda/cuda_context.cpp index 63c29da697337..766d358ef568d 100644 --- a/taichi/rhi/cuda/cuda_context.cpp +++ b/taichi/rhi/cuda/cuda_context.cpp @@ -75,6 +75,7 @@ std::string CUDAContext::get_device_name() { void CUDAContext::launch(void *func, const std::string &task_name, std::vector arg_pointers, + std::vector arg_sizes, unsigned grid_dim, unsigned block_dim, std::size_t dynamic_shared_mem_bytes) { diff --git a/taichi/rhi/cuda/cuda_context.h b/taichi/rhi/cuda/cuda_context.h index f514b3164e6b3..578aa69af38dd 100644 --- a/taichi/rhi/cuda/cuda_context.h +++ b/taichi/rhi/cuda/cuda_context.h @@ -42,6 +42,7 @@ class CUDAContext { void launch(void *func, const std::string &task_name, std::vector arg_pointers, + std::vector arg_sizes, unsigned grid_dim, unsigned block_dim, std::size_t dynamic_shared_mem_bytes); diff --git a/taichi/runtime/cuda/jit_cuda.h b/taichi/runtime/cuda/jit_cuda.h index 16cbc3957c28c..e231bc5cb6654 100644 --- a/taichi/runtime/cuda/jit_cuda.h +++ b/taichi/runtime/cuda/jit_cuda.h @@ -60,18 +60,21 @@ class JITModuleCUDA : public JITModule { } void call(const std::string &name, - const std::vector &arg_pointers) override { - launch(name, 1, 1, 0, arg_pointers); + const std::vector &arg_pointers, + const std::vector &arg_sizes) override { + launch(name, 1, 1, 0, arg_pointers, arg_sizes); } void launch(const std::string &name, std::size_t grid_dim, std::size_t block_dim, std::size_t dynamic_shared_mem_bytes, - const std::vector &arg_pointers) override { + const std::vector &arg_pointers, + const std::vector &arg_sizes) override { auto func = lookup_function(name); - CUDAContext::get_instance().launch(func, name, arg_pointers, grid_dim, - block_dim, dynamic_shared_mem_bytes); + CUDAContext::get_instance().launch(func, name, arg_pointers, arg_sizes, + grid_dim, block_dim, + dynamic_shared_mem_bytes); } bool direct_dispatch() const override {