Skip to content

Commit

Permalink
[gfx] Let gfx backends use LaunchContextBuilder to build arguments in…
Browse files Browse the repository at this point in the history
… struct type

ghstack-source-id: 171f46b640f0d679d078ca9c219d664be4de66c3
Pull Request resolved: #7662
  • Loading branch information
lin-hitonami authored and Taichi Gardener committed Apr 6, 2023
1 parent 0040789 commit 201d882
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 70 deletions.
22 changes: 21 additions & 1 deletion taichi/aot/graph_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,27 @@ void CompiledGraph::init_runtime_context(
symbolic_arg.tag == aot::ArgKind::kMatrix) {
TI_ASSERT(ival.tag == aot::ArgKind::kScalar);
// Matrix args are flattened so they're same as scalars.
ctx.set_arg(i, ival.val);
int type_size = data_type_size(symbolic_arg.dtype());
switch (type_size) {
case 1:
ctx.set_arg(i,
taichi_union_cast_with_different_sizes<int8>(ival.val));
break;
case 2:
ctx.set_arg(i,
taichi_union_cast_with_different_sizes<int16>(ival.val));
break;
case 4:
ctx.set_arg(i,
taichi_union_cast_with_different_sizes<int32>(ival.val));
break;
case 8:
ctx.set_arg(i,
taichi_union_cast_with_different_sizes<int64>(ival.val));
break;
default:
TI_ERROR("Unsupported type size {}", type_size);
}
} else if (symbolic_arg.tag == aot::ArgKind::kTexture) {
TI_ASSERT(ival.tag == aot::ArgKind::kTexture);
Texture *tex = reinterpret_cast<Texture *>(ival.val);
Expand Down
19 changes: 6 additions & 13 deletions taichi/program/launch_context_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,8 @@ LaunchContextBuilder::LaunchContextBuilder(CallableBase *kernel)
: kernel_(kernel),
owned_ctx_(std::make_unique<RuntimeContext>()),
ctx_(owned_ctx_.get()),
arg_buffer_(std::make_unique<char[]>(
arch_uses_llvm(kernel->arch)
? kernel->args_size
: sizeof(uint64) * taichi_max_num_args_total)),
result_buffer_(std::make_unique<char[]>(
arch_uses_llvm(kernel->arch)
? kernel->ret_size
: sizeof(uint64) * taichi_result_buffer_entries)),
arg_buffer_(std::make_unique<char[]>(kernel->args_size)),
result_buffer_(std::make_unique<char[]>(kernel->ret_size)),
ret_type_(kernel->ret_type),
arg_buffer_size(kernel->args_size),
args_type(kernel->args_type),
Expand Down Expand Up @@ -123,24 +117,23 @@ void LaunchContextBuilder::set_extra_arg_int(int i, int j, int32 d) {

template <typename T>
void LaunchContextBuilder::set_struct_arg(std::vector<int> index, T v) {
if (!arch_uses_llvm(kernel_->arch)) {
if (kernel_->arch == Arch::cc) {
return;
}
int offset = args_type->get_element_offset(index);
TI_ASSERT(offset + sizeof(T) <= arg_buffer_size);
*(T *)(ctx_->arg_buffer + offset) = v;
}

template <typename T>
T LaunchContextBuilder::get_arg(int i) {
if (arch_uses_llvm(kernel_->arch)) {
return get_struct_arg<T>({i});
}
return taichi_union_cast_with_different_sizes<T>(ctx_->args[i]);
return get_struct_arg<T>({i});
}

template <typename T>
T LaunchContextBuilder::get_struct_arg(std::vector<int> index) {
int offset = args_type->get_element_offset(index);
TI_ASSERT(offset + sizeof(T) <= arg_buffer_size);
return *(T *)(ctx_->arg_buffer + offset);
}

Expand Down
86 changes: 30 additions & 56 deletions taichi/runtime/gfx/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,74 +53,48 @@ class HostDeviceContextBlitter {
TI_ASSERT(device_->map(*device_args_buffer_, &device_base) ==
RhiResult::success);

#define TO_DEVICE(short_type, type) \
if (arg.dtype == PrimitiveTypeID::short_type) { \
auto d = host_ctx_.get_arg<type>(i); \
reinterpret_cast<type *>(device_ptr)[0] = d; \
break; \
}

for (int i = 0; i < ctx_attribs_->args().size(); ++i) {
const auto &arg = ctx_attribs_->args()[i];
void *device_ptr = (uint8_t *)device_base + arg.offset_in_mem;
do {
if (arg.is_array) {
if (host_ctx_.device_allocation_type[i] ==
LaunchContextBuilder::DevAllocType::kNone &&
ext_arr_size.at(i)) {
// Only need to blit ext arrs (host array)
uint32_t access = uint32_t(ctx_attribs_->arr_access.at(i));
if (access & uint32_t(irpass::ExternalPtrAccess::READ)) {
DeviceAllocation buffer = ext_arrays.at(i);
void *device_arr_ptr{nullptr};
TI_ASSERT(device_->map(buffer, &device_arr_ptr) ==
RhiResult::success);
const void *host_ptr = host_ctx_.array_ptrs[{i}];
std::memcpy(device_arr_ptr, host_ptr, ext_arr_size.at(i));
device_->unmap(buffer);
}
}
// Substitute in the device address.

// (penguinliong) We don't check the availability of physical pointer
// here. It should be done before you need this class.
if ((host_ctx_.device_allocation_type[i] ==
LaunchContextBuilder::DevAllocType::kNone ||
host_ctx_.device_allocation_type[i] ==
LaunchContextBuilder::DevAllocType::kNdarray)) {
uint64_t addr =
device_->get_memory_physical_pointer(ext_arrays.at(i));
reinterpret_cast<uint64 *>(device_ptr)[0] = addr;
if (arg.is_array) {
if (host_ctx_.device_allocation_type[i] ==
LaunchContextBuilder::DevAllocType::kNone &&
ext_arr_size.at(i)) {
// Only need to blit ext arrs (host array)
uint32_t access = uint32_t(ctx_attribs_->arr_access.at(i));
if (access & uint32_t(irpass::ExternalPtrAccess::READ)) {
DeviceAllocation buffer = ext_arrays.at(i);
void *device_arr_ptr{nullptr};
TI_ASSERT(device_->map(buffer, &device_arr_ptr) ==
RhiResult::success);
const void *host_ptr = host_ctx_.array_ptrs[{i}];
std::memcpy(device_arr_ptr, host_ptr, ext_arr_size.at(i));
device_->unmap(buffer);
}
// We should not process the rest
break;
}
// (penguinliong) Same. The availability of short/long int types depends
// on the kernels and compute graphs and the check should already be
// done during module loads.
TO_DEVICE(i8, int8)
TO_DEVICE(u8, uint8)
TO_DEVICE(i16, int16)
TO_DEVICE(u16, uint16)
TO_DEVICE(i32, int32)
TO_DEVICE(u32, uint32)
TO_DEVICE(f32, float32)
TO_DEVICE(i64, int64)
TO_DEVICE(u64, uint64)
TO_DEVICE(f64, float64)
TO_DEVICE(f16, uint16)
TI_ERROR("Device does not support arg type={}",
PrimitiveType::get(arg.dtype).to_string());
} while (false);
// Substitute in the device address.

if ((host_ctx_.device_allocation_type[i] ==
LaunchContextBuilder::DevAllocType::kNone ||
host_ctx_.device_allocation_type[i] ==
LaunchContextBuilder::DevAllocType::kNdarray) &&
device_->get_caps().get(
DeviceCapability::spirv_has_physical_storage_buffer)) {
uint64_t addr =
device_->get_memory_physical_pointer(ext_arrays.at(i));
host_ctx_.set_arg(i, addr);
}
}
}

std::memcpy(device_base, host_ctx_.get_context().arg_buffer,
ctx_attribs_->args_bytes());

void *device_ptr =
(uint8_t *)device_base + ctx_attribs_->extra_args_mem_offset();
std::memcpy(device_ptr, host_ctx_.get_context().extra_args,
ctx_attribs_->extra_args_bytes());

device_->unmap(*device_args_buffer_);
#undef TO_DEVICE
}

bool device_to_host(
Expand Down

0 comments on commit 201d882

Please sign in to comment.