-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Match layer output to weights #2156
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #2156 +/- ##
==========================================
- Coverage 86.38% 83.90% -2.49%
==========================================
Files 19 19
Lines 1462 1485 +23
==========================================
- Hits 1263 1246 -17
- Misses 199 239 +40 ☔ View full report in Codecov by Sentry. |
I've been putting off reading about how other DL libraries handle promotion and/or mixed precision because it's hardly a riveting read, but does anyone have observations to offer there? Thus far, it seems like PyTorch's promotion rules roughly match ours and (automatic) mixed precision training would be mostly orthogonal to this. Edit: found some interesting related discussion on promotion on PyTorch at pytorch/pytorch#56356. |
Many thanks for digging. It sounds like Pytorch is conflicted between wanting roughly Flux/Julia's present rules (always widen), and instead wanting errors on mismatch (assumed to be a mistake), but at present has a mixture of the two. This PR does a 3rd thing, fixing mistakes by assuming the weights to be definitive. With a warning, which I hope may be about as good as an error for debugging, and less unfriendly. It would be easy to make a switch like Re mixed precision, this was my understanding of the basic idea: A copy of the model is run entirely in Float16, weights, data & gradients. (Possibly with the loss function doing some calculation in higher precision.) Using Float16 gradients to update Float32 weights should now be trivial with Optimisers.jl, which is in fact a nice advantage compared to Flux's old implicit way, which I think would be confused by the two copies of each weight array have different |
Doctests pass after using #2157 some places. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm on the fence about the complexity-performance tradeoff this brings, but at some point the only way we'll know is to try it. Just one comment on the RNN changes.
@@ -200,10 +200,11 @@ end | |||
RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) = | |||
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1)) | |||
|
|||
function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,I,H,V,T} | |||
function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {F,I,H,V,T} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any risk of us having to handle Dual
s or such here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I forgot, the intention here is that an RNNCell{F,I,H,V,<:AbstractMatrix{Float32}}
should accept also accept Matrix{Float64}
etc, which we are sure that _match_eltype
will convert to Float32, and hence not promote the type of the hidden state. It does not accept Matrix{Nil}
, nor dual numbers.
For the old code to accept Dual
, the weight matrix would have to also be duals with the same tag parameter. It's hard to imagine how that would happen. So I don't think that this will break anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, but wouldn't the weights be Duals when doing nested diff with Zygote + ForwardDiff? Or would the tags differ there too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the tags would in fact match. Say destructure
builds the network (all weights have same Dual type) and then data passes through an earlier layer, then it would arrive here with the same tag.
I'm not sure this would work at all. Chunked mode ForwardDiff is going to call the model several times, which sounds likely to screw up anything RNN.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be easy to just revert this part of the PR if we run into issues? If so, we can at least try it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Could have precisely old and new signatures, calling the same internal function.
Re complexity, it does seem a little sad to move from this: (a::Dense)(x::AbstractVecOrMat) = (a.σ).(a.weight * x .+ a.bias) to this opaque thing (after other PRs too): function (a::Dense)(x::AbstractVecOrMat)
_size_check(a, x, 1 => size(a.weight, 2))
xT = _match_eltype(a, x)
NNlib.bias_act!(a.σ., a.weight * xT, a.bias)
end But each change here does serve a real purpose. Maybe we should leave more comments in the source about why. For more complicated layers, the source was never all that self-explanatory, so there's less to lose. |
This complexity was inevitable at some point. We can always consider encapsulating some of it into NNlib functions in the future. |
# Other floats, and integers, silently fix. | ||
function _match_eltype(layer, ::Type{T}, x::AbstractArray{<:Union{AbstractFloat, Integer}}) where {T} | ||
convert(AbstractArray{T}, x) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not correct for all types, for example TrackedArray and Array{TrackedReal}. This is leading to a good amount of downstream breakage. Is there a reason this wasn't made to be in a breaking update?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing T<:Union{AbstractFloat,Integer}
as well would likely be sufficient.
This is leading to a good amount of downstream breakage
Can you provide some links? We made a point of double-checking the reverse CI results for DiffEqFlux, NeuralPDE and OperatorLearning, so if they have insufficient coverage then that's very concerning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SciMLSensitivity is the one that tests the adjoints in detail. It found it in two tests. SciML/SciMLSensitivity.jl#784 is really just a test of master. An MWE:
using Flux, Tracker
x = Float32[0.8; 0.8]
ann = Chain(Dense(2, 10, tanh), Dense(10, 1))
p2, re = Flux.destructure(ann)
u = Tracker.TrackedArray(x)
p = Tracker.TrackedArray(p2)
z = re(p)(u)[1]
@show typeof(z) # Tracker.TrackedReal{Tracker.TrackedReal{Float32}}
The issue seems to be because eltype(TrackedArray) is TrackedReal, so this makes a TrackedArray{TrackedReal}.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason the relevant test groups for SciMLSensitivity weren't added alongside https://github.com/FluxML/Flux.jl/blob/master/.github/workflows/Downstream.yml#L28-L30? I don't think anyone actively searches for new downstream tests unless they're in the same org.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason this wasn't made to be in a breaking update?
Yes, the reason is that there is no plan to stop supporting code which used to work.
This is a bug, and (once again) it turns out that people rely on Flux in ways which aren't tested. Besides adding more downstream tests, it would be extremely helpful if someone who needs Tracker could write up some small self-contained set of Tracker tests, to include here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Geez, I didn't know that was the sentiment. Just tell people to go away and not come back next time...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, bug reports are welcome, and a MWE doesn't have to test all cases.
"would be extremely helpful if" is obviously suggesting useful contributions beyond a bug report. Which nobody is obliged to work on. But flippant suggestions that the MWE is somehow equivalent to upgraded tests get flippant responses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which nobody is obliged to work on. But flippant suggestions that the MWE is somehow equivalent to upgraded tests get flippant responses.
Nobody is obligated to work on it, which is why when I'm going to work on it that it's pretty demoralizing to get such a flippant reply. I was just asking if that MWE was a good start to making a test or whether you were seeing the issue as something different. All you had to say was that it would be good if it checked a few more layer types (and remove the restructure? Is that the Optimisers part you were mentioning?) And I will include that in the PR, along with adding SciMLSensitivity core 1-5 as downstream tests, and a package extension to fix the Tracker case (though this really needs a new array primitive, so I was testing the promote eltype stuff in Array interface and patching all odd array types). But at least I understand now why I get so many DMs scared of contributing to Flux...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the reason is that there is no plan to stop supporting code which used to work.
The purpose of the PR is to change the output of every call that was doing such mixed precision. It changes the element type and the result of a lot of things, and explains some other tests which broke (SciMLDocs). I think that would have at least been discussed as a breaking change which is why I was confused that I didn't find such a discussion. Changing something that was 32 but precision to 64 bit is borderline, but changing a function call from 64 to 32 is going to give a different output. I mean it's fine if the answer is that the results are not guaranteed to that level and therefore it's not breaking, or sorry we didn't catch it in any test so let's improve downstream testing, but I just didn't know where to find the conversation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea it's a bug, we thought about and tested dual numbers, but apparently failed to think about Tracker. Won't be hard to fix but won't happen today. Tracker lies about element types so a bit of care will be needed to get it right & test to ensure it cannot go wrong again.
Accidental promotion to Float64 was the most common Flux performance bug, and the one the docs spend pages trying to teach you to avoid. I was trying to re-write these to be clearer but really it's seems crazy to spend so much text teaching about a footgun we can simply remove. The other half of this footgun was accidental promotion of gradients, which is harder to see, and similarly solved some time ago.
When we discussed this, nobody could picture a real reason to run a model with low-precision weights and high-precision data. For high-precision results, it is of course no problem to run a model with high-precision weights.
changes the element type and the result of a lot of things
What does "a lot of things" mean?
The performance tips page spends ages talking about how not to accidentally promote to Float64. Despite which, I think some of the tutorials still do so accidentally, and are many times slower than they need to be.
Perhaps we should just automatically fix the problem, and convert to the element type of the weights? Are there any real uses where you'd want to rely on
*
promoting the weights (or the input) for you, rather than explicitly setting the types?Fixes #1972
Does not break this example: #1755 (comment)
Also fixes #1565, by an explicit method for Recur.
Originally part of #2137
Tutorial:
With minor fixes to run on tagged Flux, time the first epoch (on CPU):
With this PR:
PR Checklist