Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Fix format/CI
Browse files Browse the repository at this point in the history
  • Loading branch information
chang-l committed Nov 21, 2024
1 parent 6531ba0 commit a5117a6
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 31 deletions.
24 changes: 14 additions & 10 deletions cpp/src/nvml_wrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "nvml_wrap.h"

#if CUDA_VERSION >= 12030
#include <dlfcn.h>
#include <stdio.h>
#include <mutex>
#include "nvml_wrap.h"
#include <stdio.h>

namespace {

void* nvml_handle = nullptr;
std::mutex nvml_mutex;
bool nvml_loaded = false;

bool LoadNvmlLibrary() {
bool LoadNvmlLibrary()
{
nvml_handle = dlopen("libnvidia-ml.so.1", RTLD_NOW);
if (!nvml_handle) {
nvml_handle = dlopen("libnvidia-ml.so", RTLD_NOW);
Expand All @@ -36,11 +39,10 @@ bool LoadNvmlLibrary() {
}

template <typename T>
T LoadNvmlSymbol(const char* name) {
T LoadNvmlSymbol(const char* name)
{
void* symbol = dlsym(nvml_handle, name);
if (!symbol) {
return nullptr;
}
if (!symbol) { return nullptr; }
return reinterpret_cast<T>(symbol);
}

Expand All @@ -51,10 +53,11 @@ nvmlDeviceGetHandleByIndexFunc nvmlDeviceGetHandleByIndexPtr = nullptr;
nvmlDeviceGetGpuFabricInfoFunc nvmlDeviceGetGpuFabricInfoPtr = nullptr;

// Ensure NVML is loaded and symbols are initialized
bool NvmlFabricSymbolLoaded() {
bool NvmlFabricSymbolLoaded()
{
std::lock_guard<std::mutex> lock(nvml_mutex);
if (nvml_loaded) {
return true; // Already loaded
return true; // Already loaded
}

if (LoadNvmlLibrary()) {
Expand All @@ -71,4 +74,5 @@ bool NvmlFabricSymbolLoaded() {
}
}
return nvml_loaded;
}
}
#endif
6 changes: 4 additions & 2 deletions cpp/src/nvml_wrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
// limitations under the License.

#pragma once
#include <cuda.h>

#if CUDA_VERSION >= 12030
#include <nvml.h>


bool NvmlFabricSymbolLoaded();

typedef nvmlReturn_t (*nvmlDeviceGetHandleByIndexFunc)(unsigned int, nvmlDevice_t*);
typedef nvmlReturn_t (*nvmlDeviceGetGpuFabricInfoFunc)(nvmlDevice_t, nvmlGpuFabricInfo_t*);

extern nvmlDeviceGetHandleByIndexFunc nvmlDeviceGetHandleByIndexPtr;
extern nvmlDeviceGetGpuFabricInfoFunc nvmlDeviceGetGpuFabricInfoPtr;
extern nvmlDeviceGetGpuFabricInfoFunc nvmlDeviceGetGpuFabricInfoPtr;
#endif
38 changes: 21 additions & 17 deletions cpp/src/wholememory/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ void exchange_rank_info(wholememory_comm_t wm_comm)
wm_comm->clique_info.is_in_clique = 0;

#if CUDA_VERSION >= 12030
if(nvmlFabricSymbolLoaded) {
if (nvmlFabricSymbolLoaded) {
memset(&ri.fabric_info, 0, sizeof(ri.fabric_info));
WHOLEMEMORY_CHECK_NOTHROW(GetGpuFabricInfo(wm_comm->dev_id, &ri.fabric_info) ==
WHOLEMEMORY_SUCCESS);
Expand All @@ -548,7 +548,9 @@ void exchange_rank_info(wholememory_comm_t wm_comm)
wm_comm->clique_info.is_in_clique = 1;
}
} else {
WHOLEMEMORY_WARN("Some required NVML symbols are missing, likely due to an outdated GPU display driver. MNNVL support will be disabled.");
WHOLEMEMORY_WARN(
"Some required NVML symbols are missing, likely due to an outdated GPU display driver. MNNVL "
"support will be disabled.");
}

#endif
Expand Down Expand Up @@ -578,30 +580,32 @@ void exchange_rank_info(wholememory_comm_t wm_comm)
}

#if CUDA_VERSION >= 12030
if(nvmlFabricSymbolLoaded) {
if ((memcmp(ri.fabric_info.clusterUuid,
p_rank_info.get()[r].fabric_info.clusterUuid,
NVML_GPU_FABRIC_UUID_LEN) == 0) &&
(ri.fabric_info.cliqueId == p_rank_info.get()[r].fabric_info.cliqueId)) {
if (r == wm_comm->world_rank) {
wm_comm->clique_info.clique_rank = wm_comm->clique_info.clique_rank_num;
if (nvmlFabricSymbolLoaded) {
if ((memcmp(ri.fabric_info.clusterUuid,
p_rank_info.get()[r].fabric_info.clusterUuid,
NVML_GPU_FABRIC_UUID_LEN) == 0) &&
(ri.fabric_info.cliqueId == p_rank_info.get()[r].fabric_info.cliqueId)) {
if (r == wm_comm->world_rank) {
wm_comm->clique_info.clique_rank = wm_comm->clique_info.clique_rank_num;
}
if (wm_comm->clique_info.clique_rank_num == 0) {
wm_comm->clique_info.clique_first_rank = r;
}
wm_comm->clique_info.clique_rank_num++;
}
if (wm_comm->clique_info.clique_rank_num == 0) { wm_comm->clique_info.clique_first_rank = r; }
wm_comm->clique_info.clique_rank_num++;
clique_uuids.insert(
std::string(reinterpret_cast<const char*>(p_rank_info.get()[r].fabric_info.clusterUuid),
NVML_GPU_FABRIC_UUID_LEN));
}
clique_uuids.insert(
std::string(reinterpret_cast<const char*>(p_rank_info.get()[r].fabric_info.clusterUuid),
NVML_GPU_FABRIC_UUID_LEN));
}
#endif
}

#if CUDA_VERSION >= 12030
if(nvmlFabricSymbolLoaded) {
if (nvmlFabricSymbolLoaded) {
wm_comm->clique_info.clique_num = clique_uuids.size();

std::string uuid = std::string(reinterpret_cast<const char*>(ri.fabric_info.clusterUuid),
NVML_GPU_FABRIC_UUID_LEN);
NVML_GPU_FABRIC_UUID_LEN);
int id = 0;
for (auto clique_uuid : clique_uuids) {
if (clique_uuid == uuid) { wm_comm->clique_info.clique_id = id; }
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/wholememory/system_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#include "wholememory/wholememory.h"

#if CUDA_VERSION >= 12030
#include <nvml.h>
#include "nvml_wrap.h"
#include <nvml.h>
#endif
bool DevAttrPagebleMemoryAccess();

Expand All @@ -41,6 +41,6 @@ namespace wholememory {

inline bool nvmlFabricSymbolLoaded = NvmlFabricSymbolLoaded();
wholememory_error_code_t GetGpuFabricInfo(int dev, nvmlGpuFabricInfo_t* gpuFabricInfo);
}
} // namespace wholememory

#endif

0 comments on commit a5117a6

Please sign in to comment.