Skip to content

Commit

Permalink
remove gpu implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 20, 2024
1 parent 31dccd1 commit d0cb656
Show file tree
Hide file tree
Showing 12 changed files with 19 additions and 432 deletions.
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,15 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
FluxAMDGPUExt = "AMDGPU"
FluxCUDAExt = "CUDA"
FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
FluxEnzymeExt = "Enzyme"
FluxMPIExt = "MPI"
FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
FluxMetalExt = "Metal"

[compat]
AMDGPU = "1"
Expand All @@ -50,11 +47,10 @@ ChainRulesCore = "1.12"
Compat = "4.10.0"
Enzyme = "0.12, 0.13"
Functors = "0.4"
MLDataDevices = "1.2.0"
MLDataDevices = "1.4.0"
MLUtils = "0.4"
MPI = "0.20.19"
MacroTools = "0.5"
Metal = "0.5, 1"
NCCL = "0.1.1"
NNlib = "0.9.22"
OneHotArrays = "0.2.4"
Expand Down
9 changes: 2 additions & 7 deletions docs/src/guide/gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,6 @@ There are two ways you can specify it:
gpu_backend = "AMDGPU"
```

Current GPU backend can be fetched from `Flux.GPU_BACKEND` variable:

```julia
julia> Flux.GPU_BACKEND
"CUDA"
```

The current backend will affect the behaviour of methods like the method `gpu` described below.

## Basic GPU Usage
Expand Down Expand Up @@ -358,7 +351,9 @@ MLDataDevices.get_device
MLDataDevices.gpu_device
MLDataDevices.gpu_backend!
MLDataDevices.get_device_type
MLDataDevices.loaded
MLDataDevices.reset_gpu_device!
MLDataDevices.set_device!
MLDataDevices.supported_gpu_backends
MLDataDevices.DeviceIterator
```
Expand Down
31 changes: 1 addition & 30 deletions ext/FluxAMDGPUExt/FluxAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
module FluxAMDGPUExt

import ChainRulesCore
import ChainRulesCore: NoTangent
import Flux
import Flux: FluxCPUAdaptor, FluxAMDGPUAdaptor, _amd, adapt_storage, fmap
import Flux: adapt_storage, fmap
import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
import NNlib
using MLDataDevices: MLDataDevices
Expand All @@ -14,38 +12,11 @@ using Zygote

const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat

# Set to boolean on the first call to check_use_amdgpu
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)


function check_use_amdgpu()
if !isnothing(USE_AMDGPU[])
return
end

USE_AMDGPU[] = AMDGPU.functional()
if USE_AMDGPU[]
if !AMDGPU.functional(:MIOpen)
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be available."
end
else
@info """
The AMDGPU function is being called but AMDGPU.jl is not functional.
Defaulting back to the CPU. (No action is required if you want to run on the CPU).
""" maxlog=1
end
return
end

ChainRulesCore.@non_differentiable check_use_amdgpu()

include("functor.jl")
include("batchnorm.jl")
include("conv.jl")

function __init__()
Flux.AMDGPU_LOADED[] = true
end

# TODO
# fail early if input to the model is not on the device (e.g. on the host)
Expand Down
79 changes: 4 additions & 75 deletions ext/FluxAMDGPUExt/functor.jl
Original file line number Diff line number Diff line change
@@ -1,61 +1,3 @@
# Convert Float64 to Float32, but preserve Float16.
function adapt_storage(to::FluxAMDGPUAdaptor, x::AbstractArray)
if to.id === nothing
if (typeof(x) <: AbstractArray{Float16, N} where N)
N = length(size(x))
return isbits(x) ? x : ROCArray{Float16, N}(x)
elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N})
N = length(size(x))
return isbits(x) ? x : ROCArray{Float32, N}(x)
else
return isbits(x) ? x : ROCArray(x)
end
end

old_id = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ids start from 0

if !(x isa ROCArray)
AMDGPU.device!(AMDGPU.devices()[to.id + 1]) # adding 1 because ids start from 0
if (typeof(x) <: AbstractArray{Float16, N} where N)
N = length(size(x))
x_new = isbits(x) ? x : ROCArray{Float16, N}(x)
elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N})
N = length(size(x))
x_new = isbits(x) ? x : ROCArray{Float32, N}(x)
else
x_new = isbits(x) ? x : ROCArray(x)
end
AMDGPU.device!(AMDGPU.devices()[old_id + 1])
return x_new
elseif AMDGPU.device_id(AMDGPU.device(x)) == to.id
return x
else
AMDGPU.device!(AMDGPU.devices()[to.id + 1])
x_new = copy(x)
AMDGPU.device!(AMDGPU.devices()[old_id + 1])
return x_new
end
end

adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.FillArrays.AbstractFill) =
ROCArray(collect(x))
adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.OneElement) = ROCArray(collect(x))
adapt_storage(::FluxAMDGPUAdaptor, x::Random.TaskLocalRNG) = AMDGPU.rocrand_rng()
adapt_storage(::FluxAMDGPUAdaptor, x::AMDGPU.rocRAND.RNG) = x
adapt_storage(::FluxAMDGPUAdaptor, x::AbstractRNG) = error("""
Cannot map RNG of type $(typeof(x)) to AMDGPU.
AMDGPU execution only supports Random.default_rng().""")

adapt_storage(::FluxCPUAdaptor, x::AMDGPU.rocRAND.RNG) = Random.default_rng()

function ChainRulesCore.rrule(
::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::AMDGPU.AnyROCArray,
)
adapt_storage(to, x), dx -> (
NoTangent(), NoTangent(),
adapt_storage(FluxAMDGPUAdaptor(), unthunk(dx)))
end

# Since MIOpen supports only cross-correlation as convolution,
# for the actual convolution, we flip horizontally and vertically the weights.
# Same for CPU -> GPU & GPU -> CPU movements.
Expand All @@ -70,23 +12,14 @@ const AMDGPU_CONV = FLUX_CONV{ROCArray}
_conv_basetype(::Conv) = Conv
_conv_basetype(::ConvTranspose) = ConvTranspose

Flux._isleaf(::AMDGPU_CONV) = true

_exclude(x) = Flux._isleaf(x)
_exclude(::CPU_CONV) = true

function _amd(id::Union{Nothing, Int}, x)
check_use_amdgpu()
USE_AMDGPU[] || return x
fmap(x -> Adapt.adapt(FluxAMDGPUAdaptor(id), x), x; exclude=_exclude)
end
MLDataDevices.isleaf(::AMDGPU_CONV) = true

_other_args(m::Conv) = (m.stride, m.pad, m.dilation, m.groups)
_other_args(m::ConvTranspose) = (m.stride, m.pad, m.outpad, m.dilation, m.groups)

# CPU -> GPU

function Adapt.adapt_structure(to::FluxAMDGPUAdaptor, m::CPU_CONV)
function Adapt.adapt_structure(to::AMDGPUDevice, m::CPU_CONV)
flipped_weight = reverse(m.weight; dims=ntuple(i -> i, ndims(m.weight) - 2))
_conv_basetype(m)(
Adapt.adapt(to, m.σ),
Expand All @@ -97,17 +30,13 @@ end

# Don't adapt again.

Adapt.adapt_structure(to::FluxAMDGPUAdaptor, m::AMDGPU_CONV) = m
Adapt.adapt_structure(to::AMDGPUDevice, m::AMDGPU_CONV) = m

# GPU -> CPU

function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMDGPU_CONV)
function Adapt.adapt_structure(to::CPUDevice, m::AMDGPU_CONV)
dims = ntuple(i -> i, ndims(m.weight) - 2)
_conv_basetype(m)(
Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims),
Adapt.adapt(to, m.bias), _other_args(m)...)
end

function Flux._get_device(::Val{:AMDGPU}, id::Int) # id should start from 0
return MLDataDevices.gpu_device(id+1, force=true)
end
50 changes: 0 additions & 50 deletions ext/FluxCUDAExt/FluxCUDAExt.jl

This file was deleted.

61 changes: 0 additions & 61 deletions ext/FluxCUDAExt/functor.jl

This file was deleted.

35 changes: 0 additions & 35 deletions ext/FluxMetalExt/FluxMetalExt.jl

This file was deleted.

Loading

0 comments on commit d0cb656

Please sign in to comment.