diff --git a/Project.toml b/Project.toml index 5261e5114..8d0bf8ec4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.54" +version = "0.6.55" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 7c7de8655..a19d7f230 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -18,7 +18,10 @@ such that if a suitable rule is defined later, the generated function will recom function has_chain_rrule(T) config_T, arg_Ts = Iterators.peel(T.parameters) configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...}) - if _is_rrule_redispatcher(configured_rrule_m.method) + is_ambig = configured_rrule_m === nothing # this means there was an ambiguity error, on configured_rrule + + + if !is_ambig && _is_rrule_redispatcher(configured_rrule_m.method) # The config is not being used: # it is being redispatched without config, so we need the method it redispatches to rrule_m = meta(Tuple{typeof(rrule), arg_Ts...}) @@ -33,6 +36,8 @@ function has_chain_rrule(T) no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), config_T, arg_Ts...}) end + is_ambig |= rrule_m === nothing # this means there was an ambiguity error on unconfigured rrule + # To understand why we only need to check if the sigs match between no_rrule_m and rrule_m # in order to decide if to use, one must consider the following facts: # - for every method in `no_rrule` there is a identical one in `rrule` that returns nothing @@ -51,8 +56,7 @@ function has_chain_rrule(T) # It can be seen that checking if it matches is the correct way to decide if we should use the rrule or not. - do_not_use_rrule = matching_cr_sig(no_rrule_m, rrule_m) - if do_not_use_rrule + if !is_ambig && matching_cr_sig(no_rrule_m, rrule_m) # Not ambigious, and opted-out. # Return instance for configured_rrule_m as that will be invalidated # directly if configured rule added, or indirectly if unconfigured rule added # Do not need an edge for `no_rrule` as no addition of methods to that can cause this @@ -60,7 +64,8 @@ function has_chain_rrule(T) # using the rrule, so not using more rules wouldn't change anything. return false, configured_rrule_m.instance else - # Otherwise found a rrule, no need to add any edges for `rrule`, as it will generate + # Either is ambigious, and we should try to use it, and then error + # or we are uses a rrule, no need to add any edges for `rrule`, as it will generate # code with natural edges if a new method is defined there. # We also do not need an edge to `no_rrule`, as any time a method is added to `no_rrule` # a corresponding method is added to `rrule` (to return `nothing`), thus we will already @@ -73,7 +78,7 @@ matching_cr_sig(t, s) = matching_cr_sig(t.method.sig, s.method.sig) matching_cr_sig(::DataType, ::UnionAll) = false matching_cr_sig(::UnionAll, ::DataType) = false matching_cr_sig(t::Type, s::Type) = type_tuple_tail(t) == type_tuple_tail(s) -matching_cr_sig(::Any, ::Nothing) = false # https://github.com/FluxML/Zygote.jl/issues/1234 +matching_cr_sig(::Any, ::Nothing) = false # ambigious https://github.com/FluxML/Zygote.jl/issues/1234 type_tuple_tail(d::DataType) = Tuple{d.parameters[2:end]...} function type_tuple_tail(d::UnionAll) diff --git a/test/chainrules.jl b/test/chainrules.jl index e9cb4afbc..51e0a80c6 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -278,11 +278,20 @@ using Zygote: ZygoteRuleConfig # https://github.com/FluxML/Zygote.jl/issues/1234 @testset "rrule lookup ambiguities" begin - f_ambig(x, y) = x + y - ChainRulesCore.rrule(::typeof(f_ambig), x::Int, y) = x + y, _ -> (0, 0) - ChainRulesCore.rrule(::typeof(f_ambig), x, y::Int) = x + y, _ -> (0, 0) - - @test_throws MethodError pullback(f_ambig, 1, 2) + @testset "unconfigured" begin + f_ambig(x, y) = x + y + ChainRulesCore.rrule(::typeof(f_ambig), x::Int, y) = x + y, _ -> (0, 0) + ChainRulesCore.rrule(::typeof(f_ambig), x, y::Int) = x + y, _ -> (0, 0) + + @test_throws MethodError pullback(f_ambig, 1, 2) + end + @testset "configured" begin + h_ambig(x, y) = x + y + ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(h_ambig), x, y) = x + y, _ -> (0, 0) + ChainRulesCore.rrule(::RuleConfig, ::typeof(h_ambig), x::Int, y::Int) = x + y, _ -> (0, 0) + + @test_throws MethodError pullback(h_ambig, 1, 2) + end end end