Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vulkan] Fix some vulkan stuff #3198

Merged
merged 3 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions taichi/backends/vulkan/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ class VkRuntime ::Impl {
}

void synchronize() {
if (current_cmdlist_) {
device_->get_compute_stream()->submit_synced(current_cmdlist_.get());
current_cmdlist_ = nullptr;
}
device_->get_compute_stream()->command_sync();
}

Expand Down
14 changes: 8 additions & 6 deletions taichi/backends/vulkan/snode_struct_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace {

class StructCompiler {
public:
CompiledSNodeStructs run(const SNode &root) {
CompiledSNodeStructs run(SNode &root) {
TI_ASSERT(root.type == SNodeType::root);

CompiledSNodeStructs result;
Expand All @@ -21,7 +21,7 @@ class StructCompiler {
}

private:
std::size_t compute_snode_size(const SNode *sn) {
std::size_t compute_snode_size(SNode *sn) {
const bool is_place = sn->is_place();

SNodeDescriptor sn_desc;
Expand All @@ -31,9 +31,9 @@ class StructCompiler {
sn_desc.container_stride = sn_desc.cell_stride;
} else {
std::size_t cell_stride = 0;
for (const auto &ch : sn->ch) {
const auto child_offset = cell_stride;
const auto *ch_snode = ch.get();
for (auto &ch : sn->ch) {
auto child_offset = cell_stride;
auto *ch_snode = ch.get();
cell_stride += compute_snode_size(ch_snode);
snode_descriptors_.find(ch_snode->id)
->second.mem_offset_in_parent_cell = child_offset;
Expand All @@ -43,6 +43,8 @@ class StructCompiler {
cell_stride * sn_desc.cells_per_container_pot();
}

sn->cell_size_bytes = sn_desc.cell_stride;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does VK backend actually use this? That said, we should make this part backend-neutral.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using cell size bytes for computing SNode Device ptr in ggui.


sn_desc.total_num_cells_from_root = 1;
for (const auto &e : sn->extractors) {
// Note that the extractors are set in two places:
Expand Down Expand Up @@ -77,7 +79,7 @@ int SNodeDescriptor::cells_per_container_pot() const {
return snode->num_cells_per_container;
}

CompiledSNodeStructs compile_snode_structs(const SNode &root) {
CompiledSNodeStructs compile_snode_structs(SNode &root) {
StructCompiler compiler;
return compiler.run(root);
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/vulkan/snode_struct_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct CompiledSNodeStructs {
SNodeDescriptorsMap snode_descriptors;
};

CompiledSNodeStructs compile_snode_structs(const SNode &root);
CompiledSNodeStructs compile_snode_structs(SNode &root);

} // namespace vulkan
} // namespace lang
Expand Down
6 changes: 4 additions & 2 deletions taichi/backends/vulkan/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,10 @@ Value IRBuilder::float_atomic(AtomicOpType op_type,
.add_seq(tmp1, true_label, merge_label)
.commit(&func_);
ib_.begin(spv::OpLabel).add(true_label).commit(&func_);
Value tmp2 = load_variable(addr_ptr, t_int32_);
store_variable(old_val, tmp2);
Value tmp2 = load_variable(addr_ptr, t_fp32_);
Value tmp2_int = new_value(t_int32_, ValueKind::kNormal);
ib_.begin(spv::OpBitcast).add_seq(t_int32_, tmp2_int, tmp2).commit(&func_);
store_variable(old_val, tmp2_int);
Value tmp3 = load_variable(old_val, t_int32_);
Value tmp4 = new_value(t_fp32_, ValueKind::kNormal);
ib_.begin(spv::OpBitcast).add_seq(t_fp32_, tmp4, tmp3).commit(&func_);
Expand Down