Skip to content

Commit

Permalink
attempt fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 24, 2023
1 parent cb92045 commit 242e669
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 32 deletions.
16 changes: 7 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.9.6"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -16,26 +17,24 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
NNlibAMDGPUExt = "AMDGPU"
NNlibCUDAExt = "CUDA"
NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibCUDAExt = "CUDA"

[compat]
AMDGPU = "0.5, 0.6"
Adapt = "3.2"
Atomix = "0.1"
ChainRulesCore = "1.13"
CUDA = "4, 5"
cuDNN = "1"
ChainRulesCore = "1.13"
GPUArraysCore = "0.1"
KernelAbstractions = "0.9.2"
Requires = "1.0"
cuDNN = "1"
julia = "1.9"

[extras]
Expand All @@ -44,6 +43,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand All @@ -55,6 +55,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter",
"FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff",
"StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme"]
test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeTestUtils"]
30 changes: 14 additions & 16 deletions ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
module NNlibEnzymeExt
module NNlibEnzymeCoreExt

using NNlib
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)
isdefined(Base, :get_extension) ? (import EnzymeCore) : (import ..EnzymeCore)

using EnzymeCore
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT}

function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT}

@assert !(OutType <: Const)
if OutType <: Duplicated || OutType <: DuplicatedNoNeed
@assert !(OutType <: EnzymeCore.Const)
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed
func.val(y.val, x.val, w.val, cdims.val; kwargs...)
end

Expand All @@ -24,36 +22,36 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNli
end

# Cache x if its overwritten and w is active (and thus required)
cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: Const) ) ? copy(x.val) : nothing
cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: EnzymeCore.Const) ) ? copy(x.val) : nothing

# Cache w if its overwritten and x is active (and thus required)
cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: Const) ) ? copy(w.val) : nothing
cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: EnzymeCore.Const) ) ? copy(w.val) : nothing

cache = (cache_x, cache_w)

return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT}
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT}
cache_x, cache_w = cache

# Don't cache x if not overwritten and w is active (and thus required)
if !(typeof(w) <: Const)
if !(typeof(w) <: EnzymeCore.Const)
if !EnzymeCore.EnzymeRules.overwritten(config)[3]
cache_x = x.val
end
end

# Don't cache w if not overwritten and x is active (and thus required)
if !(typeof(x) <: Const)
if !(typeof(x) <: EnzymeCore.Const)
if !EnzymeCore.EnzymeRules.overwritten(config)[4]
cache_w = w.val
end
end

dys = y.dval
dxs = (typeof(x) <: Const) ? dys : x.dval
dws = (typeof(w) <: Const) ? dys : w.dval
dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval
dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval

if EnzymeCore.EnzymeRules.width(config) == 1
dys = (dys,)
Expand All @@ -62,11 +60,11 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}
end

for (dy, dx, dw) in zip(dys, dxs, dws)
if !(typeof(x) <: Const) && dx !== x
if !(typeof(x) <: EnzymeCore.Const) && dx !== x
# dx += grad wrt x
NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
end
if !(typeof(w) <: Const) && dw !== w
if !(typeof(w) <: EnzymeCore.Const) && dw !== w
# dw += grad wrt w
NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
end
Expand Down
6 changes: 1 addition & 5 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,6 @@ include("impl/depthwiseconv_im2col.jl")
include("impl/pooling_direct.jl")
include("deprecations.jl")

@init @static if !isdefined(Base, :get_extension)
@require EnzymeCore="f151be2c-9106-41f4-ab19-57ee4f262869" begin
include("../ext/NNlibEnzymeCoreExt/NNlibEnzymeCoresExt.jl")
end
end
include("enzyme.jl")

end # module NNlib
2 changes: 1 addition & 1 deletion test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ end
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)
gradtest((x, w) -> conv(x, w, cdims), x, w)
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rule=true) # https://github.com/FluxML/Flux.jl/issues/1055

y = conv(x, w, cdims)
gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using NNlib, Test, Statistics, Random
using ChainRulesCore, ChainRulesTestUtils
using Base.Broadcast: broadcasted
import EnzymeTestUtils
import FiniteDifferences
import ForwardDiff
import Zygote
Expand Down
18 changes: 17 additions & 1 deletion test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,30 @@ Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly
"""
function gradtest(
f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(),
check_rrule = false, fdm = :central, check_broadcast = false,
check_rrule = false, check_enzyme_rrule = false, fdm = :central, check_broadcast = false,
skip = false, broken = false,
)
# TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166
# is merged
if check_rrule
test_rrule(f, xs...; fkwargs = fkwargs)
end
if check_enzyme_rrule
if len(xs) == 2
for Tret in (Const, Active),
Tx in (Const, Duplicated, BatchDuplicated),
Ty in (Const, Duplicated, BatchDuplicated)

are_activities_compatible(Tret, Tx, Ty) || continue

test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol)
end
else
throw(AssertionError("Unsupported arg count for testing"))
end

EnzymeTestUtils.test_rrule(f, xs...; fkwargs = fkwargs)
end

if check_broadcast
length(fkwargs) > 0 && @warn("CHECK_BROADCAST: dropping keywords args")
Expand Down

0 comments on commit 242e669

Please sign in to comment.