diff --git a/.gitignore b/.gitignore index 78756acf1..aa5ffbd93 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *.jl.mem docs/build Manifest.toml +dev/ diff --git a/Project.toml b/Project.toml index 32849f933..0bdee524c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.10" +version = "0.6.11" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" ChainRules = "0.7.55" -ChainRulesCore = "0.9.32" +ChainRulesCore = "0.9.44" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11" ForwardDiff = "0.10" diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index c573d98f7..8392a27ce 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -47,7 +47,7 @@ for T_outer in (:Tuple, :NamedTuple) # we create separate methods rather than using a `Union` + an `if` so that we avoid a # branch that changes output type, because nested AD on that kinda thing makes Zygote less # than happy. - @eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer} + @eval @inline function wrap_chainrules_output(x::ChainRules.Tangent{P, T}) where {P, T<:$T_outer} xp = map(wrap_chainrules_output, canonicalize(x)) convert($T_outer, xp) end @@ -59,10 +59,10 @@ end Convert `x` from the format Zygote uses internally to differentials types ChainRules uses. """ @inline wrap_chainrules_input(x) = x -@inline wrap_chainrules_input(::Nothing) = ChainRules.Zero() +@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent() @inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) xp = map(wrap_chainrules_input, xs) - ChainRules.Composite{Any, typeof(xp)}(xp) + ChainRules.Tangent{Any, typeof(xp)}(xp) end """ diff --git a/test/chainrules.jl b/test/chainrules.jl index 7fd8c6be5..8b7034753 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -131,7 +131,7 @@ using Zygote, Test, ChainRules not_diff_eg(x, i) = [10, 20][i] function ChainRules.rrule(::typeof(not_diff_eg), x, i) function not_diff_eg_pullback(Δ) - return ChainRules.NO_FIELDS, ChainRules.Zero(), ChainRules.DoesNotExist() + return ChainRules.NO_FIELDS, ChainRules.ZeroTangent(), ChainRules.NoTangent() end return not_diff_eg(x, i), not_diff_eg_pullback end @@ -204,7 +204,7 @@ using Zygote, Test, ChainRules not_diff_kw_eg(x, i; kw=1.0) = [10, 20][i] function ChainRules.rrule(::typeof(not_diff_kw_eg), x, i; kwargs...) function not_diff_kw_eg_pullback(Δ) - return ChainRules.NO_FIELDS, ChainRules.Zero(), ChainRules.DoesNotExist() + return ChainRules.NO_FIELDS, ChainRules.ZeroTangent(), ChainRules.NoTangent() end return not_diff_kw_eg(x, i; kwargs...), not_diff_kw_eg_pullback end