Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Fix host view for mnnvl #166

Merged
merged 2 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/include/wholememory/wholememory.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ wholememory_error_code_t wholememory_store_to_file(wholememory_handle_t wholemem
size_t file_entry_size,
const char* local_file_name);

/**
* @param comm : WholeMemory Comm
* @return : bool
*/
bool wholememory_is_intranode_communicator(wholememory_comm_t comm);

bool wholememory_is_build_with_nvshmem();
#ifdef WITH_NVSHMEM_SUPPORT
wholememory_error_code_t wholememory_get_nvshmem_reference(
Expand Down
8 changes: 7 additions & 1 deletion cpp/src/wholememory/memory_handle.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -99,6 +99,12 @@ class wholememory_impl {
if (local_ptr != nullptr) *local_ptr = local_partition_memory_pointer_;
if (local_size != nullptr) *local_size = rank_partition_strategy_.local_mem_size;
if (local_offset != nullptr) *local_offset = rank_partition_strategy_.local_mem_offset;
if (location_ == WHOLEMEMORY_ML_HOST && (type_ == WHOLEMEMORY_MT_CONTINUOUS) &&
(!(comm_->is_intranode()))) {
WHOLEMEMORY_WARN(
" Multi-node continuous type wholememory can only be accessed by GPU threads but not CPU "
"threads, regardless of whether the location of wholememory is host.");
}
}
virtual bool get_rank_memory(void** rank_memory_ptr,
size_t* rank_memory_size,
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/wholememory/wholememory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ wholememory_error_code_t wholememory_load_from_hdfs_file(wholememory_handle_t wh
return WHOLEMEMORY_NOT_IMPLEMENTED;
}

bool wholememory_is_intranode_communicator(wholememory_comm_t comm)
{
return wholememory::is_intranode_communicator(comm);
}

bool wholememory_is_build_with_nvshmem()
{
#ifdef WITH_NVSHMEM_SUPPORT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ cdef extern from "wholememory/wholememory.h":

cdef wholememory_distributed_backend_t wholememory_communicator_get_distributed_backend(
wholememory_comm_t comm)

cdef bool wholememory_is_intranode_communicator(wholememory_comm_t comm)

cpdef enum WholeMemoryErrorCode:
Success = WHOLEMEMORY_SUCCESS
Expand Down Expand Up @@ -1113,6 +1113,10 @@ cdef class PyWholeMemoryFlattenDlpack:
cdef wholememory_comm_t comm
cdef int world_rank
cdef int world_size
if self.device_type == MlHost and mem_type == MtContinuous:
check_wholememory_error_code(wholememory_get_communicator(&comm, handle.wholememory_handle))
if wholememory_is_intranode_communicator(comm) == False :
raise ValueError('Multi-node continuous type wholememory does not support host_view. Only supports host_view=false regardless of whether location is host or not.')
global_size = wholememory_get_total_size(handle.wholememory_handle)
if global_size % elt_size != 0:
raise ValueError('global_size=%d not multiple of elt_size=%d' % (global_size, elt_size))
Expand Down
Loading