Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not able to @opt_out rules with RuleConfig #1342

Closed
tansongchen opened this issue Dec 22, 2022 · 12 comments · Fixed by #1358
Closed

Not able to @opt_out rules with RuleConfig #1342

tansongchen opened this issue Dec 22, 2022 · 12 comments · Fixed by #1358

Comments

@tansongchen
Copy link

tansongchen commented Dec 22, 2022

Let's say I have a type WeirdNumber <: Number is so weird that I don't want its derivative of power function (^) be calculated by rrule of ^ and literal_pow, and should instead go through and differentiate its definition. Since Zygote defines AD-specific rules for literal_pow with RuleConfig in its code base, I had to also @opt_out this rule. However, the following MWE didn't work:

using ChainRulesCore: @opt_out, RuleConfig
import Base: ^

struct WeirdNumber <: Number
    a::Float64
end

^(x::WeirdNumber, p::Int) = WeirdNumber(x.a ^ p)

@opt_out rrule(::typeof(Base.literal_pow), ::typeof(^), x::WeirdNumber, ::Val{p}) where {p}
@opt_out rrule(::RuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::WeirdNumber,
               ::Val{p}) where {p}

using Zygote

fun(x::WeirdNumber) = (x^4).a

gradient(fun, WeirdNumber(2.))

Error:

ERROR: type Nothing has no field method
Stacktrace:
  [1] getproperty(x::Nothing, f::Symbol)
    @ Base ./Base.jl:38
  [2] has_chain_rrule(T::Type)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:21
  [3] #s2948#1074
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:20 [inlined]
  [4] var"#s2948#1074"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
  [5] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:582
  [6] _pullback
    @ ~/Applications/project/TaylorDiff.jl/.vscode/opt_lit_pow.jl:16 [inlined]
  [7] _pullback(ctx::Zygote.Context{false}, f::typeof(fun), args::WeirdNumber)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
  [8] pullback(f::Function, cx::Zygote.Context{false}, args::WeirdNumber)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:44
  [9] pullback
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:42 [inlined]
 [10] gradient(f::Function, args::WeirdNumber)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:96
@oxinabox
Copy link
Member

Oh weird. This is definitely a bug. I am on vacation for next few weeks. And don't know when exactly I will have time to look into it.
If someone else wants to dig feel encouraged.
If not I will try to attend to it when I get back.

@tansongchen
Copy link
Author

Thanks for confirming that's a bug. I will share this info on the Julia slack and see if anyone can help.

@tansongchen
Copy link
Author

tansongchen commented Dec 22, 2022

Update: if I instead add @opt_out directly to ZygoteRuleConfig, then Zygote can correctly give the result (a = 32, ). So this may be related to multiple dispatch ambiguity on rrule. The following is a working example:

using ChainRulesCore: @opt_out, RuleConfig, rrule
using Zygote: ZygoteRuleConfig
import Base: ^

struct WeirdNumber <: Number
    a::Float64
end

^(x::WeirdNumber, p::Int) = WeirdNumber(x.a ^ p)

@opt_out rrule(::typeof(^), x::WeirdNumber, p::Number)
@opt_out rrule(::typeof(Base.literal_pow), ::typeof(^), x::WeirdNumber, ::Val{p}) where {p}
@opt_out rrule(::RuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::WeirdNumber,
               ::Val{p}) where {p}
@opt_out rrule(::ZygoteRuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::WeirdNumber,
                              ::Val{p}) where {p}

using Zygote

fun(x::WeirdNumber) = (x^4).a

gradient(fun, WeirdNumber(2.))

@tansongchen
Copy link
Author

I'm attaching here some discussions on Slack so it may help to diagnose the issue


Brian Chen
Got a tricky case here with @opt_out, rules with RuleConfigs and ambiguities: #1342. is there anything that can be done without having packages with opt outs taking a dep on Zygote?

Frames Catherine White
Doesn't seem to be any Ambiguities there?
Just some kind of bug in has_chain_rrule ?

Brian Chen
The error comes from https://github.com/FluxML/Zygote.jl/blob/master/src/compiler/chainrules.jl#L21, which means configured_rrule_m = meta(...) is returning nothing. The only other time I've seen this is #1234, where the issue was IRTools giving up on an ambiguous method match. I'm not 100% sure this is the culprit, but given the signatures we're working with are
rrule(::ZygoteRuleConfig, ::typeof(Base.literal_pow), ::typeof(^), ::Number, ::Val{p}) # Zygote
rrule(::RuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::WeirdNumber, ::Val{p}) # from @opt_out
no_rrule(::RuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::WeirdNumber, ::Val{p}) # from @opt_out
It seems like a decent guess.

Songchen Tan
Hi Brian, I am the author of this issue. I can confirm that if I use @opt_out on ::ZygoteRuleConfig instead of ::RuleConfig, Zygote can correctly give the result.

Songchen Tan
So this probably is related to ambiguity. However, semantically I would like to declare that this rule is opted out for every AD framework, so that I don’t need to depend on Zygote

Songchen Tan
I’ll attach our discussion on this to GitHub issue so others may help

@oxinabox
Copy link
Member

Hmm.
some proposed solutions

sol0

Just give informative ambiguity error if If meta(...) returns nothing

sol1

If meta(...) returns nothing,
we assume that there is an ambiguity (or other problem with the rrule) and fallback to a sensible default behaviour of ignoring the rule and doing AD.

This solved this case, since goal is to ignore the rule.
The problem with this is that missed rules can lead to slow code and no sign that this is occurring.
I don't think we can even emit a debug log here as it's in the generating part of a generated function. (So only Core.println works)

sol2

There might be an alternative solution where we only do this if there is an no_rrule that would be in the ambig pool. (And otherwise give an informative error message)
Which would mean there is an extra rule to remember: opt out wins if ambiguity.

I am not sure if this is logically sound.

@tansongchen
Copy link
Author

Thanks for the purposed solutions. I'm not an expert of Zygote so couldn't compare on those, probably invite other Zygote maintainers here.

To me this ambiguity sounds like a design problem to ChainRulesCore.jl, since for any AD system, the coexistence of "RuleConfig + Specific Type" method by user and "XXXRuleConfig + Generic Type" by AD author will cause an ambiguity.

@ToucheSir
Copy link
Member

I don't think we can even emit a debug log here as it's in the generating part of a generated function. (So only Core.println works)

We could add a param to the context which controls whether the fallback is allowed or an error is thrown. I agree that this feels suboptimal though.

Sol2 could be narrowed to if there is a no_rrule with a RuleConfig >: MyADRuleConfig. Then if the AD authors really believe they have a better manual rule, they can choose one of taking the opting out package as a (weak) dep or upstreaming their rule. That said, one could make a similar argument for the reverse, where downstream libraries take a weak dep on the AD. It's unfortunate that we can't indicate somehow that the RuleConfig is less important for dispatch than the actual args in case of ambiguity—a language-level fix would be most elegant.

@oxinabox
Copy link
Member

To me this ambiguity sounds like a design problem to ChainRulesCore.jl, since for any AD system, the coexistence of "RuleConfig + Specific Type" method by user and "XXXRuleConfig + Generic Type" by AD author will cause an ambiguity.

I mean, it's a real ambiguity. There is no definite right way to handle it.
Those who prefer an AD that is always fast but sometimes errors would like this to just error sol 0
Those who prefer AD always works would prefer sol 1
so it should be up the the preference of the AD, not ChainRulesCore.

@oxinabox
Copy link
Member

I am inclined to open a PR with Sol 1, because it seems easy enough to implement,
and we can upgrade it to Sol 2 or similar later.

@tansongchen
Copy link
Author

Thanks for the followup, solution 1 definitely works for my problem. (Although it won't work if someone else is trying to use "RuleConfig + Specific Type" pattern to implement custom rules instead of opting out.) I will be happy to try out when you start working on the solution 😄

@tansongchen
Copy link
Author

@oxinabox Thank you for this quick fix. I will add Zygote as a dependency and write @opt_outs specifically for Zygote as a workaround. If in the future you plan to implement your other solutions, please also let me know 😃

@oxinabox
Copy link
Member

I think that is fine. If you update to Julia 1.9 you can make it a weak dependency.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants