From 8aa7d90de500e41f4cb51ddedf564d77bf23cbf0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 14:15:11 -0400 Subject: [PATCH 1/4] Mark certain operations as Enzyme inactive --- Project.toml | 6 ++++-- src/LuxLib.jl | 1 + src/api/dropout.jl | 4 ++++ src/api/groupnorm.jl | 19 +++---------------- src/api/instancenorm.jl | 1 + src/impl/groupnorm.jl | 20 ++++++++++++++++++-- src/impl/normalization.jl | 1 + src/utils.jl | 9 +++++++++ 8 files changed, 41 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index e81a41bd..e7cfde74 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.3.22" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" @@ -44,19 +45,20 @@ ArrayInterface = "7.9" CUDA = "5.3.2" ChainRulesCore = "1.23" ComponentArrays = "0.15.8" +EnzymeCore = "0.7" ExplicitImports = "1.4.1" FastBroadcast = "0.2.8" FastClosures = "0.3.2" ForwardDiff = "0.10.36" GPUArraysCore = "0.1.6" -KernelAbstractions = "0.9.15" +KernelAbstractions = "0.9.18" LinearAlgebra = "1.10" LuxAMDGPU = "0.2.1" LuxCUDA = "0.3.1" LuxCore = "0.1.13" LuxTestUtils = "0.1.15" Markdown = "1.10" -NNlib = "0.9.10" +NNlib = "0.9.13" PrecompileTools = "1.2" Random = "1.10" ReTestItems = "1.23.1" diff --git a/src/LuxLib.jl b/src/LuxLib.jl index 47dbdd2b..4895af17 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -5,6 +5,7 @@ using PrecompileTools: @recompile_invalidations @recompile_invalidations begin using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore, NoTangent + using EnzymeCore: EnzymeCore, EnzymeRules using FastBroadcast: @.. using FastClosures: @closure using GPUArraysCore: GPUArraysCore, AnyGPUArray diff --git a/src/api/dropout.jl b/src/api/dropout.jl index 21f9dbd5..ea4025ee 100644 --- a/src/api/dropout.jl +++ b/src/api/dropout.jl @@ -130,6 +130,7 @@ end @inline _dropout_fptype(x) = float(real(eltype(x))) CRC.@non_differentiable _dropout_fptype(::Any...) +EnzymeRules.inactive(::typeof(_dropout_fptype), ::Any...) = nothing @inline function _alpha_dropout_noise(rng, x) rng = LuxCore.replicate(rng) @@ -139,6 +140,7 @@ CRC.@non_differentiable _dropout_fptype(::Any...) end CRC.@non_differentiable _alpha_dropout_noise(::Any...) +EnzymeRules.inactive(::typeof(_alpha_dropout_noise), ::Any...) = nothing @inline function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) realfptype = _dropout_fptype(x) @@ -148,4 +150,6 @@ CRC.@non_differentiable _alpha_dropout_noise(::Any...) end CRC.@non_differentiable _generate_dropout_mask(::Any...) +EnzymeRules.inactive(::typeof(_generate_dropout_mask), ::Any...) = nothing CRC.@non_differentiable _dropout_shape(::Any...) +EnzymeRules.inactive(::typeof(_dropout_shape), ::Any...) = nothing diff --git a/src/api/groupnorm.jl b/src/api/groupnorm.jl index 3ed765f2..302ce081 100644 --- a/src/api/groupnorm.jl +++ b/src/api/groupnorm.jl @@ -45,12 +45,8 @@ function groupnorm(x::AbstractArray{<:Union{Float32, Float64}, 4}, groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F} _test_valid_groupnorm_arguments(x, scale, bias, groups) # FIXME: We need to fuse the activation function into the kernel for optimal performance - return fast_activation!!(σ, __fast_groupnorm(x, groups, scale, bias, epsilon)) -end - -# Separate this out for a cleaner rrule later on -@inline function __fast_groupnorm(x, groups, scale, bias, epsilon) - return first(_groupnorm(x, groups, scale, bias, epsilon)) + return fast_activation!!( + σ, __groupnorm_kernel_abstractions(x, groups, scale, bias, epsilon)) end # Slow Fallback (without custom Pullback Implementation) @@ -71,16 +67,6 @@ end return :($(Val(Tuple(collect(1:(N - 1)))))) end -# Custom Pullbacks -function CRC.rrule(::typeof(__fast_groupnorm), x, groups, scale, bias, epsilon) - y, μ, σ⁻¹ = _groupnorm(x, groups, scale, bias, epsilon) - ∇groupnorm = @closure Δ -> begin - ∂x, ∂scale, ∂bias = _∇groupnorm(Δ, y, x, groups, scale, bias, μ, σ⁻¹) - return NoTangent(), ∂x, NoTangent(), ∂scale, ∂bias, NoTangent() - end - return y, ∇groupnorm -end - function _test_valid_groupnorm_arguments( x::AbstractArray{T, N}, scale, bias, groups) where {T, N} _assert_same_backend(x, scale, bias) @@ -95,3 +81,4 @@ function _test_valid_groupnorm_arguments( end CRC.@non_differentiable _test_valid_groupnorm_arguments(::Any...) +EnzymeRules.inactive(::typeof(_test_valid_groupnorm_arguments), ::Any...) = nothing diff --git a/src/api/instancenorm.jl b/src/api/instancenorm.jl index d79ad234..9eee23ed 100644 --- a/src/api/instancenorm.jl +++ b/src/api/instancenorm.jl @@ -47,3 +47,4 @@ function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} end CRC.@non_differentiable _test_valid_instancenorm_arguments(::Any...) +EnzymeRules.inactive(::typeof(_test_valid_instancenorm_arguments), ::Any...) = nothing diff --git a/src/impl/groupnorm.jl b/src/impl/groupnorm.jl index 430223c6..03fc68db 100644 --- a/src/impl/groupnorm.jl +++ b/src/impl/groupnorm.jl @@ -44,7 +44,7 @@ end end # High-Level Function (Not User Facing) -@inbounds function _groupnorm( +@inbounds function _groupnorm_kernel_abstractions_impl( X::AbstractArray{TX, 4}, G::Int, γ::AbstractVector, β::AbstractVector, ϵ) where {TX} W, H, C, N = size(X) K = div(C, G) @@ -72,7 +72,7 @@ end return Y, μ, σ⁻¹ end -@inbounds function _∇groupnorm( +@inbounds function _∇groupnorm_kernel_abstractions_impl( dY::AbstractArray{T1, 4}, Y::AbstractArray{T2, 4}, X::AbstractArray{T3, 4}, G::Int, γ::AbstractVector, β::AbstractVector, μ::AbstractArray{T4, 5}, σ⁻¹::AbstractArray{T5, 5}) where {T1, T2, T3, T4, T5} @@ -111,3 +111,19 @@ end return dX, dγ, dβ end + +# Separate this out for a cleaner rrule later on +@inline function __groupnorm_kernel_abstractions(x, groups, scale, bias, epsilon) + return first(_groupnorm_kernel_abstractions_impl(x, groups, scale, bias, epsilon)) +end + +function CRC.rrule( + ::typeof(__groupnorm_kernel_abstractions), x, groups, scale, bias, epsilon) + y, μ, σ⁻¹ = _groupnorm_kernel_abstractions_impl(x, groups, scale, bias, epsilon) + ∇groupnorm = @closure Δ -> begin + ∂x, ∂scale, ∂bias = _∇groupnorm_kernel_abstractions_impl( + Δ, y, x, groups, scale, bias, μ, σ⁻¹) + return NoTangent(), ∂x, NoTangent(), ∂scale, ∂bias, NoTangent() + end + return y, ∇groupnorm +end diff --git a/src/impl/normalization.jl b/src/impl/normalization.jl index 7f47503b..2c5b4846 100644 --- a/src/impl/normalization.jl +++ b/src/impl/normalization.jl @@ -20,6 +20,7 @@ end @inline __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) CRC.@non_differentiable __accum_size(::Any...) +EnzymeRules.inactive(::typeof(__accum_size), ::Any...) = nothing @inline function _get_batch_statistics(x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val{false}, momentum) where {rdims} diff --git a/src/utils.jl b/src/utils.jl index 0b247eb2..8571241c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -20,6 +20,7 @@ function __check_all_same_or_nothing(x::Union{AbstractVector, Tuple}) end CRC.@non_differentiable _get_backend(::Any) +EnzymeRules.inactive(::typeof(_get_backend), ::Any...) = nothing @inline _assert_same_backend(args...) = _assert_same_backend([args...]) @inline function _assert_same_backend(xs) @@ -33,6 +34,7 @@ CRC.@non_differentiable _get_backend(::Any) end CRC.@non_differentiable _assert_same_backend(::Any...) +EnzymeRules.inactive(::typeof(_assert_same_backend), ::Any...) = nothing @inline @generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x @@ -47,6 +49,7 @@ CRC.@non_differentiable _assert_same_backend(::Any...) end CRC.@non_differentiable _get_reshape_dims(::Any...) +EnzymeRules.inactive(::typeof(_get_reshape_dims), ::Any...) = nothing @inline _reshape_into_proper_shape(::Nothing, y) = nothing @inline _reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) @@ -56,6 +59,7 @@ _copy_autodiff_barrier(x) = copy(x) _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) +EnzymeRules.inactive(::typeof(_copy_autodiff_barrier), ::Any...) = nothing # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector @@ -91,11 +95,13 @@ struct NotaNumber <: Real end @inline __is_immutable_array_val(x) = Val(__is_immutable_array(x)) CRC.@non_differentiable __is_immutable_array_val(::Any...) +EnzymeRules.inactive(::typeof(__is_immutable_array_val), ::Any...) = nothing @inline __has_dual(x) = false @inline __is_immutable_array_or_dual_val(x) = Val(__is_immutable_array(x) || __has_dual(x)) CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) +EnzymeRules.inactive(::typeof(__is_immutable_array_or_dual_val), ::Any...) = nothing @inline function __expand_conv_bias_dims( bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} @@ -117,6 +123,7 @@ end end CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) +EnzymeRules.inactive(::typeof(__get_concrete_fba_output_eltype), ::Any...) = nothing # Helper to add bias and apply activation function ## This is only meant to be used inside rrules @@ -209,6 +216,7 @@ end end CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) +EnzymeRules.inactive(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = nothing @inline function __reset_BLAS_threads(old_threads::Int) old_threads ≥ 1 && BLAS.set_num_threads(old_threads) @@ -216,6 +224,7 @@ CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) end CRC.@non_differentiable __reset_BLAS_threads(::Int) +EnzymeRules.inactive(::typeof(__reset_BLAS_threads), ::Int) = nothing # Defined in ext/LuxLibCUDAExt.jl function _cublaslt_matmul_fused! end From 2feccc908cee67d859f95d9ddec989000736bd1e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 16:01:41 -0400 Subject: [PATCH 2/4] Remove KA special handling --- Project.toml | 4 +- ext/LuxLibReverseDiffExt.jl | 2 - ext/LuxLibTrackerExt.jl | 13 ---- src/LuxLib.jl | 5 +- src/api/groupnorm.jl | 25 ------- src/impl/groupnorm.jl | 129 ------------------------------------ src/utils.jl | 38 ----------- test/groupnorm_tests.jl | 87 +++--------------------- 8 files changed, 12 insertions(+), 291 deletions(-) delete mode 100644 src/impl/groupnorm.jl diff --git a/Project.toml b/Project.toml index e7cfde74..8d37087e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.22" +version = "0.3.23" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -10,7 +10,6 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" -KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -51,7 +50,6 @@ FastBroadcast = "0.2.8" FastClosures = "0.3.2" ForwardDiff = "0.10.36" GPUArraysCore = "0.1.6" -KernelAbstractions = "0.9.18" LinearAlgebra = "1.10" LuxAMDGPU = "0.2.1" LuxCUDA = "0.3.1" diff --git a/ext/LuxLibReverseDiffExt.jl b/ext/LuxLibReverseDiffExt.jl index fc11d484..a1458ee1 100644 --- a/ext/LuxLibReverseDiffExt.jl +++ b/ext/LuxLibReverseDiffExt.jl @@ -21,8 +21,6 @@ end @grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedArray) @grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedReal) -LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(ReverseDiff.value(x)) - # api/dropout.jl LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(ReverseDiff.value(x)) diff --git a/ext/LuxLibTrackerExt.jl b/ext/LuxLibTrackerExt.jl index 9221afa0..69581325 100644 --- a/ext/LuxLibTrackerExt.jl +++ b/ext/LuxLibTrackerExt.jl @@ -41,20 +41,7 @@ function LuxLib._copy_autodiff_barrier(x::Union{TrackedArray, TrackedReal}) return LuxLib._copy_autodiff_barrier(Tracker.data(x)) end -LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(Tracker.data(x)) - # api/dropout.jl LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(Tracker.data(x)) -# api/groupnorm.jl -for T1 in (:TrackedArray, :AbstractArray), - T2 in (:TrackedVector, :AbstractVector), - T3 in (:TrackedVector, :AbstractVector) - - LuxLib.__is_tracked(T1, T2, T3) || continue - - @eval Tracker.@grad_from_chainrules LuxLib.__fast_groupnorm( - x::$T1, groups, scale::$T2, bias::$T3, epsilon::Real) -end - end diff --git a/src/LuxLib.jl b/src/LuxLib.jl index 4895af17..f12c7e52 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -9,25 +9,22 @@ using PrecompileTools: @recompile_invalidations using FastBroadcast: @.. using FastClosures: @closure using GPUArraysCore: GPUArraysCore, AnyGPUArray - using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore using Markdown: @doc_str using NNlib: NNlib using Random: Random, AbstractRNG, rand! using Reexport: @reexport - using Statistics: Statistics, mean, std, var + using Statistics: Statistics, mean, var end @reexport using NNlib const CRC = ChainRulesCore -const KA = KernelAbstractions include("utils.jl") # Low-Level Implementations -include("impl/groupnorm.jl") include("impl/normalization.jl") include("impl/fused_dense.jl") include("impl/fused_conv.jl") diff --git a/src/api/groupnorm.jl b/src/api/groupnorm.jl index 302ce081..b9ec0d51 100644 --- a/src/api/groupnorm.jl +++ b/src/api/groupnorm.jl @@ -21,35 +21,11 @@ statistics. The normalized array is returned. -## Performance Considerations - -The most common case of this Op -- `x` is a 4D array -- is optimized using -KernelAbstractions and has a fast custom backwards pass implemented. All other cases have a -fallback implementation which is not especially optimized. - -We have tested the code path for `Float16` and it works, but gradient accumulation is -extremely fragile. Hence, for `Float16` inputs, it uses the fallback implementation. - -If the batch size is small (< 16), then the fallback implementation will be faster than the -KA version. However, this customization is not possible using the direct `groupnorm` -interface. - ## References [1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018. """ -function groupnorm(x::AbstractArray{<:Union{Float32, Float64}, 4}, - scale::AbstractVector{<:Union{Float32, Float64}}, - bias::AbstractVector{<:Union{Float32, Float64}}, - groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F} - _test_valid_groupnorm_arguments(x, scale, bias, groups) - # FIXME: We need to fuse the activation function into the kernel for optimal performance - return fast_activation!!( - σ, __groupnorm_kernel_abstractions(x, groups, scale, bias, epsilon)) -end - -# Slow Fallback (without custom Pullback Implementation) function groupnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, bias::Union{Nothing, <:AbstractVector}, groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F, N} @@ -69,7 +45,6 @@ end function _test_valid_groupnorm_arguments( x::AbstractArray{T, N}, scale, bias, groups) where {T, N} - _assert_same_backend(x, scale, bias) if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ channels (N - 1 dim of the input array).")) diff --git a/src/impl/groupnorm.jl b/src/impl/groupnorm.jl deleted file mode 100644 index 03fc68db..00000000 --- a/src/impl/groupnorm.jl +++ /dev/null @@ -1,129 +0,0 @@ -# Low-Level Kernels -## Original Implementation: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/group_norm_op.cu -@kernel function _compute_fused_params_kernel!( - scale, bias, @Const(C), @Const(K), @Const(μ), @Const(σ⁻¹), @Const(γ), @Const(β)) - idx = @index(Global) - ng = _div_idx(idx, K) - c = _mod_idx(idx, C) - - @inbounds scale_val = γ[c] * σ⁻¹[ng] - @inbounds scale[idx] = scale_val - @inbounds bias[idx] = β[c] - μ[ng] * scale_val -end - -@kernel function _groupnorm_forward_kernel!( - Y, @Const(WxH), @Const(X), @Const(scale), @Const(bias)) - idx = @index(Global) - nc = _div_idx(idx, WxH) - @inbounds Y[idx] = X[idx] * scale[nc] + bias[nc] -end - -@kernel function _groupnorm_dy_dscale_kernel!( - dY_dscale, @Const(C), @Const(K), @Const(σ⁻¹), @Const(γ)) - idx = @index(Global) - ng = _div_idx(idx, K) - c = _mod_idx(idx, C) - - @inbounds dY_dscale[idx] = γ[c] * σ⁻¹[ng] -end - -@kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), @Const(μ), - @Const(σ⁻¹), @Const(ds_sum), @Const(db_sum)) - idx = @index(Global) - @inbounds x = (db_sum[idx] * μ[idx] - ds_sum[idx]) * (σ⁻¹[idx]^3) * alpha - @inbounds X_scale[idx] = x - @inbounds bias[idx] = -(x * μ[idx] + db_sum[idx] * σ⁻¹[idx] * alpha) -end - -@kernel function _groupnorm_dx_kernel!(dX, @Const(WxH), @Const(K), @Const(dY_dscale), - @Const(dY), @Const(X_scale), @Const(X), @Const(bias)) - idx = @index(Global) - nc = _div_idx(idx, WxH) - ng = _div_idx(nc, K) - @inbounds dX[idx] = dY[idx] * dY_dscale[nc] + X_scale[ng] * X[idx] + bias[ng] -end - -# High-Level Function (Not User Facing) -@inbounds function _groupnorm_kernel_abstractions_impl( - X::AbstractArray{TX, 4}, G::Int, γ::AbstractVector, β::AbstractVector, ϵ) where {TX} - W, H, C, N = size(X) - K = div(C, G) - - X_reshaped = reshape(X, (W, H, K, G, N)) - μ = mean(X_reshaped; dims=(1, 2, 3)) - σ⁻¹ = 1 ./ (std(X_reshaped; mean=μ, dims=(1, 2, 3), corrected=false) .+ ϵ) - - T = promote_type(eltype(μ), eltype(σ⁻¹), eltype(γ), eltype(β)) - _scale = similar(X, T, (C, N)) - _bias = similar(X, T, (C, N)) - Y = similar(X, T) - - backend = KA.get_backend(X) - - compute_fixed_params! = _compute_fused_params_kernel!(backend) - groupnorm_forward! = _groupnorm_forward_kernel!(backend) - - compute_fixed_params!(_scale, _bias, C, K, μ, σ⁻¹, γ, β; ndrange=size(_scale)) - KA.synchronize(backend) - - groupnorm_forward!(Y, W * H, X, _scale, _bias; ndrange=size(Y)) - KA.synchronize(backend) - - return Y, μ, σ⁻¹ -end - -@inbounds function _∇groupnorm_kernel_abstractions_impl( - dY::AbstractArray{T1, 4}, Y::AbstractArray{T2, 4}, X::AbstractArray{T3, 4}, - G::Int, γ::AbstractVector, β::AbstractVector, μ::AbstractArray{T4, 5}, - σ⁻¹::AbstractArray{T5, 5}) where {T1, T2, T3, T4, T5} - W, H, C, N = size(X) - K = div(C, G) - WxH = W * H - backend = KA.get_backend(X) - - dbias = reshape(sum(dY; dims=(1, 2)), (1, 1, K, G, N)) - dscale = reshape(sum(X .* dY; dims=(1, 2)), (1, 1, K, G, N)) - - dY_dscale = similar(X, promote_type(eltype(σ⁻¹), eltype(γ)), (C, N)) - groupnorm_dy_dscale! = _groupnorm_dy_dscale_kernel!(backend) - groupnorm_dy_dscale!(dY_dscale, C, K, σ⁻¹, γ; ndrange=size(dY_dscale)) - - γ_ = reshape(γ, (1, 1, K, G, 1)) - db_sum = sum(γ_ .* dbias; dims=3) - ds_sum = sum(γ_ .* dscale; dims=3) - KA.synchronize(backend) - - T = promote_type(eltype(μ), eltype(σ⁻¹), eltype(ds_sum), eltype(db_sum)) - X_scale = similar(X, T, (G, N)) - bias = similar(X, T, (G, N)) - - groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(backend) - groupnorm_xscale_and_bias!( - X_scale, bias, T(1 / (K * WxH)), μ, σ⁻¹, ds_sum, db_sum; ndrange=size(X_scale)) - KA.synchronize(backend) - - dX = similar(X) - groupnorm_dx! = _groupnorm_dx_kernel!(backend) - groupnorm_dx!(dX, WxH, K, dY_dscale, dY, X_scale, X, bias; ndrange=size(dX)) - dγ = vec(sum((-dbias .* μ .+ dscale) .* σ⁻¹; dims=5)) - dβ = vec(sum(dbias; dims=5)) - KA.synchronize(backend) - - return dX, dγ, dβ -end - -# Separate this out for a cleaner rrule later on -@inline function __groupnorm_kernel_abstractions(x, groups, scale, bias, epsilon) - return first(_groupnorm_kernel_abstractions_impl(x, groups, scale, bias, epsilon)) -end - -function CRC.rrule( - ::typeof(__groupnorm_kernel_abstractions), x, groups, scale, bias, epsilon) - y, μ, σ⁻¹ = _groupnorm_kernel_abstractions_impl(x, groups, scale, bias, epsilon) - ∇groupnorm = @closure Δ -> begin - ∂x, ∂scale, ∂bias = _∇groupnorm_kernel_abstractions_impl( - Δ, y, x, groups, scale, bias, μ, σ⁻¹) - return NoTangent(), ∂x, NoTangent(), ∂scale, ∂bias, NoTangent() - end - return y, ∇groupnorm -end diff --git a/src/utils.jl b/src/utils.jl index 8571241c..a6264a11 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,41 +1,3 @@ -# Utilities -@inline _div_idx(idx, n) = div(idx - 1, n) + 1 -@inline _mod_idx(idx, n) = mod(idx - 1, n) + 1 - -@inline _get_backend(::Nothing) = nothing -@inline function _get_backend(d) - return hasmethod(KA.get_backend, (typeof(d),)) ? KA.get_backend(d) : nothing -end -@inline _get_backend(t::Tuple) = _get_backend.(t) - -function __check_all_same_or_nothing(x::Union{AbstractVector, Tuple}) - @inbounds for i in eachindex(x) - x[i] === nothing && continue - for j in (i + 1):length(x) - x[j] === nothing && continue - x[i] != x[j] && return false - end - end - return true -end - -CRC.@non_differentiable _get_backend(::Any) -EnzymeRules.inactive(::typeof(_get_backend), ::Any...) = nothing - -@inline _assert_same_backend(args...) = _assert_same_backend([args...]) -@inline function _assert_same_backend(xs) - devs = _get_backend.(xs) - if !__check_all_same_or_nothing(devs) - throw(ArgumentError("All arguments must be on the same backend. This error is \ - encountered if you are calling a function with a mix of CPU \ - and GPU arrays.")) - end - return -end - -CRC.@non_differentiable _assert_same_backend(::Any...) -EnzymeRules.inactive(::typeof(_assert_same_backend), ::Any...) = nothing - @inline @generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x @inline @inbounds function _get_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} diff --git a/test/groupnorm_tests.jl b/test/groupnorm_tests.jl index b18a9b59..a5b070f7 100644 --- a/test/groupnorm_tests.jl +++ b/test/groupnorm_tests.jl @@ -1,85 +1,18 @@ -@testsetup module GroupNormSetup -using LuxLib - -@inline __generate_fixed_array(::Type{T}, sz...) where {T} = __generate_fixed_array(T, sz) -@inline function __generate_fixed_array(::Type{T}, sz) where {T} - return reshape(T.(collect(1:prod(sz)) ./ prod(sz)), sz...) -end -@inline __generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) - -function _setup_groupnorm(aType, T, sz, groups) - x = __generate_fixed_array(T, sz) |> aType - scale = __generate_fixed_array(T, sz[end - 1]) |> aType - bias = __generate_fixed_array(T, sz[end - 1]) |> aType - return x, scale, bias -end - -function _groupnorm_generic_fallback(x, scale, bias, epsilon, groups, act) - sz = size(x) - N = ndims(x) - x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_, xmean, xvar = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, - Val(Tuple(collect(1:(N - 1)))), Val(false), nothing, epsilon, act) - - return reshape(x_, sz) -end - -export _setup_groupnorm, _groupnorm_generic_fallback -end - -@testitem "Group Normalization KernelAbstractions" tags=[:singleworker, :normalization] setup=[ - SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, ngroups $groups, $act" for T in (Float32, Float64), - sz in ((4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), - groups in (2, 3), - act in (identity, relu, tanh_fast, sigmoid_fast, x -> gelu(x)) - - _f = (args...) -> groupnorm(args..., act; groups, epsilon) - - epsilon = T(1e-5) - x, scale, bias = _setup_groupnorm(aType, T, sz, groups) - - y = _f(x, scale, bias) - - gs_x, gs_scale, gs_bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - - @inferred groupnorm(x, scale, bias, act; groups, epsilon) - - # Stresses CI too much - T !== Float16 && @jet groupnorm(x, scale, bias, act; groups, epsilon) - - @test y isa aType{T, length(sz)} - @test size(y) == sz - - # Use the generic implementation to compare against - __f = (args...) -> _groupnorm_generic_fallback(args..., epsilon, groups, act) - - y_ = __f(x, scale, bias) - - gs_x_, gs_scale_, gs_bias_ = Zygote.gradient(sum ∘ __f, x, scale, bias) - - # The KA implementation reorders operations manually for maximal - # performance. Hence equality cannot be guaranteed. - @test check_approx(y, y_; atol=1.0f-1, rtol=1.0f-1) - @test check_approx(gs_x, gs_x_; atol=1.0f-1, rtol=1.0f-1) - @test check_approx(gs_scale, gs_scale_; atol=1.0f-1, rtol=1.0f-1) - @test check_approx(gs_bias, gs_bias_; atol=1.0f-1, rtol=1.0f-1) - - fp16 = T == Float16 - __f = (args...) -> sum(groupnorm(x, args..., act; groups, epsilon)) - skip_fd = act === relu - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 skip_finite_differences=$(skip_fd) - end +@testitem "Group Normalization" tags=[:singleworker, :normalization] setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + function _setup_groupnorm(aType, T, sz, groups) + x = __generate_fixed_array(T, sz) |> aType + scale = __generate_fixed_array(T, sz[end - 1]) |> aType + bias = __generate_fixed_array(T, sz[end - 1]) |> aType + return x, scale, bias end -end -@testitem "Group Normalization Generic Fallback" tags=[:singleworker, :normalization] setup=[ - SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups, $act" for T in ( Float16, Float32, Float64), - sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), + sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), + (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), groups in (2, 3), act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) From 1a64b73d801af75e65eef56d68f0e5c92444cf43 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 17:25:32 -0400 Subject: [PATCH 3/4] Try removing the EnzymeRules inactive --- src/api/dropout.jl | 4 ---- src/api/groupnorm.jl | 1 - src/api/instancenorm.jl | 1 - src/impl/normalization.jl | 1 - src/utils.jl | 9 ++------- 5 files changed, 2 insertions(+), 14 deletions(-) diff --git a/src/api/dropout.jl b/src/api/dropout.jl index ea4025ee..21f9dbd5 100644 --- a/src/api/dropout.jl +++ b/src/api/dropout.jl @@ -130,7 +130,6 @@ end @inline _dropout_fptype(x) = float(real(eltype(x))) CRC.@non_differentiable _dropout_fptype(::Any...) -EnzymeRules.inactive(::typeof(_dropout_fptype), ::Any...) = nothing @inline function _alpha_dropout_noise(rng, x) rng = LuxCore.replicate(rng) @@ -140,7 +139,6 @@ EnzymeRules.inactive(::typeof(_dropout_fptype), ::Any...) = nothing end CRC.@non_differentiable _alpha_dropout_noise(::Any...) -EnzymeRules.inactive(::typeof(_alpha_dropout_noise), ::Any...) = nothing @inline function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) realfptype = _dropout_fptype(x) @@ -150,6 +148,4 @@ EnzymeRules.inactive(::typeof(_alpha_dropout_noise), ::Any...) = nothing end CRC.@non_differentiable _generate_dropout_mask(::Any...) -EnzymeRules.inactive(::typeof(_generate_dropout_mask), ::Any...) = nothing CRC.@non_differentiable _dropout_shape(::Any...) -EnzymeRules.inactive(::typeof(_dropout_shape), ::Any...) = nothing diff --git a/src/api/groupnorm.jl b/src/api/groupnorm.jl index b9ec0d51..40f4637d 100644 --- a/src/api/groupnorm.jl +++ b/src/api/groupnorm.jl @@ -56,4 +56,3 @@ function _test_valid_groupnorm_arguments( end CRC.@non_differentiable _test_valid_groupnorm_arguments(::Any...) -EnzymeRules.inactive(::typeof(_test_valid_groupnorm_arguments), ::Any...) = nothing diff --git a/src/api/instancenorm.jl b/src/api/instancenorm.jl index 9eee23ed..d79ad234 100644 --- a/src/api/instancenorm.jl +++ b/src/api/instancenorm.jl @@ -47,4 +47,3 @@ function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} end CRC.@non_differentiable _test_valid_instancenorm_arguments(::Any...) -EnzymeRules.inactive(::typeof(_test_valid_instancenorm_arguments), ::Any...) = nothing diff --git a/src/impl/normalization.jl b/src/impl/normalization.jl index 2c5b4846..7f47503b 100644 --- a/src/impl/normalization.jl +++ b/src/impl/normalization.jl @@ -20,7 +20,6 @@ end @inline __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) CRC.@non_differentiable __accum_size(::Any...) -EnzymeRules.inactive(::typeof(__accum_size), ::Any...) = nothing @inline function _get_batch_statistics(x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val{false}, momentum) where {rdims} diff --git a/src/utils.jl b/src/utils.jl index a6264a11..e6c4b8b9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,7 +11,6 @@ end CRC.@non_differentiable _get_reshape_dims(::Any...) -EnzymeRules.inactive(::typeof(_get_reshape_dims), ::Any...) = nothing @inline _reshape_into_proper_shape(::Nothing, y) = nothing @inline _reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) @@ -21,7 +20,6 @@ _copy_autodiff_barrier(x) = copy(x) _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) -EnzymeRules.inactive(::typeof(_copy_autodiff_barrier), ::Any...) = nothing # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector @@ -57,13 +55,11 @@ struct NotaNumber <: Real end @inline __is_immutable_array_val(x) = Val(__is_immutable_array(x)) CRC.@non_differentiable __is_immutable_array_val(::Any...) -EnzymeRules.inactive(::typeof(__is_immutable_array_val), ::Any...) = nothing @inline __has_dual(x) = false @inline __is_immutable_array_or_dual_val(x) = Val(__is_immutable_array(x) || __has_dual(x)) CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) -EnzymeRules.inactive(::typeof(__is_immutable_array_or_dual_val), ::Any...) = nothing @inline function __expand_conv_bias_dims( bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} @@ -85,7 +81,6 @@ end end CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) -EnzymeRules.inactive(::typeof(__get_concrete_fba_output_eltype), ::Any...) = nothing # Helper to add bias and apply activation function ## This is only meant to be used inside rrules @@ -178,7 +173,7 @@ end end CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) -EnzymeRules.inactive(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = nothing +EnzymeRules.inactive_noinl(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = nothing @inline function __reset_BLAS_threads(old_threads::Int) old_threads ≥ 1 && BLAS.set_num_threads(old_threads) @@ -186,7 +181,7 @@ EnzymeRules.inactive(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = n end CRC.@non_differentiable __reset_BLAS_threads(::Int) -EnzymeRules.inactive(::typeof(__reset_BLAS_threads), ::Int) = nothing +EnzymeRules.inactive_noinl(::typeof(__reset_BLAS_threads), ::Int) = nothing # Defined in ext/LuxLibCUDAExt.jl function _cublaslt_matmul_fused! end From bb79996bda6d813ab15b54de688ac57a083ff14e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 17:28:58 -0400 Subject: [PATCH 4/4] Revert "Try removing the EnzymeRules inactive" This reverts commit 1a64b73d801af75e65eef56d68f0e5c92444cf43. --- src/api/dropout.jl | 4 ++++ src/api/groupnorm.jl | 1 + src/api/instancenorm.jl | 1 + src/impl/normalization.jl | 1 + src/utils.jl | 5 +++++ 5 files changed, 12 insertions(+) diff --git a/src/api/dropout.jl b/src/api/dropout.jl index 21f9dbd5..44a95ec2 100644 --- a/src/api/dropout.jl +++ b/src/api/dropout.jl @@ -130,6 +130,7 @@ end @inline _dropout_fptype(x) = float(real(eltype(x))) CRC.@non_differentiable _dropout_fptype(::Any...) +EnzymeRules.inactive_noinl(::typeof(_dropout_fptype), ::Any...) = nothing @inline function _alpha_dropout_noise(rng, x) rng = LuxCore.replicate(rng) @@ -139,6 +140,7 @@ CRC.@non_differentiable _dropout_fptype(::Any...) end CRC.@non_differentiable _alpha_dropout_noise(::Any...) +EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing @inline function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) realfptype = _dropout_fptype(x) @@ -148,4 +150,6 @@ CRC.@non_differentiable _alpha_dropout_noise(::Any...) end CRC.@non_differentiable _generate_dropout_mask(::Any...) +EnzymeRules.inactive_noinl(::typeof(_generate_dropout_mask), ::Any...) = nothing CRC.@non_differentiable _dropout_shape(::Any...) +EnzymeRules.inactive_noinl(::typeof(_dropout_shape), ::Any...) = nothing diff --git a/src/api/groupnorm.jl b/src/api/groupnorm.jl index 40f4637d..509e72f0 100644 --- a/src/api/groupnorm.jl +++ b/src/api/groupnorm.jl @@ -56,3 +56,4 @@ function _test_valid_groupnorm_arguments( end CRC.@non_differentiable _test_valid_groupnorm_arguments(::Any...) +EnzymeRules.inactive_noinl(::typeof(_test_valid_groupnorm_arguments), ::Any...) = nothing diff --git a/src/api/instancenorm.jl b/src/api/instancenorm.jl index d79ad234..36b14424 100644 --- a/src/api/instancenorm.jl +++ b/src/api/instancenorm.jl @@ -47,3 +47,4 @@ function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} end CRC.@non_differentiable _test_valid_instancenorm_arguments(::Any...) +EnzymeRules.inactive_noinl(::typeof(_test_valid_instancenorm_arguments), ::Any...) = nothing diff --git a/src/impl/normalization.jl b/src/impl/normalization.jl index 7f47503b..467821a7 100644 --- a/src/impl/normalization.jl +++ b/src/impl/normalization.jl @@ -20,6 +20,7 @@ end @inline __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) CRC.@non_differentiable __accum_size(::Any...) +EnzymeRules.inactive_noinl(::typeof(__accum_size), ::Any...) = nothing @inline function _get_batch_statistics(x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val{false}, momentum) where {rdims} diff --git a/src/utils.jl b/src/utils.jl index e6c4b8b9..c5e592fb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,6 +11,7 @@ end CRC.@non_differentiable _get_reshape_dims(::Any...) +EnzymeRules.inactive_noinl(::typeof(_get_reshape_dims), ::Any...) = nothing @inline _reshape_into_proper_shape(::Nothing, y) = nothing @inline _reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) @@ -20,6 +21,7 @@ _copy_autodiff_barrier(x) = copy(x) _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) +EnzymeRules.inactive_noinl(::typeof(_copy_autodiff_barrier), ::Any...) = nothing # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector @@ -55,11 +57,13 @@ struct NotaNumber <: Real end @inline __is_immutable_array_val(x) = Val(__is_immutable_array(x)) CRC.@non_differentiable __is_immutable_array_val(::Any...) +EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_val), ::Any...) = nothing @inline __has_dual(x) = false @inline __is_immutable_array_or_dual_val(x) = Val(__is_immutable_array(x) || __has_dual(x)) CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) +EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_or_dual_val), ::Any...) = nothing @inline function __expand_conv_bias_dims( bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} @@ -81,6 +85,7 @@ end end CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) +EnzymeRules.inactive_noinl(::typeof(__get_concrete_fba_output_eltype), ::Any...) = nothing # Helper to add bias and apply activation function ## This is only meant to be used inside rrules