Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii committed Dec 13, 2024
1 parent 825d5aa commit 5b81499
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions src/runtime/contrib/nvshmem/kv_transfer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,15 @@ int _KVTransfer(DLTensor* remote_pages, DLTensor* k, DLTensor* v, DLTensor* remo
{DISPATCH_NUM_KV_HEAD(
remote_num_kv_head, REMOTE_NUM_KV_HEAD,
{DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, {
dtype_in* remote_pages_data =
reinterpret_cast<dtype_in*>((char*)remote_pages->data + remote_pages->byte_offset);
dtype_in* remote_pages_data = reinterpret_cast<dtype_in*>(
(char*)remote_pages->data + remote_pages->byte_offset);
dtype_in* k_data = reinterpret_cast<dtype_in*>((char*)k->data + k->byte_offset);
dtype_in* v_data = reinterpret_cast<dtype_in*>((char*)v->data + v->byte_offset);
int32_t* remote_position_map_data =
reinterpret_cast<int32_t*>((char*)remote_position_map->data +
remote_position_map->byte_offset);
int32_t* remote_position_map_data = reinterpret_cast<int32_t*>(
(char*)remote_position_map->data + remote_position_map->byte_offset);
int32_t* remote_tp_group_pe_offset_data =
reinterpret_cast<int32_t*>((char*)remote_tp_group_pe_offset->data +
remote_tp_group_pe_offset->byte_offset);
remote_tp_group_pe_offset->byte_offset);
KVTransfer<dtype_in, LOCAL_NUM_KV_HEAD, REMOTE_NUM_KV_HEAD, HEAD_DIM, PAGE_SIZE>
<<<blocks, threads, 0, static_cast<cudaStream_t>(transfer_stream)>>>(
remote_pages_data, k_data, v_data, remote_position_map_data, kv_len,
Expand Down Expand Up @@ -303,18 +302,17 @@ int _KVTransferPageToPage(DLTensor* remote_pages, DLTensor* local_pages,
{DISPATCH_NUM_KV_HEAD(
remote_num_kv_head, REMOTE_NUM_KV_HEAD,
{DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, {
dtype_in* remote_pages_data =
reinterpret_cast<dtype_in*>((char*)remote_pages->data + remote_pages->byte_offset);
dtype_in* local_pages_data =
reinterpret_cast<dtype_in*>((char*)local_pages->data + local_pages->byte_offset);
int32_t* remote_position_map_data =
reinterpret_cast<int32_t*>((char*)remote_position_map->data +
remote_position_map->byte_offset);
int32_t* local_position_map_data = reinterpret_cast<int32_t*>((char*)local_position_map->data +
local_position_map->byte_offset);
dtype_in* remote_pages_data = reinterpret_cast<dtype_in*>(
(char*)remote_pages->data + remote_pages->byte_offset);
dtype_in* local_pages_data = reinterpret_cast<dtype_in*>(
(char*)local_pages->data + local_pages->byte_offset);
int32_t* remote_position_map_data = reinterpret_cast<int32_t*>(
(char*)remote_position_map->data + remote_position_map->byte_offset);
int32_t* local_position_map_data = reinterpret_cast<int32_t*>(
(char*)local_position_map->data + local_position_map->byte_offset);
int32_t* remote_tp_group_pe_offset_data =
reinterpret_cast<int32_t*>((char*)remote_tp_group_pe_offset->data +
remote_tp_group_pe_offset->byte_offset);
remote_tp_group_pe_offset->byte_offset);
KVTransferPageToPage<dtype_in, LOCAL_NUM_KV_HEAD, REMOTE_NUM_KV_HEAD, HEAD_DIM,
PAGE_SIZE>
<<<blocks, threads, 0, static_cast<cudaStream_t>(transfer_stream)>>>(
Expand Down

0 comments on commit 5b81499

Please sign in to comment.