Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

non-differentiability of ops on AbstractArray{Bool} #310

Merged
merged 13 commits into from
Dec 3, 2020

Conversation

gxyd
Copy link
Contributor

@gxyd gxyd commented Nov 19, 2020

Fixes #293

I think definitely there are more rules to be added.

@oxinabox
Copy link
Member

oxinabox commented Nov 19, 2020

Some of these are dublicated of ones that we have for ::Any we should remove those

Also can't do types until JuliaDiff/ChainRulesCore.jl#213 is solved.

@oxinabox
Copy link
Member

What do you think about defining these on AbstractArray{Bool} so they apply to both BitArray and Array{Bool} etc?

@non_differentiable cumprod!(::Any, ::BitArray)
@non_differentiable cumsum(::BitArray)
@non_differentiable cumsum!(::Any, ::BitArray)
@non_differentiable DenseMatrix(::BitArray)
Copy link
Contributor Author

@gxyd gxyd Nov 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general Julia query, probably not too much to do with this PR specifically. If I keep this line of code, I get a warning on doing using ChainRules:

┌ Info: Precompiling ChainRules [082447d4-558c-5d27-93f4-14fc19e9eca2]
└ @ Base loading.jl:1260
WARNING: Method definition frule(Any, UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313.
  ** incremental compilation may be fatally broken for this module **

WARNING: Method definition frule##kw(Any, typeof(ChainRulesCore.frule), Any, UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313.
  ** incremental compilation may be fatally broken for this module **

WARNING: Method definition rrule(UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328.
  ** incremental compilation may be fatally broken for this module **

WARNING: Method definition rrule##kw(Any, typeof(ChainRulesCore.rrule), UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328.
  ** incremental compilation may be fatally broken for this module **

A warning, which I believe I understand that somewhere I made a re-definition of a function call, though don't know at which line was the re-definition made.

And when I remove this line of code (remove re-definition of non-differentiability of DenseMatrix(::BitArray), and keep the below line of code about non-differentiability of Matrix(::BitArray) I get the same exact warning:

┌ Info: Precompiling ChainRules [082447d4-558c-5d27-93f4-14fc19e9eca2]
└ @ Base loading.jl:1260
WARNING: Method definition frule(Any, UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313.
  ** incremental compilation may be fatally broken for this module **

WARNING: Method definition frule##kw(Any, typeof(ChainRulesCore.frule), Any, UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313.
  ** incremental compilation may be fatally broken for this module **

WARNING: Method definition rrule(UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328.
  ** incremental compilation may be fatally broken for this module **

WARNING: Method definition rrule##kw(Any, typeof(ChainRulesCore.rrule), UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328.
  ** incremental compilation may be fatally broken for this module **

, it didn't even point me to the line that was causing the error (was it DenseMatrix line or the Matrix line) which often leads me to think over about the non-exactness of the error and warning details in Julia (probably atleast compared to Python), is it just me or am I wrong in thinking like that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is because of the bug that is fixed by JuliaDiff/ChainRulesCore.jl#243
so if you bump the requirement for ChainRulesCore up to 0.9.19
it should be fixed (0.9.19 is currently registering)

@gxyd
Copy link
Contributor Author

gxyd commented Nov 20, 2020

Let's consider a method like repeat(::AbstractArray{Bool}, ::Int), which expects the second argument to be an Int (non-integers aren't accepted) but writing rule for

  1. repeat(::AbstractArray{Bool}, ::Any) is preferable instead of a rule for
  2. repeat(::Abstract{Bool}, ::Int)

because first one will cause an error from the primal's side, instead of returning nothing (from the second one) as returning nothing could mean that the rule might not have been written for it yet. So first is a better approach eventually leading to help for the AD-engine like Zygote?

Is my understanding correct?

@gxyd
Copy link
Contributor Author

gxyd commented Nov 20, 2020

Also can't do types until JuliaDiff/ChainRulesCore.jl#213 is solved.

Would it be ok, if we write the non-differentiability rules for them, but don't merge this until JuliaDiff/ChainRulesCore.jl#213 is solved?

@oxinabox
Copy link
Member

Would it be ok, if we write the non-differentiability rules for them, but don't merge this until JuliaDiff/ChainRulesCore.jl#213 is solved?

It causes #310 (comment)

but now that it is fixed we don't have to worry.

@oxinabox
Copy link
Member

Let's consider a method like repeat(::AbstractArray{Bool}, ::Int), which expects the second argument to be an Int (non-integers aren't accepted) but writing rule for

1. `repeat(::AbstractArray{Bool}, ::Any)` is preferable instead of a rule for

2. `repeat(::Abstract{Bool}, ::Int)`

because first one will cause an error from the primal's side, instead of returning nothing (from the second one) as returning nothing could mean that the rule might not have been written for it yet. So first is a better approach eventually leading to help for the AD-engine like Zygote?

Is my understanding correct?

I would do the first, because even if someone defined something like repeat(::Vector, ::Float64) and even if they somehow e.g. made it doe something sensible like repeat([1,2,3,4], 1.5) == [1,2,3,4,1,2],
by the docstring of the repeat function, the second argument is an intristically nondifferentiable quanity.

@mcabbott
Copy link
Member

While I don't understand the innards of this package, is there a possibility that many of these might be built-in at a higher level?

Perhaps, before applying any rule, ChainRules can check whether any inputs are Bool or AbstractArray{Bool}, and then don't expand those thunks / don't do anything at all. Then (perhaps) explicit instructions to ignore things would only be needed for functions for which there isn't (yet) a rule at all.

@oxinabox
Copy link
Member

Perhaps, before applying any rule, ChainRules can check whether any inputs are Bool or AbstractArray{Bool},

ChainRules can't do this.
But potentially the AD can.
Though how would be different for each AD.
For an typical operator overloading AD, it could already return just a non-overloaded primitive type whenever it as asked to make a overloaded type for a Bool etc.
For an source to source AD that runs on untyped code it seems harder, but it might be able to generate a suitable function that does the checks.

So yes, maybe at some point in the future we will clear house and remove a bunch of rules.
But I am happy to do that later.
When and if this is solved

Comment on lines 91 to 92
@non_differentiable strides(::AbstractArray{Bool})
@non_differentiable vcat(::AbstractArray{Bool})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be vcat(::AbstractArray{Bool}...) if that's accepted?

Also, many of these like strudes & isperm surely aren't differentiable with any input, but isperm([true]) is an error anyway.

similar also accepts further argumens.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that should be the case. We'll wait on the implementation of support of Vararg's as its implemented in PR JuliaDiff/ChainRulesCore.jl#254 ?

@gxyd gxyd changed the title [WIP] non-differentiability of ops on BitArrays non-differentiability of ops on BitArrays Dec 2, 2020
@gxyd
Copy link
Contributor Author

gxyd commented Dec 2, 2020

This is probably ready for a review now..

@gxyd gxyd changed the title non-differentiability of ops on BitArrays non-differentiability of ops on AbstractArray{Bool} Dec 2, 2020
Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the PR and adapting it quickly. LGTM after the tiny change is made

src/rulesets/Base/nondiff.jl Outdated Show resolved Hide resolved
Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>
@mzgubic mzgubic merged commit f212f0a into JuliaDiff:master Dec 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make operations on BitArray's non-differentiable
4 participants