Skip to content

Commit

Permalink
Fix the recent problem in the CUDA test.
Browse files Browse the repository at this point in the history
This is done by using `cudaDeviceReset()` only when the last executor on
the device is being deleted.
  • Loading branch information
tcojean committed Mar 6, 2019
1 parent cd720b7 commit 377baa3
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 4 deletions.
6 changes: 6 additions & 0 deletions core/base/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,10 @@ const char *Operation::get_name() const noexcept
}


int CudaExecutor::num_execs[max_devices];


std::mutex CudaExecutor::mutex[max_devices];


} // namespace gko
10 changes: 6 additions & 4 deletions cuda/base/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ std::shared_ptr<CudaExecutor> CudaExecutor::create(
new CudaExecutor(device_id, std::move(master)),
[device_id](CudaExecutor *exec) {
delete exec;
device_guard g(device_id);
cudaDeviceReset();
if (!CudaExecutor::get_num_execs(device_id)) {
device_guard g(device_id);
cudaDeviceReset();
}
});
}

Expand Down Expand Up @@ -191,8 +193,8 @@ void CudaExecutor::raw_copy_to(const CudaExecutor *src, size_type num_bytes,
const void *src_ptr, void *dest_ptr) const
{
device_guard g(this->get_device_id());
GKO_ASSERT_NO_CUDA_ERRORS(cudaMemcpyPeer(dest_ptr, this->device_id_, src_ptr,
src->get_device_id(), num_bytes));
GKO_ASSERT_NO_CUDA_ERRORS(cudaMemcpyPeer(
dest_ptr, this->device_id_, src_ptr, src->get_device_id(), num_bytes));
}


Expand Down
9 changes: 9 additions & 0 deletions cuda/test/base/cuda_executor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ protected:
};


TEST_F(CudaExecutor, CanInstantiateTwoExecutorsOnOneDevice)
{
auto cuda = gko::CudaExecutor::create(0, omp);
auto cuda2 = gko::CudaExecutor::create(0, omp);

// We want automatic deinitialization to not create any error
}


TEST_F(CudaExecutor, MasterKnowsNumberOfDevices)
{
int count = 0;
Expand Down
27 changes: 27 additions & 0 deletions include/ginkgo/core/base/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include <memory>
#include <mutex>
#include <sstream>
#include <tuple>
#include <type_traits>
Expand Down Expand Up @@ -799,6 +800,8 @@ class CudaExecutor : public detail::ExecutorBase<CudaExecutor>,
static std::shared_ptr<CudaExecutor> create(
int device_id, std::shared_ptr<Executor> master);

~CudaExecutor() { decrease_num_execs(this->device_id_); }

std::shared_ptr<Executor> get_master() noexcept override;

std::shared_ptr<const Executor> get_master() const noexcept override;
Expand Down Expand Up @@ -877,8 +880,10 @@ class CudaExecutor : public detail::ExecutorBase<CudaExecutor>,
major_(0),
minor_(0)
{
assert(device_id < max_devices);
this->set_gpu_property();
this->init_handles();
increase_num_execs(device_id);
}

void *raw_alloc(size_type size) const override;
Expand All @@ -887,6 +892,24 @@ class CudaExecutor : public detail::ExecutorBase<CudaExecutor>,

GKO_ENABLE_FOR_ALL_EXECUTORS(GKO_OVERRIDE_RAW_COPY_TO);

static void increase_num_execs(int device_id)
{
std::lock_guard<std::mutex> guard(mutex[device_id]);
num_execs[device_id]++;
}

static void decrease_num_execs(int device_id)
{
std::lock_guard<std::mutex> guard(mutex[device_id]);
num_execs[device_id]--;
}

static int get_num_execs(int device_id)
{
std::lock_guard<std::mutex> guard(mutex[device_id]);
return num_execs[device_id];
}

private:
int device_id_;
std::shared_ptr<Executor> master_;
Expand All @@ -899,6 +922,10 @@ class CudaExecutor : public detail::ExecutorBase<CudaExecutor>,
using handle_manager = std::unique_ptr<T, std::function<void(T *)>>;
handle_manager<cublasContext> cublas_handle_;
handle_manager<cusparseContext> cusparse_handle_;

static constexpr int max_devices = 64;
static int num_execs[max_devices];
static std::mutex mutex[max_devices];
};


Expand Down

0 comments on commit 377baa3

Please sign in to comment.