Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MethodError if configured rrule is ambiguous #1358

Merged
merged 2 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ such that if a suitable rule is defined later, the generated function will recom
function has_chain_rrule(T)
config_T, arg_Ts = Iterators.peel(T.parameters)
configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...})
if _is_rrule_redispatcher(configured_rrule_m.method)
is_ambig = configured_rrule_m === nothing # this means there was an ambiguity error, on configured_rrule


if !is_ambig && _is_rrule_redispatcher(configured_rrule_m.method)
# The config is not being used:
# it is being redispatched without config, so we need the method it redispatches to
rrule_m = meta(Tuple{typeof(rrule), arg_Ts...})
Expand All @@ -33,6 +36,8 @@ function has_chain_rrule(T)
no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), config_T, arg_Ts...})
end

is_ambig |= rrule_m === nothing # this means there was an ambiguity error on unconfigured rrule

# To understand why we only need to check if the sigs match between no_rrule_m and rrule_m
# in order to decide if to use, one must consider the following facts:
# - for every method in `no_rrule` there is a identical one in `rrule` that returns nothing
Expand All @@ -51,16 +56,16 @@ function has_chain_rrule(T)
# It can be seen that checking if it matches is the correct way to decide if we should use the rrule or not.


do_not_use_rrule = matching_cr_sig(no_rrule_m, rrule_m)
if do_not_use_rrule
if !is_ambig && matching_cr_sig(no_rrule_m, rrule_m) # Not ambigious, and opted-out.
# Return instance for configured_rrule_m as that will be invalidated
# directly if configured rule added, or indirectly if unconfigured rule added
# Do not need an edge for `no_rrule` as no addition of methods to that can cause this
# decision to need to be revisited (only changes to `rrule`), since we are already not
# using the rrule, so not using more rules wouldn't change anything.
return false, configured_rrule_m.instance
else
# Otherwise found a rrule, no need to add any edges for `rrule`, as it will generate
# Either is ambigious, and we should try to use it, and then error
# or we are uses a rrule, no need to add any edges for `rrule`, as it will generate
# code with natural edges if a new method is defined there.
# We also do not need an edge to `no_rrule`, as any time a method is added to `no_rrule`
# a corresponding method is added to `rrule` (to return `nothing`), thus we will already
Expand All @@ -73,7 +78,7 @@ matching_cr_sig(t, s) = matching_cr_sig(t.method.sig, s.method.sig)
matching_cr_sig(::DataType, ::UnionAll) = false
matching_cr_sig(::UnionAll, ::DataType) = false
matching_cr_sig(t::Type, s::Type) = type_tuple_tail(t) == type_tuple_tail(s)
matching_cr_sig(::Any, ::Nothing) = false # https://github.com/FluxML/Zygote.jl/issues/1234
matching_cr_sig(::Any, ::Nothing) = false # ambigious https://github.com/FluxML/Zygote.jl/issues/1234

type_tuple_tail(d::DataType) = Tuple{d.parameters[2:end]...}
function type_tuple_tail(d::UnionAll)
Expand Down
19 changes: 14 additions & 5 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,20 @@ using Zygote: ZygoteRuleConfig

# https://github.com/FluxML/Zygote.jl/issues/1234
@testset "rrule lookup ambiguities" begin
f_ambig(x, y) = x + y
ChainRulesCore.rrule(::typeof(f_ambig), x::Int, y) = x + y, _ -> (0, 0)
ChainRulesCore.rrule(::typeof(f_ambig), x, y::Int) = x + y, _ -> (0, 0)

@test_throws MethodError pullback(f_ambig, 1, 2)
@testset "unconfigured" begin
f_ambig(x, y) = x + y
ChainRulesCore.rrule(::typeof(f_ambig), x::Int, y) = x + y, _ -> (0, 0)
ChainRulesCore.rrule(::typeof(f_ambig), x, y::Int) = x + y, _ -> (0, 0)

@test_throws MethodError pullback(f_ambig, 1, 2)
end
@testset "configured" begin
h_ambig(x, y) = x + y
ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(h_ambig), x, y) = x + y, _ -> (0, 0)
ChainRulesCore.rrule(::RuleConfig, ::typeof(h_ambig), x::Int, y::Int) = x + y, _ -> (0, 0)

@test_throws MethodError pullback(h_ambig, 1, 2)
end
end
end

Expand Down