Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FloatNN(::Dual) #538

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Add FloatNN(::Dual) #538

wants to merge 6 commits into from

Conversation

mzgubic
Copy link
Member

@mzgubic mzgubic commented Jul 23, 2021

Needed for Zygote on ChainRules 1.0 (only the Float64 case strictly speaking)

Similar to the recently merged https://github.com//pull/508/files

@codecov-commenter
Copy link

codecov-commenter commented Jul 23, 2021

Codecov Report

Merging #538 (c5102fa) into master (6b393f4) will increase coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #538      +/-   ##
==========================================
+ Coverage   84.83%   84.85%   +0.01%     
==========================================
  Files           9        9              
  Lines         831      832       +1     
==========================================
+ Hits          705      706       +1     
  Misses        126      126              
Impacted Files Coverage Δ
src/dual.jl 72.82% <100.00%> (+0.09%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 6b393f4...c5102fa. Read the comment docs.

src/dual.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member

Needed for Zygote on ChainRules 1.0 (only the Float64 case strictly speaking)

Why does Zygote need this?
This is pretty weird to need.
Feels like Zygote is in the wrong.

@mzgubic
Copy link
Member Author

mzgubic commented Jul 23, 2021

It's the ProjectTo really that errors:

diagonal hessian: Error During Test at /Users/mzgubic/JuliaEnvs/Zygote.jl/test/utils.jl:22
  Got exception outside of a @test
  MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, Float64, 6})
  Closest candidates are:
    (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
    (::Type{T})(::T) where T<:Number at boot.jl:760
    (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50
    ...
  Stacktrace:
    [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 6})
      @ Base ./number.jl:7
    [2] (::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}})(dx::ForwardDiff.Dual{Nothing, Float64, 6})
      @ ChainRulesCore ~/JuliaEnvs/Zygote.jl/dev/ChainRulesCore/src/projection.jl:144
    [3] ^_pullback
      @ ~/JuliaEnvs/Zygote.jl/dev/ChainRules/src/rulesets/Base/fastmath_able.jl:172 [inlined]

@oxinabox
Copy link
Member

I suspect that the correct thing to do here is not to drop the partial part.
These are not actually the mathematical objects called dual numbers, which can simply be projected onto the reals.
They are the abstraction of forwards-mode AD, which is normally isomorphic to the dual number, but not this case I think.

I think the correct thing to do is to end up with more dual numbers.
Right now we have

julia> d = ForwardDiff.Dual(1,2)
Dual{Nothing}(1,2)

julia> AbstractFloat(d)
Dual{Nothing}(1.0,2.0)

So matching the behavour of AbstractFloat would be to convert the partials and the values to a Float64 etc.
Then if Zygote really wants to drop the partials, it can by doing so explicitly.

test/DualTest.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member

I am now more convinced this is correct.
It is not a conversation of the Dual to a Float64 (etc)
It is using a forward mode operator overloading AD to compute the derivative of the Float64 (etc) constructor.
Which is to apply Float64 to the partial as well.

It is kinda gross that to do so it has to overload the constructor to return the Dual{Float64} type.
But that is operator overloading AD for you.

@oxinabox
Copy link
Member

oxinabox commented Jul 26, 2021

@YingboMa raised the good questions about how many invalidations this would cause.
It would make sense for this to be pretty invaldiating, since it would be quite possible for world-splitting to be causing it to depend on all uses of Float64 returning a Float64 (which I think it should but again operator overloading ADs are gross)

But checking with SnoopCompile (and remembering to --startup=no this time)

julia> using SnoopCompileCore;

julia> invalidations = @snoopr begin 
       using ForwardDiff
       end

It seems
ForwardDiff#master has 1108 invalidations
This PR has only 896 invalidations
I am not sure how exactly that is possible.
Maybe i screwed something up?
I did use ]activate and add during the sessions i was measuring, before loading SnoopCompileCore

@KristofferC
Copy link
Collaborator

This needs a test for when this is a problem for ForwardDiff. What bug does this fix? The test just verifies that the implementation does what the implementation does but it lacks a motivating example.

@oxinabox
Copy link
Member

oxinabox commented Jul 26, 2021

This needs a test for when this is a problem for ForwardDiff. What bug does this fix?

Done.
I have added an example of code that would error before, but now returns a correct result.

@mcabbott
Copy link
Member

Late to the party, but this seems a bit odd.

If I understand right what's happening is this:

julia> using ChainRulesCore, ForwardDiff

julia> p = ProjectTo(1.0)  # Float64 in forward pass
ProjectTo{Float64}()

julia> p(ForwardDiff.Dual(1.0, true, false))  # Dual in backward pass
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, Float64, 2})

julia> ProjectTo(ForwardDiff.Dual(1.0, true, false))  # this one we made not fussy
ProjectTo{Real}()

Could it be fixed elsewhere? The reason p converts to Float64 is mainly to protect Float32 from accidental promotion (or at least stop it propagating.) Maybe that method should only apply to AbstractFloat?

Or be done through an intermediate function which means "convert the numeric type to Float64, but preserve the Dual / units / etc"?

Second, this PR doesn't seem to achieve that, since the FloatN constructor doesn't always change the base numeric type. And it leads to weird error messages:

julia> Float32(Dual(1.0, true, false)) |> dump
Dual{Nothing, Float64, 2}
  value: Float64 1.0
  partials: ForwardDiff.Partials{2, Float64}
    values: Tuple{Float64, Float64}
      1: Float64 1.0
      2: Float64 0.0

julia> rand(3)[1] = Dual(1.0, true, false)
ERROR: TypeError: in typeassert, expected Float64, got a value of type Dual{Nothing, Float64, 2}
Stacktrace:
 [1] setindex!(A::Vector{Float64}, x::Dual{Nothing, Float64, 2}, i1::Int64)

@oxinabox
Copy link
Member

oxinabox commented Jul 26, 2021

Second, this PR doesn't seem to achieve that, since the FloatN constructor doesn't always change the base numeric type.

Right now it has some promotion stuff in there.
Because it's based off the one for AbstractFloat.
So it gives the hight of the current type and the requested one
Possibly that could be removed?
I am not sure if it is required to make nested calls work.

And it leads to weird error messages:

That error message seems fine to me.
What else would it say?
Same thing but in method error?

Or be done through an intermediate function which means "convert the numeric type to Float64, but preserve the Dual / units / etc"?

Could be a thing.
Not sure it would be possible without ForwardDiff overloading something from ChainRulesCore.
But I guess we could have a fallback for other things that subtype Real of leave them alone, they probably know what they are doing.
It does feel a little weird to treat a Dual like it is a container bfor 2 Floats in the same way a Complex is.
It's not really a container, where they should be converted on sync except by the "coincidence" that it just so happens that the operation that needs to be performed on the partial when calling Float32 on the primal is also Float32.

I do still think this PR makes perfectly fine sense though.
It is operator overloading AD of the constructor.
The operation Zygote is trying to do is to AD the combined pullback (i.e. only the reverse part) using ForwardDiff.

From the perspective of trying to do forwards mode AD on some black box ForwardDiff is failing to AD it, because right now it doesn't know what to do when it sees a calm the Float32(X) .
This PR fixes that.
The fact that the code includes a call to a pullback from ChainRules.jl is kinda immaterial.
In some sense.

@mcabbott
Copy link
Member

mcabbott commented Jul 26, 2021

Right now it has some promotion stuff in there.

If Float32(Dual(...)) means anything other than conversion to Float32, then surely it means some kind of implicit broadcast over the numbers within the Dual, right? Which promotion defeats. (The tests are only for integers.)

Crossed out above, but #508 recently made Int(::Dual) work, but not in this sense of converting the base numbers. It returns an Int, and throws an InexactError if it doesn't fit.

I remain a little concerned that making T(x)::T fail is too weird. Especially for basic bitstypes. How widely would this idea be expected to propagate? Float32(RGB(0.1,0.2,0.3)) does not work (with ImageCore.jl); I see that Float32(1u"m") does actually work --- Unitful.jl does it for f in (:float, :BigFloat, :Float64, :Float32, :Float16). And of course things like Float32(Vec(1,2,3,4)) from VectorizationBase.jl do work.

Not sure it would be possible without ForwardDiff overloading something from ChainRulesCore

Yes, I have no idea where would be the right place for such a "convert the numbers inside this" to live, if it isn't overloaded on the constructors.

I do still think the idea of making ProjectTo less strict here deserves some thought. In fact I thought that, on some iteration of that, someone convinced me that Zygote.hessian would always have the same Dual numbers (same Tag) forward and backward passes. What's the MWE that triggers this?

That error message seems fine to me. What else would it say?

My complaint is that a failed typeassert isn't a "you the user gave be bad input" message, it's a "my understanding of what Julia can legally produce here is mistaken" message.

Independent of this PR, it's possible that it should say something more helpful. This appears to be the leading way to use ForwardDiff wrong, so possibly some friendly message (explaining that your buffer needs a wide enough type) would be a good idea.

But the basic fact that you cannot convert Duals to Floats because this would silently lose derivatives is essential to understanding how this package works. I don't think the non-conversion in this PR can in fact silently lose information; it seems super-important to be sure of that.

@oxinabox
Copy link
Member

oxinabox commented Jul 27, 2021

If Float32(Dual(...)) means anything other than conversion to Float32, then surely it means some kind of implicit broadcast over the numbers within the Dual`, right?

Duals are not numbers.
Not in the sense that Complex are.
They are things to let operator overloading AD happen.
Which just so happens to mostly agree with mathematical motion of a dual number with the sqrt(0)!=0 thing.
Which kinda makes this whole thing a question of semantics.

But anyway, that's why it's constructors don't return the right types.
Other operator overloading ADs do this also.
Nabla does this to many constructor.

Which promotion defeats. (The tests are only for integers.)

Yeah, I will check tomorrow to see if removing promotion is a thing.
I am pretty sure it should be, just need to write some tests.

I do still think the idea of making ProjectTo less strict here deserves some thought.

Out of scope for this package though.

In fact I thought that, on some iteration of that, someone convinced me that Zygote.hessian would always have the same Dual numbers (same Tag) forward and backward passes. What's the MWE that triggers this?

That was me.
I was wrong.
I thought Zygote.hessian did reverse over forwards, but it does forwards over reverse.
Which i should have released in retrospect, that is the fast direction.

MWE

xs, y = randn(2,3), rand()
f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments

dx, dy = diaghessian(f34, xs, y) 
@test size(dx) == size(xs) 
@test vec(dx)  diag(hessian(x -> f34(x,y), xs)) @test dy  hessian(y -> f34(xs,y), y)

Crossed out above, but #508 recently made Int(::Dual) work, but not in this sense of converting the base numbers. It returns an Int, and throws an InexactError if it doesn't fit.

Indeed.
Which is not in the same sense as this.
OTOH, the forward derivative through a Int(x) primal is going to have zero for the partial anyway.
Since you can't perturb X (as it would error. So NoTangent() in CRC terms).
So dropping in the partial component seems not a problem?

@KristofferC
Copy link
Collaborator

KristofferC commented Jul 27, 2021

What is special about Float64?

Isn't what you saying that for any type T, T(::Dual) should "broadcast" T onto the real and dual part? This feels kind of similar to the convert(::Type{<:Any}, x::Dual) idea in some packages which causes a lot of invalidation issues.

@mcabbott
Copy link
Member

Not in the sense that Complex are. They are things to let operator overloading AD happen.

Yes. My example above was Vec for SIMD, which is similarly a way of threading more information through code which was written for scalars. Another I just thought of is Measurements.jl, which FWIW does not allow such conversions:

julia> λ = measurement(0, 0.1)
0.0 ± 0.1

julia> typeof(λ)
Measurement{Float64}

julia> Float32(λ)
ERROR: MethodError: no method matching Float32(::Measurement{Float64})

@dlfivefifty
Copy link
Contributor

Wouldn't it be better to add float32(x) and float64(x) functions to avoid any issues with invalidation?

@dlfivefifty
Copy link
Contributor

Or better yet float(Val(32), x) and float(Val(64), x)

@devmotion
Copy link
Member

devmotion commented Dec 19, 2021

I think the definition of Float64 in this PR is reasonable for operator overloading AD (and might be the right thing to do).

However, I am worried that Int (and Integer) and Float64 etc. would behave differently and that Float64(x)::Float64 would not be satisfied. A common use case for explicit Float64 constructors or conversions seems to be that the code actually requires values of type Float64, but this would require Float64(d::Dual) = Float64(value(d)) (probably together with a check of the partials, similar to Int and Integer). I came across this discussion here since e.g. it would break

using ForwardDiff
using Random

struct IsoNormal{V<:AbstractVector}
    mu::V
end

function Random.rand!(rng::Random.AbstractRNG, x::AbstractVector, d::IsoNormal)
    length(x) == length(d.mu) || throw(DimensionMismatch())
    randn!(rng, x)
    x .+= d.mu
    return x
end

Base.rand(d::IsoNormal) = rand(Float64, d)
Base.rand(::Type{T}, d::IsoNormal) where {T} = Base.rand(Random.GLOBAL_RNG, T, d)
Base.rand(rng::Random.AbstractRNG, d::IsoNormal) = rand(rng, Float64, d)
function Base.rand(rng::AbstractRNG, ::Type{T}, d::IsoNormal) where {T}
    return rand!(rng, Vector{T}(undef, length(d.mu)), d)
end

rand(IsoNormal(zeros(2))) # works, returns `Vector{Float64}`
rand(Float32, IsoNormal(zeros(2))) # works, returns `Vector{Float32}`
rand(typeof(ForwardDiff.Dual(0f0)), IsoNormal(zeros(2))) # works (but probably uncommon), returns `Vector{<:Dual}`

rand(IsoNormal([ForwardDiff.Dual(0.0) for _ in 1:2])) # fails since `Float64(d::Dual)` not defined or a `Dual` as in this PR
rand(Float32, IsoNormal([ForwardDiff.Dual(0.0) for _ in 1:2])) # fails since `Float64(d::Dual)` not defined or a `Dual` as in this PR
rand(typeof(ForwardDiff.Dual(0f0)), IsoNormal([ForwardDiff.Dual(0.0) for _ in 1:2])) # works as expected

which came up in JuliaStats/Distributions.jl#1433 (it's not clear from this example but intentionally the default type T is chosen as Float64 independently of mu there).

@KristofferC
Copy link
Collaborator

I think the definition of Float64 in this PR is reasonable for operator overloading AD (and might be the right thing to do).

A common use case for explicit Float64 constructors or conversions seems to be that the code actually requires values of type Float64, b

Yes, if someone writes explicitly Float64 then they do want a Float64. You can take this argument that you want Constructor(d::Dual) == Dual(Constructor(d.val), 0.0) for a huge number of constructors and add overloads after overload and it will still not work well. At that point, you are "outside" the generic method system that ForwardDiff is designed for and things will likely error out just a little bit later.

@devmotion
Copy link
Member

I just learnt some days ago that Tracker contains an even more general form of the approach in this PR (but, of course, it's a really nasty type piracy and it also drops the tag): FluxML/Tracker.jl#134

@mcabbott mcabbott changed the title Add Float(::Dual) Add FloatNN(::Dual) Sep 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants