Skip to content

Commit

Permalink
cl/ext
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Aug 25, 2023
1 parent b7b6d7a commit 5efa720
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 14 deletions.
6 changes: 6 additions & 0 deletions ext/FluxMetalExt/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ function _metal(x)
USE_METAL[] || return x
fmap(x -> Adapt.adapt(FluxMetalAdaptor(), x), x; exclude=_isleaf)
end

function Flux.get_device(::Val{:Metal}, idx::Int)
@assert idx == 0 # Metal only supports one device at the moment
return Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]]
end

15 changes: 10 additions & 5 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,10 +650,10 @@ function get_device(; verbose=false)::AbstractDevice
end

"""
Flux.get_device(backend::String, ordinal::Int = 0)::Flux.AbstractDevice
Flux.get_device(backend::String, idx::Int = 0)::Flux.AbstractDevice
Get a device object for a backend specified by the string `backend` and `ordinal`. The currently supported values
of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`. `ordinal` must be an integer value between `0` and the number of available devices.
Get a device object for a backend specified by the string `backend` and `idx`. The currently supported values
of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`. `idx` must be an integer value between `0` and the number of available devices.
# Examples
Expand Down Expand Up @@ -683,10 +683,15 @@ julia> cpu_device = Flux.get_device("CPU")
```
"""
function get_device(backend::String, ordinal::Int = 0)
function get_device(backend::String, idx::Int = 0)
if backend == "CPU"
return FluxCPUDevice()
else
return get_device(Val(Symbol(backend)), ordinal)
return get_device(Val(Symbol(backend)), idx)
end
end

# Fallback
function get_device(::Val{D}, idx) where D
error("Unsupported backend: $(D). Try importing the corresponding package.")
end
32 changes: 23 additions & 9 deletions test/ext_metal/get_devices.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
metal_device = Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]]
@testset "Flux.DEVICES" begin
metal_device = Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]]

# should pass, whether or not Metal is functional
@test typeof(metal_device) <: Flux.FluxMetalDevice
# should pass, whether or not Metal is functional
@test typeof(metal_device) <: Flux.FluxMetalDevice

if Metal.functional()
@test typeof(metal_device.deviceID) <: Metal.MTLDevice
else
@test typeof(metal_device.deviceID) <: Nothing
@test typeof(metal_device.deviceID) <: Metal.MTLDevice
end

# testing get_device
if Metal.functional()
@testset "get_devices()" begin
# testing get_device
metal_device = Flux.get_device()

@test typeof(metal_device) <: Flux.FluxMetalDevice
Expand All @@ -23,3 +21,19 @@ if Metal.functional()
@test cx isa Metal.MtlArray
@test Metal.device(cx).registryID == metal_device.deviceID.registryID
end

@testset "get_devices(Metal)" begin
# testing get_device
metal_device = Flux.get_device("Metal")

@test typeof(metal_device) <: Flux.FluxMetalDevice
@test typeof(metal_device.deviceID) <: Metal.MTLDevice
@test Flux._get_device_name(metal_device) in Flux.supported_devices()

metal_device = Flux.get_device("Metal", 0)

@test typeof(metal_device) <: Flux.FluxMetalDevice
@test typeof(metal_device.deviceID) <: Metal.MTLDevice
@test Flux._get_device_name(metal_device) in Flux.supported_devices()
end

0 comments on commit 5efa720

Please sign in to comment.