diff --git a/Project.toml b/Project.toml index cbd243e4a..8ea3da8f0 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.9.6" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -16,26 +17,24 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] NNlibAMDGPUExt = "AMDGPU" -NNlibCUDAExt = "CUDA" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] -NNlibEnzymeCoreExt = "EnzymeCore" +NNlibCUDAExt = "CUDA" [compat] AMDGPU = "0.5, 0.6" Adapt = "3.2" Atomix = "0.1" -ChainRulesCore = "1.13" CUDA = "4, 5" -cuDNN = "1" +ChainRulesCore = "1.13" GPUArraysCore = "0.1" KernelAbstractions = "0.9.2" Requires = "1.0" +cuDNN = "1" julia = "1.9" [extras] @@ -44,6 +43,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -55,6 +55,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", - "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", - "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme"] +test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeTestUtils"] diff --git a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl index 874764374..3d109fa06 100644 --- a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl +++ b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl @@ -1,14 +1,12 @@ -module NNlibEnzymeExt +module NNlibEnzymeCoreExt using NNlib -isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) +isdefined(Base, :get_extension) ? (import EnzymeCore) : (import ..EnzymeCore) -using EnzymeCore +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} -function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} - - @assert !(OutType <: Const) - if OutType <: Duplicated || OutType <: DuplicatedNoNeed + @assert !(OutType <: EnzymeCore.Const) + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed func.val(y.val, x.val, w.val, cdims.val; kwargs...) end @@ -24,36 +22,36 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNli end # Cache x if its overwritten and w is active (and thus required) - cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: Const) ) ? copy(x.val) : nothing + cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: EnzymeCore.Const) ) ? copy(x.val) : nothing # Cache w if its overwritten and x is active (and thus required) - cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: Const) ) ? copy(w.val) : nothing + cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: EnzymeCore.Const) ) ? copy(w.val) : nothing cache = (cache_x, cache_w) return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) end -function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} cache_x, cache_w = cache # Don't cache x if not overwritten and w is active (and thus required) - if !(typeof(w) <: Const) + if !(typeof(w) <: EnzymeCore.Const) if !EnzymeCore.EnzymeRules.overwritten(config)[3] cache_x = x.val end end # Don't cache w if not overwritten and x is active (and thus required) - if !(typeof(x) <: Const) + if !(typeof(x) <: EnzymeCore.Const) if !EnzymeCore.EnzymeRules.overwritten(config)[4] cache_w = w.val end end dys = y.dval - dxs = (typeof(x) <: Const) ? dys : x.dval - dws = (typeof(w) <: Const) ? dys : w.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval if EnzymeCore.EnzymeRules.width(config) == 1 dys = (dys,) @@ -62,11 +60,11 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)} end for (dy, dx, dw) in zip(dys, dxs, dws) - if !(typeof(x) <: Const) && dx !== x + if !(typeof(x) <: EnzymeCore.Const) && dx !== x # dx += grad wrt x NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) end - if !(typeof(w) <: Const) && dw !== w + if !(typeof(w) <: EnzymeCore.Const) && dw !== w # dw += grad wrt w NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) end diff --git a/src/NNlib.jl b/src/NNlib.jl index c4ad18750..14ba2c70f 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -123,10 +123,6 @@ include("impl/depthwiseconv_im2col.jl") include("impl/pooling_direct.jl") include("deprecations.jl") -@init @static if !isdefined(Base, :get_extension) - @require EnzymeCore="f151be2c-9106-41f4-ab19-57ee4f262869" begin - include("../ext/NNlibEnzymeCoreExt/NNlibEnzymeCoresExt.jl") - end -end +include("enzyme.jl") end # module NNlib diff --git a/test/conv.jl b/test/conv.jl index 8edc4bf24..8b67873ae 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -861,7 +861,7 @@ end w = rand(rng, repeat([3], spatial_rank)..., 3, 3) cdims = DenseConvDims(x, w) gradtest((x, w) -> conv(x, w, cdims), x, w) - gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055 + gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rule=true) # https://github.com/FluxML/Flux.jl/issues/1055 y = conv(x, w, cdims) gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) diff --git a/test/runtests.jl b/test/runtests.jl index 03602a40d..660f9167f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using NNlib, Test, Statistics, Random using ChainRulesCore, ChainRulesTestUtils using Base.Broadcast: broadcasted +import EnzymeTestUtils import FiniteDifferences import ForwardDiff import Zygote diff --git a/test/test_utils.jl b/test/test_utils.jl index 16b3998dc..598f0f66b 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -12,7 +12,7 @@ Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly """ function gradtest( f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(), - check_rrule = false, fdm = :central, check_broadcast = false, + check_rrule = false, check_enzyme_rrule = false, fdm = :central, check_broadcast = false, skip = false, broken = false, ) # TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166 @@ -20,6 +20,22 @@ function gradtest( if check_rrule test_rrule(f, xs...; fkwargs = fkwargs) end + if check_enzyme_rrule + if len(xs) == 2 + for Tret in (Const, Active), + Tx in (Const, Duplicated, BatchDuplicated), + Ty in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Tx, Ty) || continue + + test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol) + end + else + throw(AssertionError("Unsupported arg count for testing")) + end + + EnzymeTestUtils.test_rrule(f, xs...; fkwargs = fkwargs) + end if check_broadcast length(fkwargs) > 0 && @warn("CHECK_BROADCAST: dropping keywords args")