-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
non-differentiability of ops on AbstractArray{Bool} #310
Conversation
Some of these are dublicated of ones that we have for Also can't do types until JuliaDiff/ChainRulesCore.jl#213 is solved. |
What do you think about defining these on |
src/rulesets/Base/nondiff.jl
Outdated
@non_differentiable cumprod!(::Any, ::BitArray) | ||
@non_differentiable cumsum(::BitArray) | ||
@non_differentiable cumsum!(::Any, ::BitArray) | ||
@non_differentiable DenseMatrix(::BitArray) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A general Julia query, probably not too much to do with this PR specifically. If I keep this line of code, I get a warning on doing using ChainRules
:
┌ Info: Precompiling ChainRules [082447d4-558c-5d27-93f4-14fc19e9eca2]
└ @ Base loading.jl:1260
WARNING: Method definition frule(Any, UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313.
** incremental compilation may be fatally broken for this module **
WARNING: Method definition frule##kw(Any, typeof(ChainRulesCore.frule), Any, UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313.
** incremental compilation may be fatally broken for this module **
WARNING: Method definition rrule(UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328.
** incremental compilation may be fatally broken for this module **
WARNING: Method definition rrule##kw(Any, typeof(ChainRulesCore.rrule), UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328.
** incremental compilation may be fatally broken for this module **
A warning, which I believe I understand that somewhere I made a re-definition of a function call, though don't know at which line was the re-definition made.
And when I remove this line of code (remove re-definition of non-differentiability of DenseMatrix(::BitArray)
, and keep the below line of code about non-differentiability of Matrix(::BitArray)
I get the same exact warning:
┌ Info: Precompiling ChainRules [082447d4-558c-5d27-93f4-14fc19e9eca2]
└ @ Base loading.jl:1260
WARNING: Method definition frule(Any, UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313.
** incremental compilation may be fatally broken for this module **
WARNING: Method definition frule##kw(Any, typeof(ChainRulesCore.frule), Any, UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313.
** incremental compilation may be fatally broken for this module **
WARNING: Method definition rrule(UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328.
** incremental compilation may be fatally broken for this module **
WARNING: Method definition rrule##kw(Any, typeof(ChainRulesCore.rrule), UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328.
** incremental compilation may be fatally broken for this module **
, it didn't even point me to the line that was causing the error (was it DenseMatrix
line or the Matrix
line) which often leads me to think over about the non-exactness of the error and warning details in Julia (probably atleast compared to Python), is it just me or am I wrong in thinking like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is because of the bug that is fixed by JuliaDiff/ChainRulesCore.jl#243
so if you bump the requirement for ChainRulesCore up to 0.9.19
it should be fixed (0.9.19 is currently registering)
Let's consider a method like
because first one will cause an error from the primal's side, instead of returning Is my understanding correct? |
Would it be ok, if we write the non-differentiability rules for them, but don't merge this until JuliaDiff/ChainRulesCore.jl#213 is solved? |
It causes #310 (comment) but now that it is fixed we don't have to worry. |
I would do the first, because even if someone defined something like |
While I don't understand the innards of this package, is there a possibility that many of these might be built-in at a higher level? Perhaps, before applying any rule, ChainRules can check whether any inputs are |
ChainRules can't do this. So yes, maybe at some point in the future we will clear house and remove a bunch of rules. |
src/rulesets/Base/nondiff.jl
Outdated
@non_differentiable strides(::AbstractArray{Bool}) | ||
@non_differentiable vcat(::AbstractArray{Bool}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be vcat(::AbstractArray{Bool}...)
if that's accepted?
Also, many of these like strudes
& isperm
surely aren't differentiable with any input, but isperm([true])
is an error anyway.
similar
also accepts further argumens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think that should be the case. We'll wait on the implementation of support of Vararg
's as its implemented in PR JuliaDiff/ChainRulesCore.jl#254 ?
This is probably ready for a review now.. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the PR and adapting it quickly. LGTM after the tiny change is made
Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>
Fixes #293
I think definitely there are more rules to be added.