Skip to content

Commit

Permalink
Adding docstrings for the new device types, and the get_device func…
Browse files Browse the repository at this point in the history
…tion.
  • Loading branch information
codetalker7 committed Jul 23, 2023
1 parent 70044fb commit 0dc5629
Showing 1 changed file with 103 additions and 0 deletions.
103 changes: 103 additions & 0 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -445,22 +445,47 @@ function gpu(d::MLUtils.DataLoader)
)
end

"""
Flux.AbstractDevice <: Function
An abstract type representing `device` objects for different GPU backends. The currently supported backends are `"CUDA"`, `"AMD"`, `"Metal"` and `"CPU"`; the `"CPU"` backend is the fallback case when no GPU is available.
"""
abstract type AbstractDevice <: Function end

"""
Flux.FluxCPUDevice <: Flux.AbstractDevice
A type representing `device` objects for the `"CPU"` backend for Flux. This is the fallback case when no GPU is available to Flux.
"""
Base.@kwdef struct FluxCPUDevice <: AbstractDevice
name::String = "CPU"
end

"""
Flux.FluxCUDADevice <: Flux.AbstractDevice
A type representing `device` objects for the `"CUDA"` backend for Flux.
"""
Base.@kwdef struct FluxCUDADevice <: AbstractDevice
name::String = "CUDA"
pkgid::PkgId = PkgId(UUID("052768ef-5323-5732-b1bb-66c8b64840ba"), "CUDA")
end

"""
Flux.FluxAMDDevice <: Flux.AbstractDevice
A type representing `device` objects for the `"AMD"` backend for Flux.
"""
Base.@kwdef struct FluxAMDDevice <: AbstractDevice
name::String = "AMD"
pkgid::PkgId = PkgId(UUID("21141c5a-9bdb-4563-92ae-f87d6854732e"), "AMDGPU")
end

"""
Flux.FluxMetalDevice <: Flux.AbstractDevice
A type representing `device` objects for the `"Metal"` backend for Flux.
"""
Base.@kwdef struct FluxMetalDevice <: AbstractDevice
name::String = "Metal"
pkgid::PkgId = PkgId(UUID("dde4c033-4e86-420c-a63e-0dd931031962"), "Metal")
Expand All @@ -475,9 +500,87 @@ function _get_device_name(t::T) where {T <: AbstractDevice}
return hasfield(T, :name) ? t.name : ""
end

# below order is important
const DEVICES = (FluxCUDADevice(), FluxAMDDevice(), FluxMetalDevice(), FluxCPUDevice())

"""
Flux.supported_devices()
Get all supported backends for Flux, in order of preference.
# Example
```jldoctest
julia> using Flux;
julia> Flux.supported_devices()
("CUDA", "AMD", "Metal", "CPU")
```
"""
supported_devices() = map(_get_device_name, DEVICES)

"""
Flux.get_device()::AbstractDevice
Returns a `device` object for the most appropriate backend for the current Julia session.
First, the function checks whether a backend preference has been set via the `gpu_backend!` function. If so, then an attempt is made to load this backend. If the corresponding trigger package has been loaded and the backend is functional, a `device` corresponding to the given backend is loaded. Otherwise, an appropriate backend is chosen.
If there is no preference, then each of the `"CUDA"`, `"AMD"`, `"Metal"` and `"CPU"` backends in the given order, this function checks whether the given backend has been loaded via the corresponding trigger package, and whether the backend is functional. If so, the `device` corresponding to the backend is returned. If no GPU backend is available, a `Flux.FluxCPUDevice` is returned.
# Examples
For the example given below, the backend preference was set to `"AMD"` via the [`gpu_backend!`](@ref) function.
```jldoctest
julia> using Flux;
julia> model = Dense(2 => 3)
Dense(2 => 3) # 9 parameters
julia> device = Flux.get_device() # this will just load the CPU device
[ Info: Using backend set in preferences: AMD.
┌ Warning: Trying to use backend AMD but package AMDGPU [21141c5a-9bdb-4563-92ae-f87d6854732e] is not loaded.
│ Please load the package and call this function again to respect the preferences backend.
└ @ Flux ~/fluxml/Flux.jl/src/functor.jl:496
[ Info: Running automatic device selection...
(::Flux.FluxCPUDevice) (generic function with 1 method)
julia> model = model |> device
Dense(2 => 3) # 9 parameters
julia> model.weight
3×2 Matrix{Float32}:
-0.304362 -0.700477
-0.861201 0.67825
-0.176017 0.234188
```
Here is the same example, but using `"CUDA"`:
```jldoctest
julia> using Flux, CUDA;
julia> model = Dense(2 => 3)
Dense(2 => 3) # 9 parameters
julia> device = Flux.get_device()
[ Info: Using backend set in preferences: AMD.
┌ Warning: Trying to use backend AMD but package AMDGPU [21141c5a-9bdb-4563-92ae-f87d6854732e] is not loaded.
│ Please load the package and call this function again to respect the preferences backend.
└ @ Flux ~/fluxml/Flux.jl/src/functor.jl:496
[ Info: Running automatic device selection...
(::Flux.FluxCUDADevice) (generic function with 1 method)
julia> model = model |> device
Dense(2 => 3) # 9 parameters
julia> model.weight
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
0.820013 0.527131
-0.915589 0.549048
0.290744 -0.0592499
```
"""
function get_device()::AbstractDevice
backend = @load_preference("gpu_backend", nothing)
if backend !== nothing
Expand Down

0 comments on commit 0dc5629

Please sign in to comment.