From 97990530af30bf43ef2c007db30184095ac61146 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Wed, 15 Feb 2023 01:27:55 +0200 Subject: [PATCH 01/15] Initial implementation of AMDGPU extension --- Project.toml | 7 ++++++ ext/AMDGPUExt/AMDGPUExt.jl | 20 +++++++++++++++ ext/AMDGPUExt/functor.jl | 51 ++++++++++++++++++++++++++++++++++++++ src/functor.jl | 16 ++++++++++++ 4 files changed, 94 insertions(+) create mode 100644 ext/AMDGPUExt/AMDGPUExt.jl create mode 100644 ext/AMDGPUExt/functor.jl diff --git a/Project.toml b/Project.toml index 8292de97e2..be67c887a4 100644 --- a/Project.toml +++ b/Project.toml @@ -23,8 +23,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + +[extensions] +AMDGPUExt = "AMDGPU" + [compat] Adapt = "3.0" +AMDGPU = "0.4.8" CUDA = "3, 4" ChainRulesCore = "1.12" Functors = "0.3, 0.4" diff --git a/ext/AMDGPUExt/AMDGPUExt.jl b/ext/AMDGPUExt/AMDGPUExt.jl new file mode 100644 index 0000000000..cc8d462899 --- /dev/null +++ b/ext/AMDGPUExt/AMDGPUExt.jl @@ -0,0 +1,20 @@ +module AMDGPUExt + +using AMDGPU +using Adapt +using Random +using Zygote +import ChainRulesCore +import Functors: fmap +import Flux +import Flux: FluxCPUAdaptor, adapt_storage, _isleaf, _amd + +const use_amdgpu = Ref{Bool}(false) + +include("functor.jl") + +function __init__() + Flux.amdgpu_loaded[] = true +end + +end diff --git a/ext/AMDGPUExt/functor.jl b/ext/AMDGPUExt/functor.jl new file mode 100644 index 0000000000..bbd15bfe22 --- /dev/null +++ b/ext/AMDGPUExt/functor.jl @@ -0,0 +1,51 @@ +struct FluxAMDGPUAdaptor end + +adapt_storage(::FluxAMDGPUAdaptor, x) = ROCArray(x) +adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.FillArrays.AbstractFill) = + ROCArray(collect(x)) +adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.OneElement) = ROCArray(collect(x)) +adapt_storage(::FluxAMDGPUAdaptor, x::Random.TaskLocalRNG) = + AMDGPU.rocRAND.default_rng() +adapt_storage(::FluxAMDGPUAdaptor, x::AMDGPU.rocRAND.RNG) = x +adapt_storage(::FluxAMDGPUAdaptor, x::AbstractRNG) = error(""" + Cannot map RNG of type $(typeof(x)) to AMDGPU. + AMDGPU execution only supports Random.default_rng().""") + +# TODO adaptor for Conv + +adapt_storage(::FluxCPUAdaptor, x::AMDGPU.rocRAND.RNG) = Random.default_rng() + +function ChainRulesCore.rrule(::Type{Array}, x::ROCArray) + Array(x), dx -> (NoTangent(), ROCArray(unthunk(dx))) +end + +function ChainRulesCore.rrule( + ::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::AMDGPU.AnyROCArray, +) + adapt_storage(to, x), dx -> ( + NoTangent(), NoTangent(), + adapt_storage(FluxAMDGPUAdaptor(), unthunk(dx))) +end + +function _amd(x) + check_use_amdgpu() + use_amdgpu[] ? fmap(x -> Adapt.adapt(FluxAMDGPUAdaptor(), x)) : x +end + +function check_use_amdgpu() + use_amdgpu[] === nothing || return + + use_amdgpu[] = AMDGPU.functional() + if use_amdgpu[] + if !AMDGPU.functional(:MIOpen) + @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be available." + end + else + @info """ + The AMDGPU function is being called but the AMDGPU is not functional. + Defaulting back to the CPU. (No action is required if you want to run on the CPU). + """ maxlog=1 + end + return +end +ChainRulesCore.@non_differentiable check_use_amdgpu() diff --git a/src/functor.jl b/src/functor.jl index 50bc6df52c..da66c73945 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -280,3 +280,19 @@ f16(m) = _paramtype(Float16, m) @functor Cholesky trainable(c::Cholesky) = () +# AMDGPU extension. + +const amdgpu_loaded = Ref{Bool}(false) + +function amd(x) + if amdgpu_loaded[] + return _amd(x) + else + @info """ + The AMDGPU functionality is being called via `Flux.amd` but + `AMDGPU` must be loaded to access it. + """ maxlog=1 + end +end + +function _amd end From 30f076e2378c2a464f169bedc93116ea158ed745 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Wed, 15 Feb 2023 14:44:43 +0200 Subject: [PATCH 02/15] Add more tests --- Project.toml | 2 +- ext/AMDGPUExt/AMDGPUExt.jl | 34 ++++++++++++++++---- ext/AMDGPUExt/functor.jl | 45 +++++++++++---------------- src/functor.jl | 4 +-- test/amd/basic.jl | 64 ++++++++++++++++++++++++++++++++++++++ test/amd/runtests.jl | 7 +++++ test/amd/utils.jl | 47 ++++++++++++++++++++++++++++ test/cuda/cuda.jl | 8 ++--- test/runtests.jl | 15 +++++++++ 9 files changed, 186 insertions(+), 40 deletions(-) create mode 100644 test/amd/basic.jl create mode 100644 test/amd/runtests.jl create mode 100644 test/amd/utils.jl diff --git a/Project.toml b/Project.toml index be67c887a4..8ccec51734 100644 --- a/Project.toml +++ b/Project.toml @@ -57,4 +57,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] +test = ["AMDGPU", "Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] diff --git a/ext/AMDGPUExt/AMDGPUExt.jl b/ext/AMDGPUExt/AMDGPUExt.jl index cc8d462899..2b6c037afc 100644 --- a/ext/AMDGPUExt/AMDGPUExt.jl +++ b/ext/AMDGPUExt/AMDGPUExt.jl @@ -1,20 +1,42 @@ module AMDGPUExt +import ChainRulesCore +import Flux +import Flux: FluxCPUAdaptor, _amd, _isleaf, adapt_storage, fmap + using AMDGPU using Adapt using Random using Zygote -import ChainRulesCore -import Functors: fmap -import Flux -import Flux: FluxCPUAdaptor, adapt_storage, _isleaf, _amd -const use_amdgpu = Ref{Bool}(false) +const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing) + +function check_use_amdgpu() + isnothing(USE_AMDGPU[]) || return + + USE_AMDGPU[] = AMDGPU.functional() + if USE_AMDGPU[] + if !AMDGPU.functional(:MIOpen) + @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be available." + end + else + @info """ + The AMDGPU function is being called but the AMDGPU is not functional. + Defaulting back to the CPU. (No action is required if you want to run on the CPU). + """ maxlog=1 + end + return +end +ChainRulesCore.@non_differentiable check_use_amdgpu() include("functor.jl") function __init__() - Flux.amdgpu_loaded[] = true + Flux.AMDGPU_LOADED[] = true end +# TODO +# fail early if input to the model is not on the device (e.g. on the host) +# otherwise we get very cryptic errors & segfaults at the rocBLAS level + end diff --git a/ext/AMDGPUExt/functor.jl b/ext/AMDGPUExt/functor.jl index bbd15bfe22..12c6ee0d69 100644 --- a/ext/AMDGPUExt/functor.jl +++ b/ext/AMDGPUExt/functor.jl @@ -1,13 +1,20 @@ -struct FluxAMDGPUAdaptor end +struct FluxAMDAdaptor end -adapt_storage(::FluxAMDGPUAdaptor, x) = ROCArray(x) -adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.FillArrays.AbstractFill) = +# Convert Float64 to Float32, but preserve Float16. +adapt_storage(::FluxAMDAdaptor, x::T) where T <: AbstractArray = + isbits(x) ? x : ROCArray(x) +adapt_storage(::FluxAMDAdaptor, x::AbstractArray{T, N}) where {T <: AbstractFloat, N} = + isbits(x) ? x : ROCArray{Float32, N}(x) +adapt_storage(::FluxAMDAdaptor, x::AbstractArray{Float16, N}) where N = + isbits(x) ? x : ROCArray{Float16, N}(x) + +adapt_storage(::FluxAMDAdaptor, x::Zygote.FillArrays.AbstractFill) = ROCArray(collect(x)) -adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.OneElement) = ROCArray(collect(x)) -adapt_storage(::FluxAMDGPUAdaptor, x::Random.TaskLocalRNG) = +adapt_storage(::FluxAMDAdaptor, x::Zygote.OneElement) = ROCArray(collect(x)) +adapt_storage(::FluxAMDAdaptor, x::Random.TaskLocalRNG) = AMDGPU.rocRAND.default_rng() -adapt_storage(::FluxAMDGPUAdaptor, x::AMDGPU.rocRAND.RNG) = x -adapt_storage(::FluxAMDGPUAdaptor, x::AbstractRNG) = error(""" +adapt_storage(::FluxAMDAdaptor, x::AMDGPU.rocRAND.RNG) = x +adapt_storage(::FluxAMDAdaptor, x::AbstractRNG) = error(""" Cannot map RNG of type $(typeof(x)) to AMDGPU. AMDGPU execution only supports Random.default_rng().""") @@ -24,28 +31,12 @@ function ChainRulesCore.rrule( ) adapt_storage(to, x), dx -> ( NoTangent(), NoTangent(), - adapt_storage(FluxAMDGPUAdaptor(), unthunk(dx))) + adapt_storage(FluxAMDAdaptor(), unthunk(dx))) end function _amd(x) check_use_amdgpu() - use_amdgpu[] ? fmap(x -> Adapt.adapt(FluxAMDGPUAdaptor(), x)) : x -end - -function check_use_amdgpu() - use_amdgpu[] === nothing || return - - use_amdgpu[] = AMDGPU.functional() - if use_amdgpu[] - if !AMDGPU.functional(:MIOpen) - @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be available." - end - else - @info """ - The AMDGPU function is being called but the AMDGPU is not functional. - Defaulting back to the CPU. (No action is required if you want to run on the CPU). - """ maxlog=1 - end - return + USE_AMDGPU[] ? + fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_isleaf) : + x end -ChainRulesCore.@non_differentiable check_use_amdgpu() diff --git a/src/functor.jl b/src/functor.jl index da66c73945..8e2f4df300 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -282,10 +282,10 @@ trainable(c::Cholesky) = () # AMDGPU extension. -const amdgpu_loaded = Ref{Bool}(false) +const AMDGPU_LOADED = Ref{Bool}(false) function amd(x) - if amdgpu_loaded[] + if AMDGPU_LOADED[] return _amd(x) else @info """ diff --git a/test/amd/basic.jl b/test/amd/basic.jl new file mode 100644 index 0000000000..1b095e3a30 --- /dev/null +++ b/test/amd/basic.jl @@ -0,0 +1,64 @@ +@test Flux.AMDGPU_LOADED[] + +@testset "Basic GPU movement" begin + @test Flux.amd(rand(Float64, 16)) isa ROCArray{Float32, 1} + @test Flux.amd(rand(Float64, 16, 16)) isa ROCArray{Float32, 2} + @test Flux.amd(rand(Float32, 16, 16)) isa ROCArray{Float32, 2} + @test Flux.amd(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3} + + @test gradient(x -> sum(Flux.amd(x)), rand(Float32, 4, 4)) isa Tuple + @test gradient(x -> sum(cpu(x)), AMDGPU.rand(Float32, 4, 4)) isa Tuple +end + +@testset "Dense no bias" begin + m = Dense(3 => 4; bias=false) |> Flux.amd + x = zeros(Float32, 3, 4) |> Flux.amd + @test sum(m(x)) ≈ 0f0 + gs = gradient(m -> sum(m(x)), m) + @test isnothing(gs[1].bias) +end + +@testset "Chain of Dense layers" begin + m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32 + x = rand(Float32, 10, 10) + amdgputest(m, x) +end + +@testset "Cross-correlation" begin + m = CrossCor((2, 2), 3 => 4) |> f32 + x = rand(Float32, 10, 10, 3, 2) + amdgputest(m, x; atol=1f-3) +end + +@testset "Restructure" begin + m = Dense(1, 1) |> Flux.amd + θ, m̂ = Flux.destructure(m) + foo(x) = sum(re(p)(x)) + + x = Flux.amd(rand(Float32, 1)) + @test gradient(x -> sum(m̂(θ)(x)), x)[1] isa ROCArray{Float32} +end + +@testset "Flux.amd(x) on structured arrays" begin + g1 = Zygote.OneElement(1, (2, 3), axes(ones(4, 5))) + @test Flux.amd(g1) isa ROCMatrix{Int64} + g2 = Zygote.Fill(1f0, 2) + @test Flux.amd(g2) isa ROCArray{Float32, 1} + g3 = transpose(Float32[1 2; 3 4]) + @test parent(Flux.amd(g3)) isa ROCMatrix{Float32} +end + +@testset "Flux.onecold gpu" begin + y = Flux.onehotbatch(ones(3), 1:10) |> Flux.amd + l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] + @test Flux.onecold(y) isa ROCArray + @test y[3, :] isa ROCArray + @test Flux.onecold(y, l) == ['a', 'a', 'a'] +end + +# FIXME scalar indexing. Needs NNlib.scatter? +# @testset "Flux.onehot gpu" begin +# y = Flux.onehotbatch(ones(3), 1:2) |> Flux.amd +# x = rand(3, 2) |> Flux.amd +# @test gradient(x -> sum(x * y), x)[1] isa ROCArray +# end diff --git a/test/amd/runtests.jl b/test/amd/runtests.jl new file mode 100644 index 0000000000..d82e1841c9 --- /dev/null +++ b/test/amd/runtests.jl @@ -0,0 +1,7 @@ +include("utils.jl") + +AMDGPU.allowscalar(false) + +@testset "Basic" begin + include("basic.jl") +end diff --git a/test/amd/utils.jl b/test/amd/utils.jl new file mode 100644 index 0000000000..bd73d83d00 --- /dev/null +++ b/test/amd/utils.jl @@ -0,0 +1,47 @@ +function amdgputest(model, xs...; checkgrad=true, atol=1e-6, kws...) + cpu_model = model + gpu_model = Flux.amd(model) + + cpu_in = xs + gpu_in = Flux.amd.(xs) + + cpu_out = cpu_model(cpu_in...) + gpu_out = gpu_model(gpu_in...) + @test collect(cpu_out) ≈ collect(gpu_out) atol=atol + + if checkgrad + cpu_grad = gradient(m -> sum(m(cpu_in...)), cpu_model) + gpu_grad = gradient(m -> sum(m(gpu_in...)), gpu_model) + amd_check_grad(gpu_grad, cpu_grad; atol) + end +end + +function amd_check_grad(g_gpu, g_cpu; atol) + @show g_gpu g_cpu + @test false +end + +amd_check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol) = + amd_check_grad(g_gpu[], g_cpu[]; atol) +amd_check_grad(g_gpu::Nothing, g_cpu::Nothing; atol) = + @test true +amd_check_grad(g_gpu::Float32, g_cpu::Float32; atol) = + @test g_cpu ≈ g_gpu atol=atol +amd_check_grad(g_gpu::ROCArray{Float32}, g_cpu::Array{Float32}; atol) = + @test g_cpu ≈ collect(g_gpu) atol=atol +amd_check_grad( + g_gpu::ROCArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill; atol, +) = @test collect(g_cpu) ≈ collect(g_gpu) atol=atol + +function amd_check_grad(g_gpu::Tuple, g_cpu::Tuple; atol) + for (v1, v2) in zip(g_gpu, g_cpu) + amd_check_grad(v1, v2; atol) + end +end + +function amd_check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple; atol) + for ((k1, v1), (k2, v2)) in zip(pairs(g_gpu), pairs(g_cpu)) + @test k1 == k2 + amd_check_grad(v1, v2; atol) + end +end diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 9f624b2882..e5e28d428b 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -91,7 +91,7 @@ end struct SimpleBits field::Int32 end - + @test gpu((;a=ones(1))).a isa CuVector{Float32} @test gpu((;a=['a', 'b', 'c'])).a isa CuVector{Char} @test gpu((;a=[SimpleBits(1)])).a isa CuVector{SimpleBits} @@ -167,14 +167,14 @@ end @test parent(gpu(g3)) isa CuArray - #Issue #2116 + #Issue #2116 struct A2116 x::Int y::Int end x = [A2116(1,1), A2116(2,2)] - xgpu = gpu(x) + xgpu = gpu(x) @test xgpu isa CuVector{A2116} - @test cpu(xgpu) isa Vector{A2116} + @test cpu(xgpu) isa Vector{A2116} @test cpu(gpu([CartesianIndex(1)])) isa Vector{CartesianIndex{1}} end diff --git a/test/runtests.jl b/test/runtests.jl index ae57cf5ad6..503ae2f2f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,4 +60,19 @@ Random.seed!(0) doctest(Flux) end end + + if get(ENV, "FLUX_TEST_AMDGPU", "false") == "true" + using AMDGPU + AMDGPU.versioninfo() + if AMDGPU.functional() && AMDGPU.functional(:MIOpen) + @show AMDGPU.MIOpen.version() + @testset "AMDGPU" begin + include("amd/runtests.jl") + end + else + @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." + end + else + @info "Skipping AMDGPU tests, set FLUX_TEST_CUDA=true to run them." + end end From 211ab212b6e969cfb96c2dc2419ac059eed1b2d8 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 16 Feb 2023 01:03:34 +0200 Subject: [PATCH 03/15] Partially handle convolution --- ext/AMDGPUExt/functor.jl | 27 +++++++++- test/amd/basic.jl | 104 ++++++++++++++++++++++----------------- test/amd/utils.jl | 2 +- test/runtests.jl | 84 +++++++++++++++---------------- 4 files changed, 127 insertions(+), 90 deletions(-) diff --git a/ext/AMDGPUExt/functor.jl b/ext/AMDGPUExt/functor.jl index 12c6ee0d69..11ba232039 100644 --- a/ext/AMDGPUExt/functor.jl +++ b/ext/AMDGPUExt/functor.jl @@ -18,7 +18,23 @@ adapt_storage(::FluxAMDAdaptor, x::AbstractRNG) = error(""" Cannot map RNG of type $(typeof(x)) to AMDGPU. AMDGPU execution only supports Random.default_rng().""") -# TODO adaptor for Conv +function adapt_storage(to::FluxAMDAdaptor, m::Flux.Conv) + Flux.Conv( + Adapt.adapt(to, m.σ), + Adapt.adapt(to, m.weight[end:-1:1, end:-1:1, :, :]), + Adapt.adapt(to, m.bias), + m.stride, m.pad, m.dilation, m.groups) +end + +# # Don't adapt again. +# function adapt_storage( +# to::FluxAMDAdaptor, m::Flux.Conv{N, M, F, A, V}, +# ) where {N, M, F, A <: ROCArray, V} +# return m +# end + +# TODO GPU -> CPU adaptor +# TODO don't adapt again when already on AMDGPU adapt_storage(::FluxCPUAdaptor, x::AMDGPU.rocRAND.RNG) = Random.default_rng() @@ -40,3 +56,12 @@ function _amd(x) fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_isleaf) : x end + +function _amd(m::Flux.Conv) + to = FluxAMDAdaptor() + Flux.Conv( + Adapt.adapt(to, m.σ), + Adapt.adapt(to, m.weight[end:-1:1, end:-1:1, :, :]), + Adapt.adapt(to, m.bias), + m.stride, m.pad, m.dilation, m.groups) +end diff --git a/test/amd/basic.jl b/test/amd/basic.jl index 1b095e3a30..5d686189c5 100644 --- a/test/amd/basic.jl +++ b/test/amd/basic.jl @@ -1,60 +1,72 @@ @test Flux.AMDGPU_LOADED[] -@testset "Basic GPU movement" begin - @test Flux.amd(rand(Float64, 16)) isa ROCArray{Float32, 1} - @test Flux.amd(rand(Float64, 16, 16)) isa ROCArray{Float32, 2} - @test Flux.amd(rand(Float32, 16, 16)) isa ROCArray{Float32, 2} - @test Flux.amd(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3} +# @testset "Basic GPU movement" begin +# @test Flux.amd(rand(Float64, 16)) isa ROCArray{Float32, 1} +# @test Flux.amd(rand(Float64, 16, 16)) isa ROCArray{Float32, 2} +# @test Flux.amd(rand(Float32, 16, 16)) isa ROCArray{Float32, 2} +# @test Flux.amd(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3} - @test gradient(x -> sum(Flux.amd(x)), rand(Float32, 4, 4)) isa Tuple - @test gradient(x -> sum(cpu(x)), AMDGPU.rand(Float32, 4, 4)) isa Tuple -end +# @test gradient(x -> sum(Flux.amd(x)), rand(Float32, 4, 4)) isa Tuple +# @test gradient(x -> sum(cpu(x)), AMDGPU.rand(Float32, 4, 4)) isa Tuple +# end -@testset "Dense no bias" begin - m = Dense(3 => 4; bias=false) |> Flux.amd - x = zeros(Float32, 3, 4) |> Flux.amd - @test sum(m(x)) ≈ 0f0 - gs = gradient(m -> sum(m(x)), m) - @test isnothing(gs[1].bias) -end +# @testset "Dense no bias" begin +# m = Dense(3 => 4; bias=false) |> Flux.amd +# x = zeros(Float32, 3, 4) |> Flux.amd +# @test sum(m(x)) ≈ 0f0 +# gs = gradient(m -> sum(m(x)), m) +# @test isnothing(gs[1].bias) +# end -@testset "Chain of Dense layers" begin - m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32 - x = rand(Float32, 10, 10) - amdgputest(m, x) -end +# @testset "Chain of Dense layers" begin +# m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32 +# x = rand(Float32, 10, 10) +# amdgputest(m, x) +# end -@testset "Cross-correlation" begin - m = CrossCor((2, 2), 3 => 4) |> f32 - x = rand(Float32, 10, 10, 3, 2) - amdgputest(m, x; atol=1f-3) +@testset "Convolution" begin + m = Conv((2, 2), 1 => 1) |> f32 + x = rand(Float32, 4, 4, 1, 1) + amdgputest(m, x; atol=1f-3, checkgrad=false) + + # Gradients are flipped as well. + md, xd = Flux.amd.((m, x)) + gs = gradient(m -> sum(m(x)), m) + gsd = gradient(m -> sum(m(xd)), md) + @test gs[1].weight[end:-1:1, end:-1:1, :, :] ≈ Array(gsd[1].weight) atol=1f-3 end -@testset "Restructure" begin - m = Dense(1, 1) |> Flux.amd - θ, m̂ = Flux.destructure(m) - foo(x) = sum(re(p)(x)) +# @testset "Cross-correlation" begin +# m = CrossCor((2, 2), 3 => 4) |> f32 +# x = rand(Float32, 10, 10, 3, 2) +# amdgputest(m, x; atol=1f-3) +# end - x = Flux.amd(rand(Float32, 1)) - @test gradient(x -> sum(m̂(θ)(x)), x)[1] isa ROCArray{Float32} -end +# @testset "Restructure" begin +# m = Dense(1, 1) |> Flux.amd +# θ, m̂ = Flux.destructure(m) +# foo(x) = sum(re(p)(x)) -@testset "Flux.amd(x) on structured arrays" begin - g1 = Zygote.OneElement(1, (2, 3), axes(ones(4, 5))) - @test Flux.amd(g1) isa ROCMatrix{Int64} - g2 = Zygote.Fill(1f0, 2) - @test Flux.amd(g2) isa ROCArray{Float32, 1} - g3 = transpose(Float32[1 2; 3 4]) - @test parent(Flux.amd(g3)) isa ROCMatrix{Float32} -end +# x = Flux.amd(rand(Float32, 1)) +# @test gradient(x -> sum(m̂(θ)(x)), x)[1] isa ROCArray{Float32} +# end -@testset "Flux.onecold gpu" begin - y = Flux.onehotbatch(ones(3), 1:10) |> Flux.amd - l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] - @test Flux.onecold(y) isa ROCArray - @test y[3, :] isa ROCArray - @test Flux.onecold(y, l) == ['a', 'a', 'a'] -end +# @testset "Flux.amd(x) on structured arrays" begin +# g1 = Zygote.OneElement(1, (2, 3), axes(ones(4, 5))) +# @test Flux.amd(g1) isa ROCMatrix{Int64} +# g2 = Zygote.Fill(1f0, 2) +# @test Flux.amd(g2) isa ROCArray{Float32, 1} +# g3 = transpose(Float32[1 2; 3 4]) +# @test parent(Flux.amd(g3)) isa ROCMatrix{Float32} +# end + +# @testset "Flux.onecold gpu" begin +# y = Flux.onehotbatch(ones(3), 1:10) |> Flux.amd +# l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] +# @test Flux.onecold(y) isa ROCArray +# @test y[3, :] isa ROCArray +# @test Flux.onecold(y, l) == ['a', 'a', 'a'] +# end # FIXME scalar indexing. Needs NNlib.scatter? # @testset "Flux.onehot gpu" begin diff --git a/test/amd/utils.jl b/test/amd/utils.jl index bd73d83d00..b8b93caf35 100644 --- a/test/amd/utils.jl +++ b/test/amd/utils.jl @@ -1,4 +1,4 @@ -function amdgputest(model, xs...; checkgrad=true, atol=1e-6, kws...) +function amdgputest(model, xs...; checkgrad=true, atol=1e-6) cpu_model = model gpu_model = Flux.amd(model) diff --git a/test/runtests.jl b/test/runtests.jl index 503ae2f2f0..6276834401 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,55 +11,55 @@ Random.seed!(0) @testset verbose=true "Flux.jl" begin - @testset "Utils" begin - include("utils.jl") - end + # @testset "Utils" begin + # include("utils.jl") + # end - @testset "Optimise / Train" begin - include("optimise.jl") - include("train.jl") - end + # @testset "Optimise / Train" begin + # include("optimise.jl") + # include("train.jl") + # end - @testset "Data" begin - include("data.jl") - end + # @testset "Data" begin + # include("data.jl") + # end - @testset "Losses" begin - include("losses.jl") - include("ctc.jl") - CUDA.functional() && include("ctc-gpu.jl") - end + # @testset "Losses" begin + # include("losses.jl") + # include("ctc.jl") + # CUDA.functional() && include("ctc-gpu.jl") + # end - @testset "Layers" begin - include("layers/basic.jl") - include("layers/normalisation.jl") - include("layers/stateless.jl") - include("layers/recurrent.jl") - include("layers/conv.jl") - include("layers/upsample.jl") - include("layers/show.jl") - end + # @testset "Layers" begin + # include("layers/basic.jl") + # include("layers/normalisation.jl") + # include("layers/stateless.jl") + # include("layers/recurrent.jl") + # include("layers/conv.jl") + # include("layers/upsample.jl") + # include("layers/show.jl") + # end - @testset "outputsize" begin - using Flux: outputsize - include("outputsize.jl") - end + # @testset "outputsize" begin + # using Flux: outputsize + # include("outputsize.jl") + # end - @testset "CUDA" begin - if CUDA.functional() - include("cuda/runtests.jl") - else - @warn "CUDA unavailable, not testing GPU support" - end - end + # @testset "CUDA" begin + # if CUDA.functional() + # include("cuda/runtests.jl") + # else + # @warn "CUDA unavailable, not testing GPU support" + # end + # end - @static if VERSION == v"1.6" - using Documenter - @testset "Docs" begin - DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) - doctest(Flux) - end - end + # @static if VERSION == v"1.6" + # using Documenter + # @testset "Docs" begin + # DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) + # doctest(Flux) + # end + # end if get(ENV, "FLUX_TEST_AMDGPU", "false") == "true" using AMDGPU From d0eb6a03fa0754a22944b9b0bfdaf5150d73def1 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 16 Feb 2023 12:54:09 +0200 Subject: [PATCH 04/15] Handle convolutions correctly --- ext/AMDGPUExt/functor.jl | 57 ++++++++++++------- test/amd/basic.jl | 120 +++++++++++++++++++++------------------ 2 files changed, 101 insertions(+), 76 deletions(-) diff --git a/ext/AMDGPUExt/functor.jl b/ext/AMDGPUExt/functor.jl index 11ba232039..02ccb5b47a 100644 --- a/ext/AMDGPUExt/functor.jl +++ b/ext/AMDGPUExt/functor.jl @@ -18,24 +18,6 @@ adapt_storage(::FluxAMDAdaptor, x::AbstractRNG) = error(""" Cannot map RNG of type $(typeof(x)) to AMDGPU. AMDGPU execution only supports Random.default_rng().""") -function adapt_storage(to::FluxAMDAdaptor, m::Flux.Conv) - Flux.Conv( - Adapt.adapt(to, m.σ), - Adapt.adapt(to, m.weight[end:-1:1, end:-1:1, :, :]), - Adapt.adapt(to, m.bias), - m.stride, m.pad, m.dilation, m.groups) -end - -# # Don't adapt again. -# function adapt_storage( -# to::FluxAMDAdaptor, m::Flux.Conv{N, M, F, A, V}, -# ) where {N, M, F, A <: ROCArray, V} -# return m -# end - -# TODO GPU -> CPU adaptor -# TODO don't adapt again when already on AMDGPU - adapt_storage(::FluxCPUAdaptor, x::AMDGPU.rocRAND.RNG) = Random.default_rng() function ChainRulesCore.rrule(::Type{Array}, x::ROCArray) @@ -57,11 +39,44 @@ function _amd(x) x end -function _amd(m::Flux.Conv) - to = FluxAMDAdaptor() +# Since MIOpen supports only cross-correlation as convolution, +# for the actual convolution, we flip horizontally and vertically the weights. +# Same for CPU -> GPU & GPU -> CPU movements. +# Note, that gradients are also flipped. + +# CPU -> GPU + +function adapt_storage(to::FluxAMDAdaptor, m::Flux.Conv) + flipped_weight = reverse(m.weight; dims=ntuple(i -> i, ndims(m.weight) - 2)) + Flux.Conv( + Adapt.adapt(to, m.σ), + Adapt.adapt(to, flipped_weight), + Adapt.adapt(to, m.bias), + m.stride, m.pad, m.dilation, m.groups) +end + +# Don't adapt again. +function adapt_storage( + to::FluxAMDAdaptor, m::Flux.Conv{N, M, F, A, V}, +) where {N, M, F, A <: ROCArray, V} + return m +end + +_amd(m::Flux.Conv) = adapt_storage(FluxAMDAdaptor(), m) + +# GPU -> CPU + +function Flux.cpu(m::Flux.Conv{N, M, F, A, V}) where {N, M, F, A <: ROCArray, V} + adapt_storage(FluxCPUAdaptor(), m) +end + +function adapt_storage( + to::FluxCPUAdaptor, m::Flux.Conv{N, M, F, A, V}, +) where {N, M, F, A <: ROCArray, V} + dims = ntuple(i -> i, ndims(m.weight) - 2) Flux.Conv( Adapt.adapt(to, m.σ), - Adapt.adapt(to, m.weight[end:-1:1, end:-1:1, :, :]), + reverse(Adapt.adapt(to, m.weight); dims), Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups) end diff --git a/test/amd/basic.jl b/test/amd/basic.jl index 5d686189c5..d89aab1591 100644 --- a/test/amd/basic.jl +++ b/test/amd/basic.jl @@ -1,72 +1,82 @@ @test Flux.AMDGPU_LOADED[] -# @testset "Basic GPU movement" begin -# @test Flux.amd(rand(Float64, 16)) isa ROCArray{Float32, 1} -# @test Flux.amd(rand(Float64, 16, 16)) isa ROCArray{Float32, 2} -# @test Flux.amd(rand(Float32, 16, 16)) isa ROCArray{Float32, 2} -# @test Flux.amd(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3} +@testset "Basic GPU movement" begin + @test Flux.amd(rand(Float64, 16)) isa ROCArray{Float32, 1} + @test Flux.amd(rand(Float64, 16, 16)) isa ROCArray{Float32, 2} + @test Flux.amd(rand(Float32, 16, 16)) isa ROCArray{Float32, 2} + @test Flux.amd(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3} -# @test gradient(x -> sum(Flux.amd(x)), rand(Float32, 4, 4)) isa Tuple -# @test gradient(x -> sum(cpu(x)), AMDGPU.rand(Float32, 4, 4)) isa Tuple -# end + @test gradient(x -> sum(Flux.amd(x)), rand(Float32, 4, 4)) isa Tuple + @test gradient(x -> sum(cpu(x)), AMDGPU.rand(Float32, 4, 4)) isa Tuple +end -# @testset "Dense no bias" begin -# m = Dense(3 => 4; bias=false) |> Flux.amd -# x = zeros(Float32, 3, 4) |> Flux.amd -# @test sum(m(x)) ≈ 0f0 -# gs = gradient(m -> sum(m(x)), m) -# @test isnothing(gs[1].bias) -# end +@testset "Dense no bias" begin + m = Dense(3 => 4; bias=false) |> Flux.amd + x = zeros(Float32, 3, 4) |> Flux.amd + @test sum(m(x)) ≈ 0f0 + gs = gradient(m -> sum(m(x)), m) + @test isnothing(gs[1].bias) +end -# @testset "Chain of Dense layers" begin -# m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32 -# x = rand(Float32, 10, 10) -# amdgputest(m, x) -# end +@testset "Chain of Dense layers" begin + m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32 + x = rand(Float32, 10, 10) + amdgputest(m, x) +end @testset "Convolution" begin - m = Conv((2, 2), 1 => 1) |> f32 - x = rand(Float32, 4, 4, 1, 1) - amdgputest(m, x; atol=1f-3, checkgrad=false) + for nd in (1, 2, 3) + m = Conv(tuple(fill(2, nd)...), 3 => 4) |> f32 + x = rand(Float32, fill(10, nd)..., 3, 5) - # Gradients are flipped as well. - md, xd = Flux.amd.((m, x)) - gs = gradient(m -> sum(m(x)), m) - gsd = gradient(m -> sum(m(xd)), md) - @test gs[1].weight[end:-1:1, end:-1:1, :, :] ≈ Array(gsd[1].weight) atol=1f-3 + # Ensure outputs are the same. + amdgputest(m, x; atol=1f-3, checkgrad=false) + + # Gradients are flipped as well. + md, xd = Flux.amd.((m, x)) + gs = gradient(m -> sum(m(x)), m) + gsd = gradient(m -> sum(m(xd)), md) + + dims = ntuple(i -> i, ndims(m.weight) - 2) + @test reverse(gs[1].weight; dims) ≈ Array(gsd[1].weight) atol=1f-2 + + # Movement back to CPU flips weights back. + mh = Flux.cpu(md) + @test m.weight ≈ mh.weight + end end -# @testset "Cross-correlation" begin -# m = CrossCor((2, 2), 3 => 4) |> f32 -# x = rand(Float32, 10, 10, 3, 2) -# amdgputest(m, x; atol=1f-3) -# end +@testset "Cross-correlation" begin + m = CrossCor((2, 2), 3 => 4) |> f32 + x = rand(Float32, 10, 10, 3, 2) + amdgputest(m, x; atol=1f-3) +end -# @testset "Restructure" begin -# m = Dense(1, 1) |> Flux.amd -# θ, m̂ = Flux.destructure(m) -# foo(x) = sum(re(p)(x)) +@testset "Restructure" begin + m = Dense(1, 1) |> Flux.amd + θ, m̂ = Flux.destructure(m) + foo(x) = sum(re(p)(x)) -# x = Flux.amd(rand(Float32, 1)) -# @test gradient(x -> sum(m̂(θ)(x)), x)[1] isa ROCArray{Float32} -# end + x = Flux.amd(rand(Float32, 1)) + @test gradient(x -> sum(m̂(θ)(x)), x)[1] isa ROCArray{Float32} +end -# @testset "Flux.amd(x) on structured arrays" begin -# g1 = Zygote.OneElement(1, (2, 3), axes(ones(4, 5))) -# @test Flux.amd(g1) isa ROCMatrix{Int64} -# g2 = Zygote.Fill(1f0, 2) -# @test Flux.amd(g2) isa ROCArray{Float32, 1} -# g3 = transpose(Float32[1 2; 3 4]) -# @test parent(Flux.amd(g3)) isa ROCMatrix{Float32} -# end +@testset "Flux.amd(x) on structured arrays" begin + g1 = Zygote.OneElement(1, (2, 3), axes(ones(4, 5))) + @test Flux.amd(g1) isa ROCMatrix{Int64} + g2 = Zygote.Fill(1f0, 2) + @test Flux.amd(g2) isa ROCArray{Float32, 1} + g3 = transpose(Float32[1 2; 3 4]) + @test parent(Flux.amd(g3)) isa ROCMatrix{Float32} +end -# @testset "Flux.onecold gpu" begin -# y = Flux.onehotbatch(ones(3), 1:10) |> Flux.amd -# l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] -# @test Flux.onecold(y) isa ROCArray -# @test y[3, :] isa ROCArray -# @test Flux.onecold(y, l) == ['a', 'a', 'a'] -# end +@testset "Flux.onecold gpu" begin + y = Flux.onehotbatch(ones(3), 1:10) |> Flux.amd + l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] + @test Flux.onecold(y) isa ROCArray + @test y[3, :] isa ROCArray + @test Flux.onecold(y, l) == ['a', 'a', 'a'] +end # FIXME scalar indexing. Needs NNlib.scatter? # @testset "Flux.onehot gpu" begin From de91f9a74ecfadddfefe73c7e4fb4adf47eee7e0 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 16 Feb 2023 15:00:24 +0200 Subject: [PATCH 05/15] Add batchnorm --- ext/AMDGPUExt/AMDGPUExt.jl | 2 ++ ext/AMDGPUExt/batchnorm.jl | 20 ++++++++++++++++++++ test/amd/basic.jl | 6 ++++++ 3 files changed, 28 insertions(+) create mode 100644 ext/AMDGPUExt/batchnorm.jl diff --git a/ext/AMDGPUExt/AMDGPUExt.jl b/ext/AMDGPUExt/AMDGPUExt.jl index 2b6c037afc..1b2dc9a2e8 100644 --- a/ext/AMDGPUExt/AMDGPUExt.jl +++ b/ext/AMDGPUExt/AMDGPUExt.jl @@ -1,6 +1,7 @@ module AMDGPUExt import ChainRulesCore +import ChainRulesCore: NoTangent import Flux import Flux: FluxCPUAdaptor, _amd, _isleaf, adapt_storage, fmap @@ -9,6 +10,7 @@ using Adapt using Random using Zygote +const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing) function check_use_amdgpu() diff --git a/ext/AMDGPUExt/batchnorm.jl b/ext/AMDGPUExt/batchnorm.jl new file mode 100644 index 0000000000..393d3d6918 --- /dev/null +++ b/ext/AMDGPUExt/batchnorm.jl @@ -0,0 +1,20 @@ +function (b::Flux.BatchNorm)(x::ROCArray{T}) where T <: MIOPENFloat + bλ.(_amd_batchnorm(x, b.γ, b.β; μ=b.μ, σ²=b.σ², ϵ=b.ϵ)) +end + +function _amd_batchnorm(x, γ, β; μ, σ², ϵ) + if NNlib.within_gradient(x) + return AMDGPU.MIOpen.batchnorm_training(x, γ, β, μ, σ²; ϵ, iteration=0) # TODO iteration + else + return AMDGPU.MIOpen.batchnorm_inference(x, γ, β, μ, σ²; ϵ) + end +end + +function ChainRulesCore.rrule(::typeof(_amd_batchnorm), x, γ, β; μ, σ², ϵ) + y, μ_saved, ν_saved = _amd_batchnorm(x, γ, β; μ, σ², ϵ) + function _batchnorm_pullback(Δ) + dx, dγ, dβ = MIOpen.∇batchnorm(Δ, x, γ, β, μ_saved, ν_saved) + (NoTangent(), dx, dγ, dβ) + end + y, _batchnorm_pullback +end diff --git a/test/amd/basic.jl b/test/amd/basic.jl index d89aab1591..bf3433e117 100644 --- a/test/amd/basic.jl +++ b/test/amd/basic.jl @@ -78,6 +78,12 @@ end @test Flux.onecold(y, l) == ['a', 'a', 'a'] end +@testset "Batchnorm" begin + bn = BatchNorm(3, σ) + x = rand(Float32, 16, 16, 3, 4) + amdgputest(bn, x; atol=1f-3) +end + # FIXME scalar indexing. Needs NNlib.scatter? # @testset "Flux.onehot gpu" begin # y = Flux.onehotbatch(ones(3), 1:2) |> Flux.amd From a37ee90b710bf5aecdaf9404f9f425956a7e903d Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 16 Feb 2023 23:48:35 +0200 Subject: [PATCH 06/15] Add gpu backend switch mechanism --- LocalPreferences.toml | 2 + NEWS.md | 3 ++ Project.toml | 5 ++- ext/AMDGPUExt/AMDGPUExt.jl | 2 +- ext/AMDGPUExt/functor.jl | 2 - src/Flux.jl | 19 +++++++++ src/functor.jl | 8 +++- test/amd/basic.jl | 42 ++++++++++--------- test/amd/runtests.jl | 2 + test/amd/utils.jl | 4 +- test/runtests.jl | 84 +++++++++++++++++++------------------- 11 files changed, 103 insertions(+), 70 deletions(-) create mode 100644 LocalPreferences.toml diff --git a/LocalPreferences.toml b/LocalPreferences.toml new file mode 100644 index 0000000000..48efdf096a --- /dev/null +++ b/LocalPreferences.toml @@ -0,0 +1,2 @@ +[Flux] +gpu_backend = "AMD" diff --git a/NEWS.md b/NEWS.md index dc770e0f73..a65bf12235 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,9 @@ ## v0.13.13 * Added `f16` which changes precision to `Float16`, recursively. +* Initial support for AMDGPU via extension mechanism. +* Add `gpu_backend` preference to select GPU backend using `LocalPreference.toml`. +* Add `Flux.gpu_backend!` method to switch between GPU backends. ## v0.13.12 * CUDA.jl 4.0 compatibility. diff --git a/Project.toml b/Project.toml index 8ccec51734..2968d43ca9 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -30,8 +31,8 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" AMDGPUExt = "AMDGPU" [compat] -Adapt = "3.0" AMDGPU = "0.4.8" +Adapt = "3.0" CUDA = "3, 4" ChainRulesCore = "1.12" Functors = "0.3, 0.4" @@ -57,4 +58,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AMDGPU", "Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] diff --git a/ext/AMDGPUExt/AMDGPUExt.jl b/ext/AMDGPUExt/AMDGPUExt.jl index 1b2dc9a2e8..2b422a0393 100644 --- a/ext/AMDGPUExt/AMDGPUExt.jl +++ b/ext/AMDGPUExt/AMDGPUExt.jl @@ -3,7 +3,7 @@ module AMDGPUExt import ChainRulesCore import ChainRulesCore: NoTangent import Flux -import Flux: FluxCPUAdaptor, _amd, _isleaf, adapt_storage, fmap +import Flux: FluxCPUAdaptor, FluxAMDAdaptor, _amd, _isleaf, adapt_storage, fmap using AMDGPU using Adapt diff --git a/ext/AMDGPUExt/functor.jl b/ext/AMDGPUExt/functor.jl index 02ccb5b47a..c94778cf3e 100644 --- a/ext/AMDGPUExt/functor.jl +++ b/ext/AMDGPUExt/functor.jl @@ -1,5 +1,3 @@ -struct FluxAMDAdaptor end - # Convert Float64 to Float32, but preserve Float16. adapt_storage(::FluxAMDAdaptor, x::T) where T <: AbstractArray = isbits(x) ? x : ROCArray(x) diff --git a/src/Flux.jl b/src/Flux.jl index 041827905c..db6dee2946 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -1,6 +1,7 @@ module Flux using Base: tail +using Preferences using LinearAlgebra, Statistics, Random # standard lib using MacroTools, Reexport, ProgressLogging, SpecialFunctions using MacroTools: @forward @@ -72,4 +73,22 @@ include("deprecations.jl") include("cuda/cuda.jl") +const GPU_BACKENDS = Dict( + "CUDA" => FluxCUDAAdaptor(), + "AMD" => FluxAMDAdaptor()) + +const GPU_BACKEND = Ref{Union{FluxCUDAAdaptor, FluxAMDAdaptor}}( + GPU_BACKENDS[@load_preference("gpu_backend", "CUDA")]) + +function gpu_backend!(backend::String) + backend in keys(GPU_BACKENDS) || throw(ArgumentError(""" + Unsupported GPU backend: $backend. + Supported backends are: $(keys(GPU_BACKENDS)). + """)) + + @set_preferences!("gpu_backend" => backend) + GPU_BACKEND[] = GPU_BACKENDS[@load_preference("gpu_backend")] + return +end + end # module diff --git a/src/functor.jl b/src/functor.jl index 8e2f4df300..09fa7d467c 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -209,6 +209,10 @@ CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer} ``` """ function gpu(x) + gpu(GPU_BACKEND[], x) +end + +function gpu(::FluxCUDAAdaptor, x) check_use_cuda() use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x end @@ -282,9 +286,11 @@ trainable(c::Cholesky) = () # AMDGPU extension. +struct FluxAMDAdaptor end + const AMDGPU_LOADED = Ref{Bool}(false) -function amd(x) +function gpu(::FluxAMDAdaptor, x) if AMDGPU_LOADED[] return _amd(x) else diff --git a/test/amd/basic.jl b/test/amd/basic.jl index bf3433e117..bcd4443d00 100644 --- a/test/amd/basic.jl +++ b/test/amd/basic.jl @@ -1,18 +1,18 @@ @test Flux.AMDGPU_LOADED[] @testset "Basic GPU movement" begin - @test Flux.amd(rand(Float64, 16)) isa ROCArray{Float32, 1} - @test Flux.amd(rand(Float64, 16, 16)) isa ROCArray{Float32, 2} - @test Flux.amd(rand(Float32, 16, 16)) isa ROCArray{Float32, 2} - @test Flux.amd(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3} + @test Flux.gpu(rand(Float64, 16)) isa ROCArray{Float32, 1} + @test Flux.gpu(rand(Float64, 16, 16)) isa ROCArray{Float32, 2} + @test Flux.gpu(rand(Float32, 16, 16)) isa ROCArray{Float32, 2} + @test Flux.gpu(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3} - @test gradient(x -> sum(Flux.amd(x)), rand(Float32, 4, 4)) isa Tuple + @test gradient(x -> sum(Flux.gpu(x)), rand(Float32, 4, 4)) isa Tuple @test gradient(x -> sum(cpu(x)), AMDGPU.rand(Float32, 4, 4)) isa Tuple end @testset "Dense no bias" begin - m = Dense(3 => 4; bias=false) |> Flux.amd - x = zeros(Float32, 3, 4) |> Flux.amd + m = Dense(3 => 4; bias=false) |> Flux.gpu + x = zeros(Float32, 3, 4) |> Flux.gpu @test sum(m(x)) ≈ 0f0 gs = gradient(m -> sum(m(x)), m) @test isnothing(gs[1].bias) @@ -25,7 +25,7 @@ end end @testset "Convolution" begin - for nd in (1, 2, 3) + for nd in 1:3 m = Conv(tuple(fill(2, nd)...), 3 => 4) |> f32 x = rand(Float32, fill(10, nd)..., 3, 5) @@ -33,7 +33,7 @@ end amdgputest(m, x; atol=1f-3, checkgrad=false) # Gradients are flipped as well. - md, xd = Flux.amd.((m, x)) + md, xd = Flux.gpu.((m, x)) gs = gradient(m -> sum(m(x)), m) gsd = gradient(m -> sum(m(xd)), md) @@ -53,25 +53,25 @@ end end @testset "Restructure" begin - m = Dense(1, 1) |> Flux.amd + m = Dense(1, 1) |> Flux.gpu θ, m̂ = Flux.destructure(m) foo(x) = sum(re(p)(x)) - x = Flux.amd(rand(Float32, 1)) + x = Flux.gpu(rand(Float32, 1)) @test gradient(x -> sum(m̂(θ)(x)), x)[1] isa ROCArray{Float32} end -@testset "Flux.amd(x) on structured arrays" begin +@testset "Flux.gpu(x) on structured arrays" begin g1 = Zygote.OneElement(1, (2, 3), axes(ones(4, 5))) - @test Flux.amd(g1) isa ROCMatrix{Int64} + @test Flux.gpu(g1) isa ROCMatrix{Int64} g2 = Zygote.Fill(1f0, 2) - @test Flux.amd(g2) isa ROCArray{Float32, 1} + @test Flux.gpu(g2) isa ROCArray{Float32, 1} g3 = transpose(Float32[1 2; 3 4]) - @test parent(Flux.amd(g3)) isa ROCMatrix{Float32} + @test parent(Flux.gpu(g3)) isa ROCMatrix{Float32} end @testset "Flux.onecold gpu" begin - y = Flux.onehotbatch(ones(3), 1:10) |> Flux.amd + y = Flux.onehotbatch(ones(3), 1:10) |> Flux.gpu l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] @test Flux.onecold(y) isa ROCArray @test y[3, :] isa ROCArray @@ -80,13 +80,15 @@ end @testset "Batchnorm" begin bn = BatchNorm(3, σ) - x = rand(Float32, 16, 16, 3, 4) - amdgputest(bn, x; atol=1f-3) + for nd in 1:3 + x = rand(Float32, fill(16, nd - 1)..., 3, 4) + amdgputest(bn, x; atol=1f-3) + end end # FIXME scalar indexing. Needs NNlib.scatter? # @testset "Flux.onehot gpu" begin -# y = Flux.onehotbatch(ones(3), 1:2) |> Flux.amd -# x = rand(3, 2) |> Flux.amd +# y = Flux.onehotbatch(ones(3), 1:2) |> Flux.gpu +# x = rand(3, 2) |> Flux.gpu # @test gradient(x -> sum(x * y), x)[1] isa ROCArray # end diff --git a/test/amd/runtests.jl b/test/amd/runtests.jl index d82e1841c9..2a2a95160d 100644 --- a/test/amd/runtests.jl +++ b/test/amd/runtests.jl @@ -1,3 +1,5 @@ +Flux.gpu_backend!("AMD") + include("utils.jl") AMDGPU.allowscalar(false) diff --git a/test/amd/utils.jl b/test/amd/utils.jl index b8b93caf35..dbb5f9dabc 100644 --- a/test/amd/utils.jl +++ b/test/amd/utils.jl @@ -1,9 +1,9 @@ function amdgputest(model, xs...; checkgrad=true, atol=1e-6) cpu_model = model - gpu_model = Flux.amd(model) + gpu_model = Flux.gpu(model) cpu_in = xs - gpu_in = Flux.amd.(xs) + gpu_in = Flux.gpu.(xs) cpu_out = cpu_model(cpu_in...) gpu_out = gpu_model(gpu_in...) diff --git a/test/runtests.jl b/test/runtests.jl index 6276834401..503ae2f2f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,55 +11,55 @@ Random.seed!(0) @testset verbose=true "Flux.jl" begin - # @testset "Utils" begin - # include("utils.jl") - # end + @testset "Utils" begin + include("utils.jl") + end - # @testset "Optimise / Train" begin - # include("optimise.jl") - # include("train.jl") - # end + @testset "Optimise / Train" begin + include("optimise.jl") + include("train.jl") + end - # @testset "Data" begin - # include("data.jl") - # end + @testset "Data" begin + include("data.jl") + end - # @testset "Losses" begin - # include("losses.jl") - # include("ctc.jl") - # CUDA.functional() && include("ctc-gpu.jl") - # end + @testset "Losses" begin + include("losses.jl") + include("ctc.jl") + CUDA.functional() && include("ctc-gpu.jl") + end - # @testset "Layers" begin - # include("layers/basic.jl") - # include("layers/normalisation.jl") - # include("layers/stateless.jl") - # include("layers/recurrent.jl") - # include("layers/conv.jl") - # include("layers/upsample.jl") - # include("layers/show.jl") - # end + @testset "Layers" begin + include("layers/basic.jl") + include("layers/normalisation.jl") + include("layers/stateless.jl") + include("layers/recurrent.jl") + include("layers/conv.jl") + include("layers/upsample.jl") + include("layers/show.jl") + end - # @testset "outputsize" begin - # using Flux: outputsize - # include("outputsize.jl") - # end + @testset "outputsize" begin + using Flux: outputsize + include("outputsize.jl") + end - # @testset "CUDA" begin - # if CUDA.functional() - # include("cuda/runtests.jl") - # else - # @warn "CUDA unavailable, not testing GPU support" - # end - # end + @testset "CUDA" begin + if CUDA.functional() + include("cuda/runtests.jl") + else + @warn "CUDA unavailable, not testing GPU support" + end + end - # @static if VERSION == v"1.6" - # using Documenter - # @testset "Docs" begin - # DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) - # doctest(Flux) - # end - # end + @static if VERSION == v"1.6" + using Documenter + @testset "Docs" begin + DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) + doctest(Flux) + end + end if get(ENV, "FLUX_TEST_AMDGPU", "false") == "true" using AMDGPU From eaada4b5b363d6d04e88045351ebf8874dbdb28f Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 16 Feb 2023 23:49:08 +0200 Subject: [PATCH 07/15] Update .gitignore --- .gitignore | 2 +- LocalPreferences.toml | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) delete mode 100644 LocalPreferences.toml diff --git a/.gitignore b/.gitignore index 61c47905dc..45b845a41b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,4 @@ docs/site/ deps .vscode Manifest.toml - +LocalPreferences.toml diff --git a/LocalPreferences.toml b/LocalPreferences.toml deleted file mode 100644 index 48efdf096a..0000000000 --- a/LocalPreferences.toml +++ /dev/null @@ -1,2 +0,0 @@ -[Flux] -gpu_backend = "AMD" From 37ce734cea936999a4fe1ac7c8833c34a45a4217 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 16 Feb 2023 23:52:02 +0200 Subject: [PATCH 08/15] Add AMDGPU to extras --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 2968d43ca9..8cd6574d86 100644 --- a/Project.toml +++ b/Project.toml @@ -50,6 +50,7 @@ Zygote = "0.6.49" julia = "1.6" [extras] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" From cf6dd42d8ec73054e8f3db104508427371f59fd6 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Fri, 17 Feb 2023 15:45:48 +0200 Subject: [PATCH 09/15] Fix batchnorm & handle regular convolutions --- Project.toml | 2 +- ext/AMDGPUExt/AMDGPUExt.jl | 4 ++++ ext/AMDGPUExt/batchnorm.jl | 20 ++++++++++++-------- ext/AMDGPUExt/conv.jl | 9 +++++++++ src/Flux.jl | 18 ------------------ src/functor.jl | 35 ++++++++++++++++++++++++++++++++++- test/amd/basic.jl | 4 ++-- test/amd/utils.jl | 38 ++++++++++++++++++++++---------------- 8 files changed, 84 insertions(+), 46 deletions(-) create mode 100644 ext/AMDGPUExt/conv.jl diff --git a/Project.toml b/Project.toml index 8cd6574d86..23ce3b0a3d 100644 --- a/Project.toml +++ b/Project.toml @@ -59,4 +59,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] +test = ["AMDGPU", "Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] diff --git a/ext/AMDGPUExt/AMDGPUExt.jl b/ext/AMDGPUExt/AMDGPUExt.jl index 2b422a0393..934af63b7f 100644 --- a/ext/AMDGPUExt/AMDGPUExt.jl +++ b/ext/AMDGPUExt/AMDGPUExt.jl @@ -4,6 +4,8 @@ import ChainRulesCore import ChainRulesCore: NoTangent import Flux import Flux: FluxCPUAdaptor, FluxAMDAdaptor, _amd, _isleaf, adapt_storage, fmap +import Flux: DenseConvDims, Conv, conv, conv_reshape_bias +import NNlib using AMDGPU using Adapt @@ -32,6 +34,8 @@ end ChainRulesCore.@non_differentiable check_use_amdgpu() include("functor.jl") +include("batchnorm.jl") +include("conv.jl") function __init__() Flux.AMDGPU_LOADED[] = true diff --git a/ext/AMDGPUExt/batchnorm.jl b/ext/AMDGPUExt/batchnorm.jl index 393d3d6918..10b7d4be43 100644 --- a/ext/AMDGPUExt/batchnorm.jl +++ b/ext/AMDGPUExt/batchnorm.jl @@ -1,19 +1,23 @@ function (b::Flux.BatchNorm)(x::ROCArray{T}) where T <: MIOPENFloat - bλ.(_amd_batchnorm(x, b.γ, b.β; μ=b.μ, σ²=b.σ², ϵ=b.ϵ)) + b.λ.(_amd_batchnorm( + x, b.γ, b.β; μ=b.μ, σ²=b.σ², ϵ=b.ϵ, + within_grad=NNlib.within_gradient(x))) end -function _amd_batchnorm(x, γ, β; μ, σ², ϵ) - if NNlib.within_gradient(x) - return AMDGPU.MIOpen.batchnorm_training(x, γ, β, μ, σ²; ϵ, iteration=0) # TODO iteration +function _amd_batchnorm(x, γ, β; μ, σ², ϵ, within_grad::Bool) + if within_grad + return AMDGPU.MIOpen.batchnorm_training(x, γ, β, μ, σ²; ϵ=Float64(ϵ), iteration=0) # TODO iteration else - return AMDGPU.MIOpen.batchnorm_inference(x, γ, β, μ, σ²; ϵ) + return AMDGPU.MIOpen.batchnorm_inference(x, γ, β, μ, σ²; ϵ=Float64(ϵ)) end end -function ChainRulesCore.rrule(::typeof(_amd_batchnorm), x, γ, β; μ, σ², ϵ) - y, μ_saved, ν_saved = _amd_batchnorm(x, γ, β; μ, σ², ϵ) +function ChainRulesCore.rrule( + ::typeof(_amd_batchnorm), x, γ, β; μ, σ², ϵ, within_grad::Bool, +) + y, μ_saved, ν_saved = _amd_batchnorm(x, γ, β; μ, σ², ϵ, within_grad) function _batchnorm_pullback(Δ) - dx, dγ, dβ = MIOpen.∇batchnorm(Δ, x, γ, β, μ_saved, ν_saved) + dx, dγ, dβ = AMDGPU.MIOpen.∇batchnorm(Δ, x, γ, β, μ_saved, ν_saved) (NoTangent(), dx, dγ, dβ) end y, _batchnorm_pullback diff --git a/ext/AMDGPUExt/conv.jl b/ext/AMDGPUExt/conv.jl new file mode 100644 index 0000000000..1af952b6eb --- /dev/null +++ b/ext/AMDGPUExt/conv.jl @@ -0,0 +1,9 @@ +function (c::Conv)(x::T) where T <: ROCArray + Flux._size_check(c, x, ndims(x) - 1 => Flux._channels_in(c)) + σ = NNlib.fast_act(c.σ, x) + cdims = DenseConvDims( + x, c.weight; stride=c.stride, padding=c.pad, + dilation=c.dilation, groups=c.groups, flipkernel=true) + xT = Flux._match_eltype(c, x) + σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c)) +end diff --git a/src/Flux.jl b/src/Flux.jl index db6dee2946..5280259b1f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -73,22 +73,4 @@ include("deprecations.jl") include("cuda/cuda.jl") -const GPU_BACKENDS = Dict( - "CUDA" => FluxCUDAAdaptor(), - "AMD" => FluxAMDAdaptor()) - -const GPU_BACKEND = Ref{Union{FluxCUDAAdaptor, FluxAMDAdaptor}}( - GPU_BACKENDS[@load_preference("gpu_backend", "CUDA")]) - -function gpu_backend!(backend::String) - backend in keys(GPU_BACKENDS) || throw(ArgumentError(""" - Unsupported GPU backend: $backend. - Supported backends are: $(keys(GPU_BACKENDS)). - """)) - - @set_preferences!("gpu_backend" => backend) - GPU_BACKEND[] = GPU_BACKENDS[@load_preference("gpu_backend")] - return -end - end # module diff --git a/src/functor.jl b/src/functor.jl index 09fa7d467c..5fa1c926cc 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -177,6 +177,30 @@ _isbitsarray(x) = false _isleaf(::AbstractRNG) = true _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) +const GPU_BACKENDS = ("CUDA", "AMD") +const GPU_BACKEND = @load_preference("gpu_backend", "CUDA") + +function gpu_backend!(backend::String) + if backend == GPU_BACKEND + @info """ + GPU backend is already set to: $backend. + No need to do anything else. + """ + return + end + + backend in GPU_BACKENDS || throw(ArgumentError(""" + Unsupported GPU backend: $backend. + Supported backends are: $GPU_BACKENDS. + """)) + + @set_preferences!("gpu_backend" => backend) + @info """ + New GPU backend set: $backend. + Restart your Julia session for this change to take effect! + """ +end + """ gpu(x) @@ -209,7 +233,16 @@ CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer} ``` """ function gpu(x) - gpu(GPU_BACKEND[], x) + @static if GPU_BACKEND == "CUDA" + gpu(FluxCUDAAdaptor(), x) + elseif GPU_BACKEND == "AMD" + gpu(FluxAMDAdaptor(), x) + else + error(""" + Unsupported GPU backend: $GPU_BACKEND. + Supported backends are: $GPU_BACKENDS. + """) + end end function gpu(::FluxCUDAAdaptor, x) diff --git a/test/amd/basic.jl b/test/amd/basic.jl index bcd4443d00..68bcda7d46 100644 --- a/test/amd/basic.jl +++ b/test/amd/basic.jl @@ -81,8 +81,8 @@ end @testset "Batchnorm" begin bn = BatchNorm(3, σ) for nd in 1:3 - x = rand(Float32, fill(16, nd - 1)..., 3, 4) - amdgputest(bn, x; atol=1f-3) + x = rand(Float32, fill(2, nd - 1)..., 3, 4) + amdgputest(bn, x; atol=1f-3, allow_nothing=true) end end diff --git a/test/amd/utils.jl b/test/amd/utils.jl index dbb5f9dabc..96aa47fc4d 100644 --- a/test/amd/utils.jl +++ b/test/amd/utils.jl @@ -1,4 +1,6 @@ -function amdgputest(model, xs...; checkgrad=true, atol=1e-6) +function amdgputest( + model, xs...; checkgrad=true, atol=1e-6, allow_nothing::Bool = false, +) cpu_model = model gpu_model = Flux.gpu(model) @@ -12,36 +14,40 @@ function amdgputest(model, xs...; checkgrad=true, atol=1e-6) if checkgrad cpu_grad = gradient(m -> sum(m(cpu_in...)), cpu_model) gpu_grad = gradient(m -> sum(m(gpu_in...)), gpu_model) - amd_check_grad(gpu_grad, cpu_grad; atol) + amd_check_grad(gpu_grad, cpu_grad; atol, allow_nothing) end end -function amd_check_grad(g_gpu, g_cpu; atol) - @show g_gpu g_cpu - @test false +function amd_check_grad(g_gpu, g_cpu; atol, allow_nothing) + allow_nothing && return + @show g_gpu g_cpu + @test false end -amd_check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol) = - amd_check_grad(g_gpu[], g_cpu[]; atol) -amd_check_grad(g_gpu::Nothing, g_cpu::Nothing; atol) = +amd_check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol, allow_nothing) = + amd_check_grad(g_gpu[], g_cpu[]; atol, allow_nothing) +amd_check_grad(g_gpu::Nothing, g_cpu::Nothing; atol, allow_nothing) = @test true -amd_check_grad(g_gpu::Float32, g_cpu::Float32; atol) = +amd_check_grad(g_gpu::Float32, g_cpu::Float32; atol, allow_nothing) = @test g_cpu ≈ g_gpu atol=atol -amd_check_grad(g_gpu::ROCArray{Float32}, g_cpu::Array{Float32}; atol) = - @test g_cpu ≈ collect(g_gpu) atol=atol amd_check_grad( - g_gpu::ROCArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill; atol, + g_gpu::ROCArray{Float32}, g_cpu::Array{Float32}; + atol, allow_nothing, +) = @test g_cpu ≈ collect(g_gpu) atol=atol +amd_check_grad( + g_gpu::ROCArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill; + atol, allow_nothing ) = @test collect(g_cpu) ≈ collect(g_gpu) atol=atol -function amd_check_grad(g_gpu::Tuple, g_cpu::Tuple; atol) +function amd_check_grad(g_gpu::Tuple, g_cpu::Tuple; atol, allow_nothing) for (v1, v2) in zip(g_gpu, g_cpu) - amd_check_grad(v1, v2; atol) + amd_check_grad(v1, v2; atol, allow_nothing) end end -function amd_check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple; atol) +function amd_check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple; atol, allow_nothing) for ((k1, v1), (k2, v2)) in zip(pairs(g_gpu), pairs(g_cpu)) @test k1 == k2 - amd_check_grad(v1, v2; atol) + amd_check_grad(v1, v2; atol, allow_nothing) end end From 0a9daf72d581a9de266f7c2256de9f4d0c1a4974 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Fri, 17 Feb 2023 15:46:53 +0200 Subject: [PATCH 10/15] Cleanup --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 23ce3b0a3d..8cd6574d86 100644 --- a/Project.toml +++ b/Project.toml @@ -59,4 +59,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AMDGPU", "Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] From f68e54d828214f8cbed964e95eab1351e18b54fa Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 21 Feb 2023 13:55:30 +0200 Subject: [PATCH 11/15] Update NNlib compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8cd6574d86..b240e56150 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ ChainRulesCore = "1.12" Functors = "0.3, 0.4" MLUtils = "0.2, 0.3.1, 0.4" MacroTools = "0.5" -NNlib = "0.8.15" +NNlib = "0.8.19" NNlibCUDA = "0.2.6" OneHotArrays = "0.1, 0.2" Optimisers = "0.2.12" From 1e8bfb2f1157a40995898c092d540de813bfdad7 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 21 Feb 2023 14:33:36 +0200 Subject: [PATCH 12/15] Add documentation about AMD GPU support --- docs/src/gpu.md | 36 ++++++++++++++++++++++++++++++++++++ src/functor.jl | 4 ++-- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/docs/src/gpu.md b/docs/src/gpu.md index 46fed4e1bf..dc3fbe9f38 100644 --- a/docs/src/gpu.md +++ b/docs/src/gpu.md @@ -2,6 +2,8 @@ NVIDIA GPU support should work out of the box on systems with CUDA and CUDNN installed. For more details see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) readme. +AMD GPU support is available since Julia 1.9 on systems with ROCm and MIOpen installed. For more details refer to the [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl) repository. + ## Checking GPU Availability By default, Flux will run the checks on your system to see if it can support GPU functionality. You can check if Flux identified a valid GPU setup by typing the following: @@ -13,6 +15,40 @@ julia> CUDA.functional() true ``` +For AMD GPU: + +```julia +julia> using AMDGPU + +julia> AMDGPU.functional() +true + +julia> AMDGPU.functional(:MIOpen) +true +``` + +## Selecting GPU backend + +Available GPU backends are: `CUDA`, `AMD`. + +Flux relies on [Preferences.jl](https://github.com/JuliaPackaging/Preferences.jl) for selecting default GPU backend to use. + +There are two ways you can specify it: + +- From the REPL/code in your project, call `Flux.gpu_backend!("AMD")` and restart (if needed) Julia session for the changes to take effect. +- In `LocalPreferences.toml` file in you project directory specify: +```toml +[Flux] +gpu_backend = "AMD" +``` + +Current GPU backend can be fetched from `Flux.GPU_BACKEND` variable: + +```julia +julia> Flux.GPU_BACKEND +"CUDA" +``` + ## GPU Usage Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CUDA](https://github.com/JuliaGPU/CUDA.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it. diff --git a/src/functor.jl b/src/functor.jl index 5fa1c926cc..dd819adb93 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -204,11 +204,11 @@ end """ gpu(x) -Copies `m` to the current GPU device, if one is available. +Copies `m` to the current GPU device (using current GPU backend), if one is available. If no GPU is available, it does nothing (but prints a warning the first time). On arrays, this calls CUDA's `cu`, which also changes arrays -with Float64 elements to Float32 while copying them to the device. +with Float64 elements to Float32 while copying them to the device (same for AMDGPU). To act on arrays within a struct, the struct type must be marked with [`@functor`](@ref). Use [`cpu`](@ref) to copy back to ordinary `Array`s. From 746caa51abf2d6d0b542f1ee85031a5792822bd2 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Mon, 27 Feb 2023 13:43:28 +0200 Subject: [PATCH 13/15] Handle ConvTranspose correctly & refactor --- ext/AMDGPUExt/AMDGPUExt.jl | 2 +- ext/AMDGPUExt/conv.jl | 24 ++++++++++++++++------ ext/AMDGPUExt/functor.jl | 42 ++++++++++++++++++++++++++++---------- test/amd/basic.jl | 11 ++-------- test/amd/utils.jl | 2 +- 5 files changed, 53 insertions(+), 28 deletions(-) diff --git a/ext/AMDGPUExt/AMDGPUExt.jl b/ext/AMDGPUExt/AMDGPUExt.jl index 934af63b7f..a8c768f332 100644 --- a/ext/AMDGPUExt/AMDGPUExt.jl +++ b/ext/AMDGPUExt/AMDGPUExt.jl @@ -4,7 +4,7 @@ import ChainRulesCore import ChainRulesCore: NoTangent import Flux import Flux: FluxCPUAdaptor, FluxAMDAdaptor, _amd, _isleaf, adapt_storage, fmap -import Flux: DenseConvDims, Conv, conv, conv_reshape_bias +import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias import NNlib using AMDGPU diff --git a/ext/AMDGPUExt/conv.jl b/ext/AMDGPUExt/conv.jl index 1af952b6eb..29a6042c81 100644 --- a/ext/AMDGPUExt/conv.jl +++ b/ext/AMDGPUExt/conv.jl @@ -1,9 +1,21 @@ -function (c::Conv)(x::T) where T <: ROCArray - Flux._size_check(c, x, ndims(x) - 1 => Flux._channels_in(c)) - σ = NNlib.fast_act(c.σ, x) - cdims = DenseConvDims( +function Flux.conv_dims(c::Conv, x::T) where T <: ROCArray + DenseConvDims( x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups, flipkernel=true) - xT = Flux._match_eltype(c, x) - σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c)) +end + +function Flux.conv_transpose_dims(c::ConvTranspose, x::T) where T <: ROCArray + # Calculate size of "input", from ∇conv_data()'s perspective... + combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end]) + I = (size(x)[1:end - 2] .- 1) .* c.stride .+ 1 .+ + (size(c.weight)[1:end - 2] .- 1) .* c.dilation .- combined_pad + C_in = size(c.weight)[end - 1] * c.groups + batch_size = size(x)[end] + + # Create DenseConvDims() that looks like the corresponding conv(). + w_size = size(c.weight) + DenseConvDims( + (I..., C_in, batch_size), w_size; + stride=c.stride, padding=c.pad, dilation=c.dilation, + groups=c.groups, flipkernel=true) end diff --git a/ext/AMDGPUExt/functor.jl b/ext/AMDGPUExt/functor.jl index c94778cf3e..3f7971aff0 100644 --- a/ext/AMDGPUExt/functor.jl +++ b/ext/AMDGPUExt/functor.jl @@ -44,9 +44,12 @@ end # CPU -> GPU -function adapt_storage(to::FluxAMDAdaptor, m::Flux.Conv) +_conv_basetype(c::Type{C}) where C <: Conv = Conv +_conv_basetype(c::Type{C}) where C <: ConvTranspose = ConvTranspose + +function adapt_storage(to::FluxAMDAdaptor, m::C) where C <: Union{Conv, ConvTranspose} flipped_weight = reverse(m.weight; dims=ntuple(i -> i, ndims(m.weight) - 2)) - Flux.Conv( + _conv_basetype(C)( Adapt.adapt(to, m.σ), Adapt.adapt(to, flipped_weight), Adapt.adapt(to, m.bias), @@ -55,26 +58,43 @@ end # Don't adapt again. function adapt_storage( - to::FluxAMDAdaptor, m::Flux.Conv{N, M, F, A, V}, + to::FluxAMDAdaptor, m::Conv{N, M, F, A, V}, ) where {N, M, F, A <: ROCArray, V} return m end -_amd(m::Flux.Conv) = adapt_storage(FluxAMDAdaptor(), m) +function adapt_storage( + to::FluxAMDAdaptor, m::ConvTranspose{N, M, F, A, V}, +) where {N, M, F, A <: ROCArray, V} + return m +end + +_amd(m::Union{Conv, ConvTranspose}) = adapt_storage(FluxAMDAdaptor(), m) # GPU -> CPU -function Flux.cpu(m::Flux.Conv{N, M, F, A, V}) where {N, M, F, A <: ROCArray, V} +function Flux.cpu(m::Conv{N, M, F, A, V}) where {N, M, F, A <: ROCArray, V} + adapt_storage(FluxCPUAdaptor(), m) +end + +function Flux.cpu(m::ConvTranspose{N, M, F, A, V}) where {N, M, F, A <: ROCArray, V} adapt_storage(FluxCPUAdaptor(), m) end function adapt_storage( - to::FluxCPUAdaptor, m::Flux.Conv{N, M, F, A, V}, + to::FluxCPUAdaptor, m::Conv{N, M, F, A, V}, ) where {N, M, F, A <: ROCArray, V} dims = ntuple(i -> i, ndims(m.weight) - 2) - Flux.Conv( - Adapt.adapt(to, m.σ), - reverse(Adapt.adapt(to, m.weight); dims), - Adapt.adapt(to, m.bias), - m.stride, m.pad, m.dilation, m.groups) + Conv( + Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims), + Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups) +end + +function adapt_storage( + to::FluxCPUAdaptor, m::ConvTranspose{N, M, F, A, V}, +) where {N, M, F, A <: ROCArray, V} + dims = ntuple(i -> i, ndims(m.weight) - 2) + ConvTranspose( + Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims), + Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups) end diff --git a/test/amd/basic.jl b/test/amd/basic.jl index 68bcda7d46..294c2df974 100644 --- a/test/amd/basic.jl +++ b/test/amd/basic.jl @@ -25,8 +25,8 @@ end end @testset "Convolution" begin - for nd in 1:3 - m = Conv(tuple(fill(2, nd)...), 3 => 4) |> f32 + for conv_type in (Conv, ConvTranspose), nd in 1:3 + m = conv_type(tuple(fill(2, nd)...), 3 => 4) |> f32 x = rand(Float32, fill(10, nd)..., 3, 5) # Ensure outputs are the same. @@ -85,10 +85,3 @@ end amdgputest(bn, x; atol=1f-3, allow_nothing=true) end end - -# FIXME scalar indexing. Needs NNlib.scatter? -# @testset "Flux.onehot gpu" begin -# y = Flux.onehotbatch(ones(3), 1:2) |> Flux.gpu -# x = rand(3, 2) |> Flux.gpu -# @test gradient(x -> sum(x * y), x)[1] isa ROCArray -# end diff --git a/test/amd/utils.jl b/test/amd/utils.jl index 96aa47fc4d..23ae6fa05e 100644 --- a/test/amd/utils.jl +++ b/test/amd/utils.jl @@ -37,7 +37,7 @@ amd_check_grad( amd_check_grad( g_gpu::ROCArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill; atol, allow_nothing -) = @test collect(g_cpu) ≈ collect(g_gpu) atol=atol +) = @test g_cpu ≈ collect(g_gpu) atol=atol function amd_check_grad(g_gpu::Tuple, g_cpu::Tuple; atol, allow_nothing) for (v1, v2) in zip(g_gpu, g_cpu) From 54c494697f3feebc9bc7b03fd1fca10391710554 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Mon, 27 Feb 2023 20:44:45 +0200 Subject: [PATCH 14/15] Unify gpu test utils --- test/amd/basic.jl | 8 ++-- test/amd/runtests.jl | 20 +++++++++- test/amd/utils.jl | 53 -------------------------- test/cuda/runtests.jl | 1 - test/cuda/test_utils.jl | 72 ----------------------------------- test/runtests.jl | 2 + test/test_utils.jl | 84 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 108 insertions(+), 132 deletions(-) delete mode 100644 test/amd/utils.jl delete mode 100644 test/cuda/test_utils.jl create mode 100644 test/test_utils.jl diff --git a/test/amd/basic.jl b/test/amd/basic.jl index 294c2df974..d053337381 100644 --- a/test/amd/basic.jl +++ b/test/amd/basic.jl @@ -21,7 +21,7 @@ end @testset "Chain of Dense layers" begin m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32 x = rand(Float32, 10, 10) - amdgputest(m, x) + gpu_autodiff_test(m, x) end @testset "Convolution" begin @@ -30,7 +30,7 @@ end x = rand(Float32, fill(10, nd)..., 3, 5) # Ensure outputs are the same. - amdgputest(m, x; atol=1f-3, checkgrad=false) + gpu_autodiff_test(m, x; atol=1f-3, checkgrad=false) # Gradients are flipped as well. md, xd = Flux.gpu.((m, x)) @@ -49,7 +49,7 @@ end @testset "Cross-correlation" begin m = CrossCor((2, 2), 3 => 4) |> f32 x = rand(Float32, 10, 10, 3, 2) - amdgputest(m, x; atol=1f-3) + gpu_autodiff_test(m, x) end @testset "Restructure" begin @@ -82,6 +82,6 @@ end bn = BatchNorm(3, σ) for nd in 1:3 x = rand(Float32, fill(2, nd - 1)..., 3, 4) - amdgputest(bn, x; atol=1f-3, allow_nothing=true) + gpu_autodiff_test(bn, x; atol=1f-3, allow_nothing=true) end end diff --git a/test/amd/runtests.jl b/test/amd/runtests.jl index 2a2a95160d..70c1876487 100644 --- a/test/amd/runtests.jl +++ b/test/amd/runtests.jl @@ -1,9 +1,25 @@ Flux.gpu_backend!("AMD") -include("utils.jl") - AMDGPU.allowscalar(false) +# Extend test utils to AMDGPU. + +function check_grad( + g_gpu::ROCArray{Float32}, g_cpu::Array{Float32}, atol, rtol; + allow_nothing::Bool, +) + @test g_cpu ≈ collect(g_gpu) atol=atol rtol=rtol +end + +function check_grad( + g_gpu::ROCArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill, + atol, rtol; allow_nothing::Bool, +) + @test g_cpu ≈ collect(g_gpu) atol=atol rtol=rtol +end + +check_type(x::ROCArray{Float32}) = true + @testset "Basic" begin include("basic.jl") end diff --git a/test/amd/utils.jl b/test/amd/utils.jl deleted file mode 100644 index 23ae6fa05e..0000000000 --- a/test/amd/utils.jl +++ /dev/null @@ -1,53 +0,0 @@ -function amdgputest( - model, xs...; checkgrad=true, atol=1e-6, allow_nothing::Bool = false, -) - cpu_model = model - gpu_model = Flux.gpu(model) - - cpu_in = xs - gpu_in = Flux.gpu.(xs) - - cpu_out = cpu_model(cpu_in...) - gpu_out = gpu_model(gpu_in...) - @test collect(cpu_out) ≈ collect(gpu_out) atol=atol - - if checkgrad - cpu_grad = gradient(m -> sum(m(cpu_in...)), cpu_model) - gpu_grad = gradient(m -> sum(m(gpu_in...)), gpu_model) - amd_check_grad(gpu_grad, cpu_grad; atol, allow_nothing) - end -end - -function amd_check_grad(g_gpu, g_cpu; atol, allow_nothing) - allow_nothing && return - @show g_gpu g_cpu - @test false -end - -amd_check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol, allow_nothing) = - amd_check_grad(g_gpu[], g_cpu[]; atol, allow_nothing) -amd_check_grad(g_gpu::Nothing, g_cpu::Nothing; atol, allow_nothing) = - @test true -amd_check_grad(g_gpu::Float32, g_cpu::Float32; atol, allow_nothing) = - @test g_cpu ≈ g_gpu atol=atol -amd_check_grad( - g_gpu::ROCArray{Float32}, g_cpu::Array{Float32}; - atol, allow_nothing, -) = @test g_cpu ≈ collect(g_gpu) atol=atol -amd_check_grad( - g_gpu::ROCArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill; - atol, allow_nothing -) = @test g_cpu ≈ collect(g_gpu) atol=atol - -function amd_check_grad(g_gpu::Tuple, g_cpu::Tuple; atol, allow_nothing) - for (v1, v2) in zip(g_gpu, g_cpu) - amd_check_grad(v1, v2; atol, allow_nothing) - end -end - -function amd_check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple; atol, allow_nothing) - for ((k1, v1), (k2, v2)) in zip(pairs(g_gpu), pairs(g_cpu)) - @test k1 == k2 - amd_check_grad(v1, v2; atol, allow_nothing) - end -end diff --git a/test/cuda/runtests.jl b/test/cuda/runtests.jl index d9cd7c707a..5e25829999 100644 --- a/test/cuda/runtests.jl +++ b/test/cuda/runtests.jl @@ -6,7 +6,6 @@ using Random, LinearAlgebra, Statistics @info "Testing GPU Support" CUDA.allowscalar(false) -include("test_utils.jl") include("cuda.jl") include("losses.jl") include("layers.jl") diff --git a/test/cuda/test_utils.jl b/test/cuda/test_utils.jl deleted file mode 100644 index bc0db37474..0000000000 --- a/test/cuda/test_utils.jl +++ /dev/null @@ -1,72 +0,0 @@ -function check_grad(g_gpu, g_cpu, atol, rtol) - @show g_gpu g_cpu - @test false -end -check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol, rtol) = - check_grad(g_gpu[], g_cpu[], atol, rtol) -check_grad(g_gpu::Nothing, g_cpu::Nothing, atol, rtol) = @test true -check_grad(g_gpu::Float32, g_cpu::Float32, atol, rtol) = @test g_cpu ≈ g_gpu rtol=rtol atol=atol -check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}, atol, rtol) = - @test g_cpu ≈ collect(g_gpu) rtol=rtol atol=atol - -function check_grad(g_gpu::Tuple, g_cpu::Tuple, atol, rtol) - for (v1, v2) in zip(g_gpu, g_cpu) - check_grad(v1, v2, atol, rtol) - end -end - -function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple, atol, rtol) - for ((k1,v1), (k2,v2)) in zip(pairs(g_gpu), pairs(g_cpu)) - @test k1 == k2 - # @show k2 v2 - check_grad(v1, v2, atol, rtol) - end -end - -function gpu_autodiff_test(f_cpu, xs_cpu::Array{Float32}...; - test_equal=true, rtol=1e-4, atol=1e-4) - - check_type(x) = false - check_type(x::Float32) = true - check_type(x::CuArray{Float32}) = true - check_type(x::Array{Float32}) = true - - ### GRADIENT WITH RESPECT TO INPUT ##### - # y_cpu, back_cpu = pullback((f, x...) -> f(x...), f_cpu, xs_cpu...) - y_cpu, back_cpu = pullback((x...) -> f_cpu(x...), xs_cpu...) - @test check_type(y_cpu) - Δ_cpu = size(y_cpu) == () ? randn(Float32) : randn(Float32, size(y_cpu)) - gs_cpu = back_cpu(Δ_cpu) - - f_gpu = f_cpu |> gpu - xs_gpu = gpu.(xs_cpu) - Δ_gpu = Δ_cpu |> gpu - # y_gpu, back_gpu = pullback((f, x...) -> f(x...), f_gpu, xs_gpu...) - y_gpu, back_gpu = pullback((x...) -> f_gpu(x...), xs_gpu...) - @test check_type(y_gpu) - gs_gpu = back_gpu(Δ_gpu) - - if test_equal - @test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol - for (g_gpu, g_cpu) in zip(gs_gpu, gs_cpu) - check_grad(g_gpu, g_cpu, atol, rtol) - end - end - - ### GRADIENT WITH RESPECT TO f ##### - ps_cpu = Flux.params(f_cpu) - y_cpu, back_cpu = pullback(() -> f_cpu(xs_cpu...), ps_cpu) - gs_cpu = back_cpu(Δ_cpu) - - ps_gpu = Flux.params(f_gpu) - y_gpu, back_gpu = pullback(() -> f_gpu(xs_gpu...), ps_gpu) - gs_gpu = back_gpu(Δ_gpu) - - if test_equal - @test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol - @assert length(ps_gpu) == length(ps_cpu) - for (p_gpu, p_cpu) in zip(ps_gpu, ps_cpu) - check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu], atol, rtol) - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 503ae2f2f0..5bd756324d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,8 @@ using IterTools: ncycle using Zygote using CUDA +include("test_utils.jl") + Random.seed!(0) @testset verbose=true "Flux.jl" begin diff --git a/test/test_utils.jl b/test/test_utils.jl new file mode 100644 index 0000000000..2b07e59d08 --- /dev/null +++ b/test/test_utils.jl @@ -0,0 +1,84 @@ +function check_grad(g_gpu, g_cpu, atol, rtol; allow_nothing::Bool) + allow_nothing && return + @show g_gpu g_cpu + @test false +end +check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol, rtol; allow_nothing::Bool) = + check_grad(g_gpu[], g_cpu[], atol, rtol; allow_nothing) +check_grad(g_gpu::Nothing, g_cpu::Nothing, atol, rtol; allow_nothing::Bool) = + @test true +check_grad(g_gpu::Float32, g_cpu::Float32, atol, rtol; allow_nothing::Bool) = + @test g_cpu ≈ g_gpu rtol=rtol atol=atol +check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}, atol, rtol; allow_nothing::Bool) = + @test g_cpu ≈ collect(g_gpu) rtol=rtol atol=atol + +function check_grad(g_gpu::Tuple, g_cpu::Tuple, atol, rtol; allow_nothing::Bool) + for (v1, v2) in zip(g_gpu, g_cpu) + check_grad(v1, v2, atol, rtol; allow_nothing) + end +end + +function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple, atol, rtol; allow_nothing::Bool) + for ((k1,v1), (k2,v2)) in zip(pairs(g_gpu), pairs(g_cpu)) + @test k1 == k2 + check_grad(v1, v2, atol, rtol; allow_nothing) + end +end + +check_type(x) = false +check_type(x::Float32) = true +check_type(x::CuArray{Float32}) = true +check_type(x::Array{Float32}) = true + +function gpu_autodiff_test( + f_cpu, xs_cpu::Array{Float32}...; + test_equal=true, rtol=1e-4, atol=1e-4, + checkgrad::Bool = true, allow_nothing::Bool = false, +) + # Compare CPU & GPU function outputs. + f_gpu = f_cpu |> gpu + xs_gpu = gpu.(xs_cpu) + + y_cpu = f_cpu(xs_cpu...) + y_gpu = f_gpu(xs_gpu...) + @test collect(y_cpu) ≈ collect(y_gpu) atol=atol rtol=rtol + + checkgrad || return + + ### GRADIENT WITH RESPECT TO INPUT ### + + y_cpu, back_cpu = pullback((x...) -> f_cpu(x...), xs_cpu...) + @test check_type(y_cpu) + Δ_cpu = size(y_cpu) == () ? randn(Float32) : randn(Float32, size(y_cpu)) + gs_cpu = back_cpu(Δ_cpu) + + Δ_gpu = Δ_cpu |> gpu + y_gpu, back_gpu = pullback((x...) -> f_gpu(x...), xs_gpu...) + @test check_type(y_gpu) + gs_gpu = back_gpu(Δ_gpu) + + if test_equal + @test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol + for (g_gpu, g_cpu) in zip(gs_gpu, gs_cpu) + check_grad(g_gpu, g_cpu, atol, rtol; allow_nothing) + end + end + + ### GRADIENT WITH RESPECT TO f ### + + ps_cpu = Flux.params(f_cpu) + y_cpu, back_cpu = pullback(() -> f_cpu(xs_cpu...), ps_cpu) + gs_cpu = back_cpu(Δ_cpu) + + ps_gpu = Flux.params(f_gpu) + y_gpu, back_gpu = pullback(() -> f_gpu(xs_gpu...), ps_gpu) + gs_gpu = back_gpu(Δ_gpu) + + if test_equal + @test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol + @assert length(ps_gpu) == length(ps_cpu) + for (p_gpu, p_cpu) in zip(ps_gpu, ps_cpu) + check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu], atol, rtol; allow_nothing) + end + end +end From 621829bed3f790aa4418d1ccfa29bc4056c8ff24 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 28 Feb 2023 13:19:22 +0200 Subject: [PATCH 15/15] Refactor --- ext/AMDGPUExt/functor.jl | 9 ++------- test/runtests.jl | 24 ++++++++++++------------ 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/ext/AMDGPUExt/functor.jl b/ext/AMDGPUExt/functor.jl index 3f7971aff0..797a55c110 100644 --- a/ext/AMDGPUExt/functor.jl +++ b/ext/AMDGPUExt/functor.jl @@ -18,10 +18,6 @@ adapt_storage(::FluxAMDAdaptor, x::AbstractRNG) = error(""" adapt_storage(::FluxCPUAdaptor, x::AMDGPU.rocRAND.RNG) = Random.default_rng() -function ChainRulesCore.rrule(::Type{Array}, x::ROCArray) - Array(x), dx -> (NoTangent(), ROCArray(unthunk(dx))) -end - function ChainRulesCore.rrule( ::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::AMDGPU.AnyROCArray, ) @@ -32,9 +28,8 @@ end function _amd(x) check_use_amdgpu() - USE_AMDGPU[] ? - fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_isleaf) : - x + USE_AMDGPU[] || return x + fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_isleaf) end # Since MIOpen supports only cross-correlation as convolution, diff --git a/test/runtests.jl b/test/runtests.jl index 5bd756324d..a2a8f66323 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,18 +63,18 @@ Random.seed!(0) end end - if get(ENV, "FLUX_TEST_AMDGPU", "false") == "true" - using AMDGPU - AMDGPU.versioninfo() - if AMDGPU.functional() && AMDGPU.functional(:MIOpen) - @show AMDGPU.MIOpen.version() - @testset "AMDGPU" begin - include("amd/runtests.jl") - end - else - @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." - end + if get(ENV, "FLUX_TEST_AMDGPU", "false") == "true" + using AMDGPU + AMDGPU.versioninfo() + if AMDGPU.functional() && AMDGPU.functional(:MIOpen) + @show AMDGPU.MIOpen.version() + @testset "AMDGPU" begin + include("amd/runtests.jl") + end else - @info "Skipping AMDGPU tests, set FLUX_TEST_CUDA=true to run them." + @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." end + else + @info "Skipping AMDGPU tests, set FLUX_TEST_AMDGPU=true to run them." + end end