From 825d5aa74de6ee139beb3781b6609a60c29b56d2 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Fri, 13 Dec 2024 21:23:36 +0000 Subject: [PATCH] fix lint --- src/runtime/contrib/nvshmem/init.cc | 8 +++----- src/runtime/contrib/nvshmem/kv_transfer.cu | 24 +++++++++++----------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/runtime/contrib/nvshmem/init.cc b/src/runtime/contrib/nvshmem/init.cc index 051022ced8f1..33a787b5b9f3 100644 --- a/src/runtime/contrib/nvshmem/init.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -18,14 +18,13 @@ */ #include #include +#include #include #include #include #include "../../cuda/cuda_common.h" -#include - namespace tvm { namespace runtime { @@ -106,15 +105,14 @@ void InitNVSHMEMWrapper(String args) { int worker_id_start = static_cast(obj["pe_start"].get()); InitNVSHMEM(uid_64, num_workers, worker_id_start); - } TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID); TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM); -TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper").set_body_typed(InitNVSHMEMWrapper); - +TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper") + .set_body_typed(InitNVSHMEMWrapper); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/nvshmem/kv_transfer.cu b/src/runtime/contrib/nvshmem/kv_transfer.cu index 6cd365357dcb..71eac46bd260 100644 --- a/src/runtime/contrib/nvshmem/kv_transfer.cu +++ b/src/runtime/contrib/nvshmem/kv_transfer.cu @@ -50,7 +50,7 @@ __global__ void KVTransfer(T* pages, T* k_data, T* v_data, int32_t* remote_posit int position = remote_position_map[global_pos]; if (position == -1) { continue; - }; + } if (local_num_kv_head <= remote_num_kv_head) { // gather assert(remote_num_kv_head % local_num_kv_head == 0); @@ -98,7 +98,7 @@ __global__ void KVTransferPageToPage(T* remote_pages, T* local_pages, int32_t* r int local_position = local_position_map[global_pos]; if (remote_position == -1 || local_position == -1) { continue; - }; + } if (local_num_kv_head <= remote_num_kv_head) { // gather assert(remote_num_kv_head % local_num_kv_head == 0); @@ -234,14 +234,14 @@ int _KVTransfer(DLTensor* remote_pages, DLTensor* k, DLTensor* v, DLTensor* remo remote_num_kv_head, REMOTE_NUM_KV_HEAD, {DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, { dtype_in* remote_pages_data = - (dtype_in*)((char*)remote_pages->data + remote_pages->byte_offset); - dtype_in* k_data = (dtype_in*)((char*)k->data + k->byte_offset); - dtype_in* v_data = (dtype_in*)((char*)v->data + v->byte_offset); + reinterpret_cast((char*)remote_pages->data + remote_pages->byte_offset); + dtype_in* k_data = reinterpret_cast((char*)k->data + k->byte_offset); + dtype_in* v_data = reinterpret_cast((char*)v->data + v->byte_offset); int32_t* remote_position_map_data = - (int32_t*)((char*)remote_position_map->data + + reinterpret_cast((char*)remote_position_map->data + remote_position_map->byte_offset); int32_t* remote_tp_group_pe_offset_data = - (int32_t*)((char*)remote_tp_group_pe_offset->data + + reinterpret_cast((char*)remote_tp_group_pe_offset->data + remote_tp_group_pe_offset->byte_offset); KVTransfer <<(transfer_stream)>>>( @@ -304,16 +304,16 @@ int _KVTransferPageToPage(DLTensor* remote_pages, DLTensor* local_pages, remote_num_kv_head, REMOTE_NUM_KV_HEAD, {DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, { dtype_in* remote_pages_data = - (dtype_in*)((char*)remote_pages->data + remote_pages->byte_offset); + reinterpret_cast((char*)remote_pages->data + remote_pages->byte_offset); dtype_in* local_pages_data = - (dtype_in*)((char*)local_pages->data + local_pages->byte_offset); + reinterpret_cast((char*)local_pages->data + local_pages->byte_offset); int32_t* remote_position_map_data = - (int32_t*)((char*)remote_position_map->data + + reinterpret_cast((char*)remote_position_map->data + remote_position_map->byte_offset); - int32_t* local_position_map_data = (int32_t*)((char*)local_position_map->data + + int32_t* local_position_map_data = reinterpret_cast((char*)local_position_map->data + local_position_map->byte_offset); int32_t* remote_tp_group_pe_offset_data = - (int32_t*)((char*)remote_tp_group_pe_offset->data + + reinterpret_cast((char*)remote_tp_group_pe_offset->data + remote_tp_group_pe_offset->byte_offset); KVTransferPageToPage