Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make operations on BitArray's non-differentiable #293

Closed
oxinabox opened this issue Oct 22, 2020 · 7 comments · Fixed by #310
Closed

Make operations on BitArray's non-differentiable #293

oxinabox opened this issue Oct 22, 2020 · 7 comments · Fixed by #310
Labels
missing rule non-differentiable For issues/PRs relating to @non_differentiable

Comments

@oxinabox
Copy link
Member

Bools are not a differentiable type,
thus nor are BitArray's

@CarloLucibello
Copy link
Contributor

CarloLucibello commented Nov 28, 2020

It looks like that since true isa Number and there are no specific rules defined in this repo that prevent bools from being differentiated, they get differentiated by generic methods. Having array of bools and bitarray non-differentiable is weird if we don't fix that. For instance, in Zygote we have some inconsistencies:

julia> using Zygote

julia> gradient(x -> x, false) # bools are differentiable
(true,)

julia> gradient(x -> x + 1, false)
(1,)

julia> gradient(x -> sum(x), [false,false])   # array of bools are not
(nothing,)

julia> gradient(x -> x[1] + x[2], [false,false]) # this should give same result as previous line
([1, 1],)

I'm ok with bools being not-differentiable, although they can be embedded in a differentiable manifold as much as e.g. integers
can, so this choice is somewhat arbitrary.

Either way, we must handle consistently single bools and arrays of bools. Can we add the rules for no-grading bools to this repo?
Another option could be to have bools and array of bools differentiable, but BitArrays not diff. I'm not sure this makes sense

@mcabbott

@oxinabox
Copy link
Member Author

We have dozens of rules settings bools as nondifferentiable.
Just not the ones you tried.
We don't have 100% coverage of Base.

@CarloLucibello
Copy link
Contributor

Where are they? I tried to search for Bool here on github but got just one hit

@mcabbott
Copy link
Member

mcabbott commented Nov 28, 2020

Yes it's a choice, but a useful one I think. (And it wasn't my idea!)

Also a relatively new one, Zygote's initial behaviour was to treat all numbers alike, including promoting to Complex un-asked for, which (IMO) is extremely surprising. Grouping (Int8 ... Float64) as all representing mathematical reals (for AD) seems like a good policy to me. This does promote integers to the continuum, but they are often used as easy-to-write real numbers, e.g. nobody in calculus class is confused by sin'(1).

Base treats bools specially in a few places, I guess the first is where we can claim moral precedent:

julia> [1.0 Inf NaN -Inf] .* false  # strong zero
1×4 Array{Float64,2}:
 0.0  0.0  0.0  -0.0

julia> [1.0 Inf NaN -Inf] .* 0
1×4 Array{Float64,2}:
 0.0  NaN  NaN  NaN

julia> (1:3)[[true, false, true]]  # Array{Bool} & BitArray both give logical indexing
2-element Array{Int64,1}:
 1
 3

julia> (1:3)[Any[true, 0x02, Int32(3)]] 
3-element Array{Int64,1}:
 1
 2
 3

About adding rules, I'm still not entirely sure why the functions encoded in rrule can't be further modified. Wherever rrule ultimately gets called, call instead something like this:

function final_rrule(f, xs...)
    y, ruleback = rrule(f, xs...)
    function back(dy)
        map(xs, ruleback(dy)) do x,dx
            x isa Bool ? Zero() : dx
       end
    end
    y, back
end

or some more elegant version which handles BitArrays, Complex, and perhaps Diagonal, etc. (And keeps just the types, not the objects x, perhaps.) What am I missing?

@mcabbott
Copy link
Member

Here's how I hacked this into ZygoteRules at some point. This was for real/complex but could trivially include Bool:

FluxML/ZygoteRules.jl@master...mcabbott:real

@gxyd
Copy link
Contributor

gxyd commented Nov 28, 2020

Where are they? I tried to search for Bool here on github but got just one hit

I also didn't see any, the only rule that I was already going to put in my next commit was:

@non_differentiable Bool(::Any)

@oxinabox
Copy link
Member Author

oxinabox commented Nov 28, 2020

Where are they? I tried to search for Bool here on github but got just one hit

Most of them that we have right now are for functions that return Bools e.g. isequal, isfile.
We could probably do with a bunch more of them still, you are correct.
Especially for ones where Bool is being used as a Number.
In practice, rarely run into things that those cause trouble for.
Where-as I have seen code that has broken from ADing operations on BitArrays.

About adding rules, I'm still not entirely sure why the functions encoded in rrule can't be further modified. Wherever rrule ultimately gets called, call instead something like this:

One could, but it is beyond the scope of this issue.
and there are bigger concerns as to if its even the right way to solve this, depending on the AD in question.
For operator overloading AD you ideally would be to actually swap to not tracking the type, which is done entirely differently.
(It also relates to JuliaDiff/ChainRulesCore.jl#248)
I am by no means saying we can't do this, but I am saying it needs more thought than solving this issue,
which can be done just by the fairly mechanical action of listing out methods we want to mark as non-differentiable.
There is just a whole discussion to have, feel free to open an issue about that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
missing rule non-differentiable For issues/PRs relating to @non_differentiable
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants