Skip to content

Commit

Permalink
[aot] [vulkan] Output shapes/dims to AOT exported module (#4382)
Browse files Browse the repository at this point in the history
* [aot] export symbol visibility of DeviceAllocation

* [aot] [vulkan] Output shape/dim NdArray information

This is needed when loading the AOT module so the args of the kernel can
be filled properly.

* Update taichi/backends/vulkan/aot_module_builder_impl.cpp

Co-authored-by: Ye Kuang <k-ye@users.noreply.github.com>
  • Loading branch information
ghuau-innopeak and k-ye authored Feb 25, 2022
1 parent 2fbf3db commit 1e6904a
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 13 deletions.
6 changes: 3 additions & 3 deletions taichi/backends/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct LLVMRuntime;

// TODO: Figure out how to support images. Temporary solutions is to have all
// opque types such as images work as an allocation
struct DeviceAllocation {
struct TI_DLL_EXPORT DeviceAllocation {
Device *device{nullptr};
uint32_t alloc_id{0};

Expand All @@ -68,14 +68,14 @@ struct DeviceAllocation {
}
};

struct DeviceAllocationGuard : public DeviceAllocation {
struct TI_DLL_EXPORT DeviceAllocationGuard : public DeviceAllocation {
DeviceAllocationGuard(DeviceAllocation alloc) : DeviceAllocation(alloc) {
}
DeviceAllocationGuard(const DeviceAllocationGuard &) = delete;
~DeviceAllocationGuard();
};

struct DevicePtr : public DeviceAllocation {
struct TI_DLL_EXPORT DevicePtr : public DeviceAllocation {
uint64_t offset{0};

bool operator==(const DevicePtr &other) const {
Expand Down
22 changes: 13 additions & 9 deletions taichi/backends/vulkan/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,19 @@ class AotDataConverter {
res.args_buffer_size = in.ctx_attribs.args_bytes();
res.rets_buffer_size = in.ctx_attribs.rets_bytes();
for (const auto &arg : in.ctx_attribs.args()) {
res.scalar_args[arg.index] = visit(arg);
if (!arg.is_array) {
aot::ScalarArg scalar_arg{};
scalar_arg.dtype_name = arg.dt.to_string();
scalar_arg.offset_in_args_buf = arg.offset_in_mem;
res.scalar_args[arg.index] = scalar_arg;
} else {
aot::ArrayArg arr_arg{};
arr_arg.dtype_name = arg.dt.to_string();
arr_arg.field_dim = arg.field_dim;
arr_arg.element_shape = arg.element_shape;
arr_arg.shape_offset_in_args_buf = arg.index * sizeof(int32_t);
res.arr_args[arg.index] = arr_arg;
}
}
return res;
}
Expand All @@ -62,14 +74,6 @@ class AotDataConverter {
res.gpu_block_size = in.advisory_num_threads_per_group;
return res;
}

aot::ScalarArg visit(
const spirv::KernelContextAttributes::ArgAttributes &in) const {
aot::ScalarArg res{};
res.dtype_name = in.dt.to_string();
res.offset_in_args_buf = in.offset_in_mem;
return res;
}
};

} // namespace
Expand Down
4 changes: 4 additions & 0 deletions taichi/codegen/spirv/kernel_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel)
aa.dt = ka.dt;
const size_t dt_bytes = data_type_size(aa.dt);
aa.is_array = ka.is_array;
if (aa.is_array) {
aa.field_dim = ka.total_dim - ka.element_shape.size();
aa.element_shape = ka.element_shape;
}
aa.stride = dt_bytes;
aa.index = arg_attribs_vec_.size();
arg_attribs_vec_.push_back(aa);
Expand Down
4 changes: 3 additions & 1 deletion taichi/codegen/spirv/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,10 @@ class KernelContextAttributes {
int index{-1};
DataType dt;
bool is_array{false};
std::vector<int> element_shape;
std::size_t field_dim{0};

TI_IO_DEF(stride, offset_in_mem, index, is_array);
TI_IO_DEF(stride, offset_in_mem, index, is_array, element_shape, field_dim);
};

public:
Expand Down

0 comments on commit 1e6904a

Please sign in to comment.