Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match layer output to weights #2156

Merged
merged 6 commits into from
Feb 8, 2023
Merged

Match layer output to weights #2156

merged 6 commits into from
Feb 8, 2023

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jan 8, 2023

The performance tips page spends ages talking about how not to accidentally promote to Float64. Despite which, I think some of the tutorials still do so accidentally, and are many times slower than they need to be.

Perhaps we should just automatically fix the problem, and convert to the element type of the weights? Are there any real uses where you'd want to rely on * promoting the weights (or the input) for you, rather than explicitly setting the types?

Fixes #1972

Does not break this example: #1755 (comment)

julia> Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu)

julia> loss_cpu(X, Y)
-10.739283f0

julia> loss_gpu(X_gpu, Y_gpu)
-10.739282f0

julia> rnn_gpu.
cell   state
julia> rnn_gpu.state
3×15 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 -0.896594   -0.879953  -0.546474  -0.513224    -0.92431   -0.00636907  -0.205964  -0.584286
  0.887761    0.803234   0.543203   0.416374      0.810287  -0.0428312    0.295843  -0.0679076
 -0.0956036  -0.38699    0.485399  -0.874913      0.775657   0.655159    -0.781047  -0.871242

julia> rnn.state
3×15 Matrix{Float32}:
 -0.896594   -0.879953  -0.546474  -0.513224    -0.92431   -0.0063691  -0.205965  -0.584287
  0.887761    0.803234   0.543203   0.416374      0.810287  -0.0428311   0.295843  -0.0679076
 -0.0956035  -0.38699    0.485399  -0.874913      0.775657   0.655159   -0.781047  -0.871242

Also fixes #1565, by an explicit method for Recur.

Originally part of #2137

Tutorial:

With minor fixes to run on tagged Flux, time the first epoch (on CPU):

julia> @time for x in train_loader
           # - Flatten the images from 28x28xbatchsize to 784xbatchsize
           real_data = Flux.flatten(x);

           # Train the discriminator
           noise = randn(latent_dim, size(x)[end]) |> gpu
           fake_data = generator(noise)
           loss_dscr = train_dscr!(discriminator, real_data, fake_data)
           # loss_sum_dscr += loss_dscr

           # Train the generator
           loss_gen = train_gen!(discriminator, generator)
           # loss_sum_gen += loss_gen
       end
749.946580 seconds (949.68 k allocations: 98.895 GiB, 94.96% gc time, 0.21% compilation time)

With this PR:

┌ Warning: Layer with Float32 parameters got Float64 input.
│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(100 => 256, #13)  # 25_856 parameterssummary(x) = "100×128 Matrix{Float64}"
└ @ Flux ~/.julia/dev/Flux/src/layers/stateless.jl:77
269.454910 seconds (908.34 k allocations: 30.413 GiB, 92.92% gc time, 0.60% compilation time)

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@codecov-commenter
Copy link

codecov-commenter commented Jan 8, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 83.90%. Comparing base (e33de0c) to head (dc5821f).
Report is 296 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2156      +/-   ##
==========================================
- Coverage   86.38%   83.90%   -2.49%     
==========================================
  Files          19       19              
  Lines        1462     1485      +23     
==========================================
- Hits         1263     1246      -17     
- Misses        199      239      +40     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ToucheSir
Copy link
Member

ToucheSir commented Jan 16, 2023

I've been putting off reading about how other DL libraries handle promotion and/or mixed precision because it's hardly a riveting read, but does anyone have observations to offer there? Thus far, it seems like PyTorch's promotion rules roughly match ours and (automatic) mixed precision training would be mostly orthogonal to this.

Edit: found some interesting related discussion on promotion on PyTorch at pytorch/pytorch#56356.

@mcabbott
Copy link
Member Author

Many thanks for digging.

It sounds like Pytorch is conflicted between wanting roughly Flux/Julia's present rules (always widen), and instead wanting errors on mismatch (assumed to be a mistake), but at present has a mixture of the two.

This PR does a 3rd thing, fixing mistakes by assuming the weights to be definitive. With a warning, which I hope may be about as good as an error for debugging, and less unfriendly. It would be easy to make a switch like CUDA.allow_scalar to let you choose an error.

Re mixed precision, this was my understanding of the basic idea: A copy of the model is run entirely in Float16, weights, data & gradients. (Possibly with the loss function doing some calculation in higher precision.)

Using Float16 gradients to update Float32 weights should now be trivial with Optimisers.jl, which is in fact a nice advantage compared to Flux's old implicit way, which I think would be confused by the two copies of each weight array have different objectids.

@mcabbott
Copy link
Member Author

Doctests pass after using #2157 some places.

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm on the fence about the complexity-performance tradeoff this brings, but at some point the only way we'll know is to try it. Just one comment on the RNN changes.

@@ -200,10 +200,11 @@ end
RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))

function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,I,H,V,T}
function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {F,I,H,V,T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any risk of us having to handle Duals or such here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I forgot, the intention here is that an RNNCell{F,I,H,V,<:AbstractMatrix{Float32}} should accept also accept Matrix{Float64} etc, which we are sure that _match_eltype will convert to Float32, and hence not promote the type of the hidden state. It does not accept Matrix{Nil}, nor dual numbers.

For the old code to accept Dual, the weight matrix would have to also be duals with the same tag parameter. It's hard to imagine how that would happen. So I don't think that this will break anything.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong, but wouldn't the weights be Duals when doing nested diff with Zygote + ForwardDiff? Or would the tags differ there too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the tags would in fact match. Say destructure builds the network (all weights have same Dual type) and then data passes through an earlier layer, then it would arrive here with the same tag.

I'm not sure this would work at all. Chunked mode ForwardDiff is going to call the model several times, which sounds likely to screw up anything RNN.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be easy to just revert this part of the PR if we run into issues? If so, we can at least try it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Could have precisely old and new signatures, calling the same internal function.

@mcabbott
Copy link
Member Author

mcabbott commented Feb 4, 2023

Re complexity, it does seem a little sad to move from this:

(a::Dense)(x::AbstractVecOrMat) = (a.σ).(a.weight * x .+ a.bias)

to this opaque thing (after other PRs too):

function (a::Dense)(x::AbstractVecOrMat)
  _size_check(a, x, 1 => size(a.weight, 2))
  xT = _match_eltype(a, x)
  NNlib.bias_act!(a.σ., a.weight * xT, a.bias)
end

But each change here does serve a real purpose. Maybe we should leave more comments in the source about why. For more complicated layers, the source was never all that self-explanatory, so there's less to lose.

@ToucheSir
Copy link
Member

This complexity was inevitable at some point. We can always consider encapsulating some of it into NNlib functions in the future.

@mcabbott mcabbott merged commit d511d7a into FluxML:master Feb 8, 2023
@mcabbott mcabbott deleted the weight_type branch February 8, 2023 04:04
Comment on lines +85 to +88
# Other floats, and integers, silently fix.
function _match_eltype(layer, ::Type{T}, x::AbstractArray{<:Union{AbstractFloat, Integer}}) where {T}
convert(AbstractArray{T}, x)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct for all types, for example TrackedArray and Array{TrackedReal}. This is leading to a good amount of downstream breakage. Is there a reason this wasn't made to be in a breaking update?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing T<:Union{AbstractFloat,Integer} as well would likely be sufficient.

This is leading to a good amount of downstream breakage

Can you provide some links? We made a point of double-checking the reverse CI results for DiffEqFlux, NeuralPDE and OperatorLearning, so if they have insufficient coverage then that's very concerning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SciMLSensitivity is the one that tests the adjoints in detail. It found it in two tests. SciML/SciMLSensitivity.jl#784 is really just a test of master. An MWE:

using Flux, Tracker

x = Float32[0.8; 0.8]
ann = Chain(Dense(2, 10, tanh), Dense(10, 1))
p2, re = Flux.destructure(ann)
u = Tracker.TrackedArray(x)
p = Tracker.TrackedArray(p2)
z = re(p)(u)[1]
@show typeof(z) # Tracker.TrackedReal{Tracker.TrackedReal{Float32}}

The issue seems to be because eltype(TrackedArray) is TrackedReal, so this makes a TrackedArray{TrackedReal}.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason the relevant test groups for SciMLSensitivity weren't added alongside https://github.com/FluxML/Flux.jl/blob/master/.github/workflows/Downstream.yml#L28-L30? I don't think anyone actively searches for new downstream tests unless they're in the same org.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this wasn't made to be in a breaking update?

Yes, the reason is that there is no plan to stop supporting code which used to work.

This is a bug, and (once again) it turns out that people rely on Flux in ways which aren't tested. Besides adding more downstream tests, it would be extremely helpful if someone who needs Tracker could write up some small self-contained set of Tracker tests, to include here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Geez, I didn't know that was the sentiment. Just tell people to go away and not come back next time...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, bug reports are welcome, and a MWE doesn't have to test all cases.

"would be extremely helpful if" is obviously suggesting useful contributions beyond a bug report. Which nobody is obliged to work on. But flippant suggestions that the MWE is somehow equivalent to upgraded tests get flippant responses.

Copy link
Member

@ChrisRackauckas ChrisRackauckas Feb 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which nobody is obliged to work on. But flippant suggestions that the MWE is somehow equivalent to upgraded tests get flippant responses.

Nobody is obligated to work on it, which is why when I'm going to work on it that it's pretty demoralizing to get such a flippant reply. I was just asking if that MWE was a good start to making a test or whether you were seeing the issue as something different. All you had to say was that it would be good if it checked a few more layer types (and remove the restructure? Is that the Optimisers part you were mentioning?) And I will include that in the PR, along with adding SciMLSensitivity core 1-5 as downstream tests, and a package extension to fix the Tracker case (though this really needs a new array primitive, so I was testing the promote eltype stuff in Array interface and patching all odd array types). But at least I understand now why I get so many DMs scared of contributing to Flux...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the reason is that there is no plan to stop supporting code which used to work.

The purpose of the PR is to change the output of every call that was doing such mixed precision. It changes the element type and the result of a lot of things, and explains some other tests which broke (SciMLDocs). I think that would have at least been discussed as a breaking change which is why I was confused that I didn't find such a discussion. Changing something that was 32 but precision to 64 bit is borderline, but changing a function call from 64 to 32 is going to give a different output. I mean it's fine if the answer is that the results are not guaranteed to that level and therefore it's not breaking, or sorry we didn't catch it in any test so let's improve downstream testing, but I just didn't know where to find the conversation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea it's a bug, we thought about and tested dual numbers, but apparently failed to think about Tracker. Won't be hard to fix but won't happen today. Tracker lies about element types so a bit of care will be needed to get it right & test to ensure it cannot go wrong again.

Accidental promotion to Float64 was the most common Flux performance bug, and the one the docs spend pages trying to teach you to avoid. I was trying to re-write these to be clearer but really it's seems crazy to spend so much text teaching about a footgun we can simply remove. The other half of this footgun was accidental promotion of gradients, which is harder to see, and similarly solved some time ago.

When we discussed this, nobody could picture a real reason to run a model with low-precision weights and high-precision data. For high-precision results, it is of course no problem to run a model with high-precision weights.

changes the element type and the result of a lot of things

What does "a lot of things" mean?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Opaque error caused by Float64 input to RNN Recurrent cell eltype restriction breaks outputsize
4 participants