From 66b321712d34146b39fae0eb18aea0cd14b7a931 Mon Sep 17 00:00:00 2001 From: Chuang Zhu Date: Wed, 6 Dec 2023 08:55:09 +0000 Subject: [PATCH 1/2] fix inferencesample_option --- python/pylibwholegraph/pylibwholegraph/torch/common_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py index 0999fdfe5..42746add8 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py @@ -132,7 +132,7 @@ def add_common_sampler_options(argparser: ArgumentParser): argparser.add_argument( "-s", "--inferencesample", - type=int, + type=str, dest="inferencesample", default="30", help="inference sample count, -1 is all", From 82844809ffd8df3c7fff20cea9fa5d388fe6ffa9 Mon Sep 17 00:00:00 2001 From: Chuang Zhu Date: Wed, 17 Jan 2024 05:54:18 +0000 Subject: [PATCH 2/2] add sync before unregister nvshmem_buffer --- cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh | 2 +- cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh index d2d040a0e..5fa93ee12 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh +++ b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh @@ -29,7 +29,7 @@ class nvshmem_device_reference { : pointer_(static_cast(nvshmem_ref.pointer)), typed_stride_(nvshmem_ref.stride / sizeof(DataTypeT)) { - assert(gref.stride % sizeof(DataTypeT) == 0); + assert(nvshmem_ref.stride % sizeof(DataTypeT) == 0); } __device__ nvshmem_device_reference() = delete; diff --git a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu index a860cbc6c..4051f12bd 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu @@ -185,6 +185,7 @@ wholememory_error_code_t wholememory_gather_nvshmem( p_env_fns, stream); // ungistre + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); if (nvshmemx_buffer_unregister(temp_output_ptr) != 0) { WHOLEMEMORY_ERROR("nvshmemx_buffer_unregister error in wholememory_gather_nvshmem"); }