Skip to content

Commit

Permalink
feat: add DistributedUtils (MPI&NCCL working)
Browse files Browse the repository at this point in the history
  • Loading branch information
askorupka committed Jul 7, 2024
1 parent f111c50 commit 71ae53d
Show file tree
Hide file tree
Showing 7 changed files with 652 additions and 124 deletions.
13 changes: 13 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,45 @@ version = "0.14.16"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
FluxAMDGPUExt = "AMDGPU"
FluxCUDAExt = "CUDA"
FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
FluxMPIExt = "MPI"
FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
FluxMetalExt = "Metal"

[compat]
Expand All @@ -43,14 +53,17 @@ ChainRulesCore = "1.12"
Compat = "4.10.0"
Functors = "0.4"
MLUtils = "0.4"
MPI = "0.20.19"
MacroTools = "0.5"
Metal = "0.5, 1"
NCCL = "0.1.1"
NNlib = "0.9.15"
OneHotArrays = "0.2.4"
Optimisers = "0.3.3"
Preferences = "1"
ProgressLogging = "0.1"
Reexport = "1.0"
Setfield = "1.1"
SpecialFunctions = "2.1.2"
Statistics = "1"
Zygote = "0.6.67"
Expand Down
124 changes: 0 additions & 124 deletions distributed.jl

This file was deleted.

169 changes: 169 additions & 0 deletions ext/FluxMPIExt/FluxMPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
module FluxMPIExt

using CUDA
# using AMGDPU ### TODO
using Flux: MPIBackend, NCCLBackend, DistributedUtils,
AbstractDevice, FluxCUDADevice, FluxAMDGPUDevice, cpu, gpu,
get_device, MPI_CUDA_AWARE, MPI_ROCM_AWARE
# using LuxDeviceUtils: LuxCUDADevice, LuxAMDGPUDevice
using MPI: MPI


function DistributedUtils.__initialize(
::Type{MPIBackend}; cuda_devices=nothing, amdgpu_devices=nothing,
force_cuda::Bool=false, caller::String="", force_amdgpu::Bool=false) # Undocumented internal kwarg
!MPI.Initialized() && MPI.Init()
DistributedUtils.MPI_Initialized[] = true

local_rank = MPI.Comm_rank(MPI.COMM_WORLD)

if cuda_devices !== missing && CUDA.functional()
if cuda_devices === nothing
CUDA.device!((local_rank + 1) % length(CUDA.devices()))
else
CUDA.device!(cuda_devices[local_rank + 1])
end
elseif force_cuda
error(lazy"CUDA devices are not functional and `force_cuda` is set to `true`. This is caused by backend: $(caller).")
end

return
end

DistributedUtils.__get_distributed_backend(::Type{MPIBackend}) = MPIBackend(MPI.COMM_WORLD)

DistributedUtils.local_rank(backend::MPIBackend) = MPI.Comm_rank(backend.comm)

DistributedUtils.total_workers(backend::MPIBackend) = MPI.Comm_size(backend.comm)

# Broadcast
# Union with Function is because of Flux.cpu istypeof Function
# We need CPU in case of non CUDA-aware implementation
function DistributedUtils.__bcast!(
backend::MPIBackend, sendrecvbuf, dev::Union{AbstractDevice, Function}; root=0)
MPI.Bcast!(sendrecvbuf, backend.comm; root)
return sendrecvbuf
end

function DistributedUtils.__bcast!(
backend::MPIBackend, sendbuf, recvbuf, dev::Union{AbstractDevice, Function}; root=0)
return DistributedUtils.__bcast!(
backend, ifelse(DistributedUtils.local_rank(backend) == root, sendbuf, recvbuf),
dev; root)
end

# if MPI implementation is not CUDA-aware
# we have to move data to CPU first
for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice))
if !aware
@eval begin
function DistributedUtils.__bcast!(
backend::MPIBackend, sendrecvbuf, dev::$dType; root=0)
sendrecvbuf_ = sendrecvbuf |> cpu
DistributedUtils.__bcast!(backend, sendrecvbuf_, cpu; root)
sendrecvbuf |> gpu
return sendrecvbuf
end

function DistributedUtils.__bcast!(
backend::MPIBackend, sendbuf, recvbuf, dev::$dType; root=0)
sendbuf_ = sendbuf |> cpu
recvbuf_ = recvbuf |> cpu
DistributedUtils.__bcast!(backend, sendbuf_, recvbuf_, cpu; root)
recvbuf |> gpu
return recvbuf
end
end
end
end


# Allreduce
function DistributedUtils.__allreduce!(
backend::MPIBackend, sendrecvbuf, op::F, dev::Union{AbstractDevice, Function};) where {F}
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Allreduce!(sendrecvbuf, mpiop, backend.comm)
if op === DistributedUtils.avg
sendrecvbuf ./= DistributedUtils.total_workers(backend)
end
return sendrecvbuf
end

function DistributedUtils.__allreduce!(
backend::MPIBackend, sendbuf, recvbuf, op::F, dev::Union{AbstractDevice, Function};) where {F}
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Allreduce!(sendbuf, recvbuf, mpiop, backend.comm)
if op === DistributedUtils.avg
recvbuf ./= DistributedUtils.total_workers(backend)
end
return recvbuf
end

for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice))
if !aware
@eval begin
function DistributedUtils.__allreduce!(
backend::MPIBackend, sendrecvbuf, op::F, dev::$dType) where {F}
sendrecvbuf_ = sendrecvbuf |> cpu
DistributedUtils.__allreduce!(backend, sendrecvbuf_, op, cpu)
sendrecvbuf |> gpu
return sendrecvbuf
end

function DistributedUtils.__allreduce!(
backend::MPIBackend, sendbuf, recvbuf, op::F, dev::$dType) where {F}
sendbuf_ = sendbuf |> cpu
recvbuf_ = recvbuf |> cpu
DistributedUtils.__allreduce!(backend, sendbuf_, recvbuf_, op, cpu)
recvbuf |> gpu
return recvbuf
end
end
end
end

# Reduce
function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F,
dev::Union{AbstractDevice, Function}; root::Int) where {F}
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Reduce!(sendrecvbuf, mpiop, backend.comm; root)
if op === DistributedUtils.avg
sendrecvbuf ./= DistributedUtils.total_workers(backend)
end
return sendrecvbuf
end

function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf, op::F,
dev::Union{AbstractDevice, Function}; root::Int) where {F}
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Reduce!(sendbuf, recvbuf, mpiop, backend.comm; root)
if op === DistributedUtils.avg
recvbuf ./= DistributedUtils.total_workers(backend)
end
return recvbuf
end

for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice))
if !aware
@eval begin
function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F,
dev::$dType; root::Int) where {F}
sendrecvbuf_ = sendrecvbuf |> cpu
DistributedUtils.__reduce!(backend, sendrecvbuf_, op, cpu; root)
sendrecvbuf |> gpu
return sendrecvbuf
end

function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf,
op::F, dev::$dType; root::Int) where {F}
sendbuf_ = sendbuf |> cpu
recvbuf_ = recvbuf |> cpu
DistributedUtils.__reduce!(backend, sendbuf_, recvbuf_, op, cpu; root)
recvbuf |> gpu
return recvbuf
end
end
end
end

end
Loading

0 comments on commit 71ae53d

Please sign in to comment.