From 0ee07746656b879504736829435935cf5df031af Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Feb 2024 09:21:08 -0500 Subject: [PATCH] Handle `NCCL.avg` correctly (#54) --- Project.toml | 2 +- src/base.jl | 2 ++ test/runtests.jl | 57 ++++++++++++++++++++++++++++++++++-------------- 3 files changed, 44 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 4d27d82..c630515 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NCCL" uuid = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" -version = "0.1.0" +version = "0.1.1" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/base.jl b/src/base.jl index c8a58bb..cb1efc9 100644 --- a/src/base.jl +++ b/src/base.jl @@ -33,6 +33,8 @@ ncclRedOp_t(::typeof(+)) = ncclSum ncclRedOp_t(::typeof(*)) = ncclProd ncclRedOp_t(::typeof(max)) = ncclMax ncclRedOp_t(::typeof(min)) = ncclMin +# Handles the case where user directly passed in the ncclRedOp_t (eg. `NCCL.avg`) +ncclRedOp_t(x::ncclRedOp_t) = x """ NCCl.avg diff --git a/test/runtests.jl b/test/runtests.jl index a96158b..df695c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,24 +25,49 @@ end @testset "Allreduce!" begin devs = CUDA.devices() comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - N = 512 - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), N)) - recvbuf[ii] = CUDA.zeros(Float64, N) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + + @testset "sum" begin + recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) + sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) + N = 512 + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(Float64(ii), N)) + recvbuf[ii] = CUDA.zeros(Float64, N) + end + NCCL.group() do + for ii in 1:length(devs) + NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + end + end + answer = sum(1:length(devs)) + for (ii, dev) in enumerate(devs) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) end end - answer = sum(1:length(devs)) - for (ii, dev) in enumerate(devs) - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) + + @testset "NCCL.avg" begin + recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) + sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) + N = 512 + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(Float64(ii), N)) + recvbuf[ii] = CUDA.zeros(Float64, N) + end + NCCL.group() do + for ii in 1:length(devs) + NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], NCCL.avg, comms[ii]) + end + end + answer = sum(1:length(devs)) / length(devs) + for (ii, dev) in enumerate(devs) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .≈ answer) + end end end