Skip to content

Commit

Permalink
[mesh] Fix MeshTaichi warnings in CUDA backend (#6369)
Browse files Browse the repository at this point in the history
continue from: #6306

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
g1n0st and pre-commit-ci[bot] committed Oct 19, 2022
1 parent dc6537e commit 4646105
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions taichi/transforms/make_mesh_block_local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,11 @@ Stmt *MakeMeshBlockLocal::create_cache_mapping(
Stmt *bls_ptr = body->push_back<BlockLocalPtrStmt>(
offset,
TypeFactory::get_instance().get_pointer_type(mapping_data_type_));
Stmt *casted_val = body->push_back<UnaryOpStmt>(UnaryOpType::cast_value,
global_val(body, idx_val));
casted_val->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32;
[[maybe_unused]] Stmt *bls_store =
body->push_back<GlobalStoreStmt>(bls_ptr, global_val(body, idx_val));
body->push_back<GlobalStoreStmt>(bls_ptr, casted_val);
});
}

Expand Down Expand Up @@ -368,8 +371,11 @@ void MakeMeshBlockLocal::fetch_mapping(
Stmt *global_ptr = body->push_back<GlobalPtrStmt>(
mapping_snode_, std::vector<Stmt *>{global_offset});
Stmt *global_load = body->push_back<GlobalLoadStmt>(global_ptr);
attr_callback_handler(body, idx_val, global_load);
return global_load;
Stmt *casted_global_load = body->push_back<UnaryOpStmt>(
UnaryOpType::cast_value, global_load);
casted_global_load->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32;
attr_callback_handler(body, idx_val, casted_global_load);
return casted_global_load;
});
} else {
// int i = threadIdx.x;
Expand All @@ -388,8 +394,11 @@ void MakeMeshBlockLocal::fetch_mapping(
Stmt *global_ptr = body->push_back<GlobalPtrStmt>(
mapping_snode_, std::vector<Stmt *>{global_offset});
Stmt *global_load = body->push_back<GlobalLoadStmt>(global_ptr);
attr_callback_handler(body, idx_val, global_load);
return global_load;
Stmt *casted_global_load = body->push_back<UnaryOpStmt>(
UnaryOpType::cast_value, global_load);
casted_global_load->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32;
attr_callback_handler(body, idx_val, casted_global_load);
return casted_global_load;
});
}
}
Expand Down Expand Up @@ -513,7 +522,8 @@ MakeMeshBlockLocal::MakeMeshBlockLocal(OffloadedStmt *offload,
mapping_snode_ = (offload->mesh->index_mapping
.find(std::make_pair(element_type, conv_type))
->second);
mapping_data_type_ = mapping_snode_->dt.ptr_removed();
// mapping_data_type_ = mapping_snode_->dt.ptr_removed();
mapping_data_type_ = PrimitiveType::i32;
mapping_dtype_size_ = data_type_size(mapping_data_type_);

// Ensure BLS alignment
Expand Down

0 comments on commit 4646105

Please sign in to comment.