Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] [spirv] Support dynamic indexing in spirv #6990

Merged
merged 3 commits into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions taichi/codegen/spirv/snode_struct_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class StructCompiler {
cell_stride += snode_size;
snode_descriptors_.find(ch_snode->id)
->second.mem_offset_in_parent_cell = child_offset;
ch_snode->offset_bytes_in_parent_cell = child_offset;
}
sn_desc.cell_stride = cell_stride;

Expand Down
53 changes: 33 additions & 20 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,40 +219,53 @@ class TaskCodegen : public IRVisitor {
}

void visit(AllocaStmt *alloca) override {
spirv::Value ptr_val;
if (alloca->ret_type->is<TensorType>()) {
// Alloca for shared memory / workgroup memory
if (!alloca->is_shared) {
TI_ERROR(
"Tensor type for dyanmic index is not yet supported on Vulkan.");
}
auto tensor_type = alloca->ret_type->cast<TensorType>();
auto elem_num = tensor_type->get_num_elements();
spirv::SType elem_type =
ir_->get_primitive_type(tensor_type->get_element_type());

spirv::SType arr_type = ir_->get_array_type(elem_type, elem_num);
spirv::Value ptr_val = ir_->alloca_workgroup_array(arr_type);
shared_array_binds_.push_back(ptr_val);
ir_->register_value(alloca->raw_name(), ptr_val);
if (alloca->is_shared) { // for shared memory / workgroup memory
ptr_val = ir_->alloca_workgroup_array(arr_type);
shared_array_binds_.push_back(ptr_val);
} else { // for function memory
ptr_val = ir_->alloca_variable(arr_type);
}
} else {
// Alloca for a single variable
spirv::SType src_type = ir_->get_primitive_type(alloca->element_type());
spirv::Value ptr_val = ir_->alloca_variable(src_type);
ptr_val = ir_->alloca_variable(src_type);
ir_->store_variable(ptr_val, ir_->get_zero(src_type));
ir_->register_value(alloca->raw_name(), ptr_val);
}
ir_->register_value(alloca->raw_name(), ptr_val);
}

void visit(MatrixPtrStmt *stmt) override {
spirv::SType data_type =
ir_->get_primitive_type(stmt->element_type().ptr_removed());
spirv::SType ptr_type =
ir_->get_pointer_type(data_type, spv::StorageClassWorkgroup);
auto origin_val = ir_->query_value(stmt->origin->raw_name());
auto offset_val = ir_->query_value(stmt->offset->raw_name());
Value offset_ptr =
ir_->make_value(spv::OpAccessChain, ptr_type, origin_val, offset_val);
ir_->register_value(stmt->raw_name(), offset_ptr);
spirv::Value ptr_val;
spirv::Value origin_val = ir_->query_value(stmt->origin->raw_name());
spirv::Value offset_val = ir_->query_value(stmt->offset->raw_name());
auto dt = stmt->element_type().ptr_removed();
if (stmt->offset_used_as_index()) {
if (stmt->origin->is<AllocaStmt>()) {
spirv::SType ptr_type = ir_->get_pointer_type(
ir_->get_primitive_type(dt), origin_val.stype.storage_class);
ptr_val = ir_->make_value(spv::OpAccessChain, ptr_type, origin_val,
offset_val);
} else if (stmt->origin->is<GlobalTemporaryStmt>()) {
spirv::Value dt_bytes = ir_->int_immediate_number(
ir_->i32_type(), ir_->get_primitive_type_size(dt), false);
spirv::Value offset_bytes = ir_->mul(dt_bytes, offset_val);
ptr_val = ir_->add(origin_val, offset_bytes);
ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin];
} else {
TI_NOT_IMPLEMENTED;
}
} else { // offset used as bytes
strongoier marked this conversation as resolved.
Show resolved Hide resolved
ptr_val = ir_->add(origin_val, ir_->cast(origin_val.stype, offset_val));
ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin];
}
ir_->register_value(stmt->raw_name(), ptr_val);
}

void visit(LocalLoadStmt *stmt) override {
Expand Down
4 changes: 3 additions & 1 deletion taichi/program/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ bool is_extension_supported(Arch arch, Extension ext) {
{Arch::metal,
{Extension::adstack, Extension::assertion, Extension::dynamic_index,
Extension::sparse}},
{Arch::opengl, {Extension::extfunc}},
{Arch::opengl, {Extension::dynamic_index, Extension::extfunc}},
strongoier marked this conversation as resolved.
Show resolved Hide resolved
{Arch::gles, {}},
{Arch::vulkan, {Extension::dynamic_index}},
{Arch::dx11, {Extension::dynamic_index}},
{Arch::cc, {Extension::data64, Extension::extfunc, Extension::adstack}},
};
// if (with_opengl_extension_data64())
Expand Down
1 change: 1 addition & 0 deletions taichi/rhi/opengl/opengl_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ bool initialize_opengl(bool use_gles, bool error_tolerance) {
glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 3);
glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 1);
} else {
glfwWindowHint(GLFW_CLIENT_API, GLFW_OPENGL_API);
glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);
glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 4);
glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3);
Expand Down
22 changes: 8 additions & 14 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,39 +166,33 @@ def run():

def _test_local_matrix_non_constant_index():
@ti.kernel
def func1():
def func1() -> ti.types.vector(3, ti.i32):
tmp = ti.Vector([1, 2, 3])
for i in range(3):
vec = ti.Vector([4, 5, 6])
for j in range(3):
vec[tmp[i] % 3] += vec[j]
tmp[i] = vec[tmp[i] % 3]
assert tmp[0] == 24
assert tmp[1] == 30
assert tmp[2] == 19
return tmp

func1()
assert (func1() == ti.Vector([24, 30, 19])).all()

@ti.kernel
def func2(i: ti.i32, j: ti.i32, k: ti.i32):
def func2(i: ti.i32, j: ti.i32, k: ti.i32) -> ti.i32:
tmp = ti.Matrix([[k, k * 2], [k * 2, k * 3]])
assert tmp[i, j] == k * (i + j + 1)
return tmp[i, j]

for i in range(2):
for j in range(2):
func2(i, j, 10)
assert func2(i, j, 10) == 10 * (i + j + 1)


@test_utils.test(require=ti.extension.dynamic_index,
dynamic_index=True,
debug=True)
@test_utils.test(require=ti.extension.dynamic_index, dynamic_index=True)
def test_local_matrix_non_constant_index():
_test_local_matrix_non_constant_index()


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix_scalarize=False,
debug=True)
@test_utils.test(arch=[ti.cuda, ti.cpu], real_matrix_scalarize=False)
def test_local_matrix_non_constant_index_real_matrix():
_test_local_matrix_non_constant_index()

Expand Down