Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii committed Dec 13, 2024
1 parent d94cfd3 commit 825d5aa
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
8 changes: 3 additions & 5 deletions src/runtime/contrib/nvshmem/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
*/
#include <nvshmem.h>
#include <nvshmemx.h>
#include <picojson.h>
#include <tvm/runtime/disco/disco_worker.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include "../../cuda/cuda_common.h"

#include <picojson.h>

namespace tvm {
namespace runtime {

Expand Down Expand Up @@ -106,15 +105,14 @@ void InitNVSHMEMWrapper(String args) {
int worker_id_start = static_cast<int>(obj["pe_start"].get<int64_t>());

InitNVSHMEM(uid_64, num_workers, worker_id_start);

}

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID);

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM);

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper").set_body_typed(InitNVSHMEMWrapper);

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper")
.set_body_typed(InitNVSHMEMWrapper);

} // namespace runtime
} // namespace tvm
24 changes: 12 additions & 12 deletions src/runtime/contrib/nvshmem/kv_transfer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ __global__ void KVTransfer(T* pages, T* k_data, T* v_data, int32_t* remote_posit
int position = remote_position_map[global_pos];
if (position == -1) {
continue;
};
}
if (local_num_kv_head <= remote_num_kv_head) {
// gather
assert(remote_num_kv_head % local_num_kv_head == 0);
Expand Down Expand Up @@ -98,7 +98,7 @@ __global__ void KVTransferPageToPage(T* remote_pages, T* local_pages, int32_t* r
int local_position = local_position_map[global_pos];
if (remote_position == -1 || local_position == -1) {
continue;
};
}
if (local_num_kv_head <= remote_num_kv_head) {
// gather
assert(remote_num_kv_head % local_num_kv_head == 0);
Expand Down Expand Up @@ -234,14 +234,14 @@ int _KVTransfer(DLTensor* remote_pages, DLTensor* k, DLTensor* v, DLTensor* remo
remote_num_kv_head, REMOTE_NUM_KV_HEAD,
{DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, {
dtype_in* remote_pages_data =
(dtype_in*)((char*)remote_pages->data + remote_pages->byte_offset);
dtype_in* k_data = (dtype_in*)((char*)k->data + k->byte_offset);
dtype_in* v_data = (dtype_in*)((char*)v->data + v->byte_offset);
reinterpret_cast<dtype_in*>((char*)remote_pages->data + remote_pages->byte_offset);
dtype_in* k_data = reinterpret_cast<dtype_in*>((char*)k->data + k->byte_offset);
dtype_in* v_data = reinterpret_cast<dtype_in*>((char*)v->data + v->byte_offset);
int32_t* remote_position_map_data =
(int32_t*)((char*)remote_position_map->data +
reinterpret_cast<int32_t*>((char*)remote_position_map->data +
remote_position_map->byte_offset);
int32_t* remote_tp_group_pe_offset_data =
(int32_t*)((char*)remote_tp_group_pe_offset->data +
reinterpret_cast<int32_t*>((char*)remote_tp_group_pe_offset->data +
remote_tp_group_pe_offset->byte_offset);
KVTransfer<dtype_in, LOCAL_NUM_KV_HEAD, REMOTE_NUM_KV_HEAD, HEAD_DIM, PAGE_SIZE>
<<<blocks, threads, 0, static_cast<cudaStream_t>(transfer_stream)>>>(
Expand Down Expand Up @@ -304,16 +304,16 @@ int _KVTransferPageToPage(DLTensor* remote_pages, DLTensor* local_pages,
remote_num_kv_head, REMOTE_NUM_KV_HEAD,
{DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, {
dtype_in* remote_pages_data =
(dtype_in*)((char*)remote_pages->data + remote_pages->byte_offset);
reinterpret_cast<dtype_in*>((char*)remote_pages->data + remote_pages->byte_offset);
dtype_in* local_pages_data =
(dtype_in*)((char*)local_pages->data + local_pages->byte_offset);
reinterpret_cast<dtype_in*>((char*)local_pages->data + local_pages->byte_offset);
int32_t* remote_position_map_data =
(int32_t*)((char*)remote_position_map->data +
reinterpret_cast<int32_t*>((char*)remote_position_map->data +
remote_position_map->byte_offset);
int32_t* local_position_map_data = (int32_t*)((char*)local_position_map->data +
int32_t* local_position_map_data = reinterpret_cast<int32_t*>((char*)local_position_map->data +
local_position_map->byte_offset);
int32_t* remote_tp_group_pe_offset_data =
(int32_t*)((char*)remote_tp_group_pe_offset->data +
reinterpret_cast<int32_t*>((char*)remote_tp_group_pe_offset->data +
remote_tp_group_pe_offset->byte_offset);
KVTransferPageToPage<dtype_in, LOCAL_NUM_KV_HEAD, REMOTE_NUM_KV_HEAD, HEAD_DIM,
PAGE_SIZE>
Expand Down

0 comments on commit 825d5aa

Please sign in to comment.