-
-
Notifications
You must be signed in to change notification settings - Fork 607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Match layer output to weights #2156
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,3 +57,47 @@ true | |
σ = std(x, dims=dims, mean=μ, corrected=false) | ||
return @. (x - μ) / (σ + ϵ) | ||
end | ||
|
||
""" | ||
_match_eltype(layer, ::Type{T}, x) | ||
_match_eltype(layer, x) | ||
|
||
This internal function corrects most layer input to match the type of the weights. | ||
The second method uses `T = eltype(layer.weight)`. | ||
|
||
It solves a common performance bug: Before, accidentally supplying `Float64` input, | ||
or an activation function which produces `Float64`, would silently run the | ||
entire forward pass in this precision. | ||
""" | ||
_match_eltype(layer, ::Type{T}, x::AbstractArray{T}) where {T} = x | ||
|
||
# A common mistake, print a friendly warning, and fix it: | ||
function _match_eltype(layer, ::Type{Float32}, x::AbstractArray{Float64}) | ||
# This warning is the only reason this needs to take the layer. | ||
@warn "Layer with Float32 parameters got Float64 input. | ||
The input will be converted, but any earlier layers may be very slow." layer summary(x) maxlog=1 | ||
convert(AbstractArray{Float32}, x) | ||
end | ||
|
||
# Allow OneHot to reach specialisation of * etc: | ||
_match_eltype(layer, ::Type, x::OneHotLike) = x | ||
|
||
# Other floats, and integers, silently fix. | ||
function _match_eltype(layer, ::Type{T}, x::AbstractArray{<:Union{AbstractFloat, Integer}}) where {T} | ||
convert(AbstractArray{T}, x) | ||
end | ||
Comment on lines
+85
to
+88
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not correct for all types, for example TrackedArray and Array{TrackedReal}. This is leading to a good amount of downstream breakage. Is there a reason this wasn't made to be in a breaking update? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changing
Can you provide some links? We made a point of double-checking the reverse CI results for DiffEqFlux, NeuralPDE and OperatorLearning, so if they have insufficient coverage then that's very concerning. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SciMLSensitivity is the one that tests the adjoints in detail. It found it in two tests. SciML/SciMLSensitivity.jl#784 is really just a test of master. An MWE: using Flux, Tracker
x = Float32[0.8; 0.8]
ann = Chain(Dense(2, 10, tanh), Dense(10, 1))
p2, re = Flux.destructure(ann)
u = Tracker.TrackedArray(x)
p = Tracker.TrackedArray(p2)
z = re(p)(u)[1]
@show typeof(z) # Tracker.TrackedReal{Tracker.TrackedReal{Float32}} The issue seems to be because eltype(TrackedArray) is TrackedReal, so this makes a TrackedArray{TrackedReal}. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason the relevant test groups for SciMLSensitivity weren't added alongside https://github.com/FluxML/Flux.jl/blob/master/.github/workflows/Downstream.yml#L28-L30? I don't think anyone actively searches for new downstream tests unless they're in the same org. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, the reason is that there is no plan to stop supporting code which used to work. This is a bug, and (once again) it turns out that people rely on Flux in ways which aren't tested. Besides adding more downstream tests, it would be extremely helpful if someone who needs Tracker could write up some small self-contained set of Tracker tests, to include here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Geez, I didn't know that was the sentiment. Just tell people to go away and not come back next time... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean, bug reports are welcome, and a MWE doesn't have to test all cases. "would be extremely helpful if" is obviously suggesting useful contributions beyond a bug report. Which nobody is obliged to work on. But flippant suggestions that the MWE is somehow equivalent to upgraded tests get flippant responses. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Nobody is obligated to work on it, which is why when I'm going to work on it that it's pretty demoralizing to get such a flippant reply. I was just asking if that MWE was a good start to making a test or whether you were seeing the issue as something different. All you had to say was that it would be good if it checked a few more layer types (and remove the restructure? Is that the Optimisers part you were mentioning?) And I will include that in the PR, along with adding SciMLSensitivity core 1-5 as downstream tests, and a package extension to fix the Tracker case (though this really needs a new array primitive, so I was testing the promote eltype stuff in Array interface and patching all odd array types). But at least I understand now why I get so many DMs scared of contributing to Flux... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The purpose of the PR is to change the output of every call that was doing such mixed precision. It changes the element type and the result of a lot of things, and explains some other tests which broke (SciMLDocs). I think that would have at least been discussed as a breaking change which is why I was confused that I didn't find such a discussion. Changing something that was 32 but precision to 64 bit is borderline, but changing a function call from 64 to 32 is going to give a different output. I mean it's fine if the answer is that the results are not guaranteed to that level and therefore it's not breaking, or sorry we didn't catch it in any test so let's improve downstream testing, but I just didn't know where to find the conversation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea it's a bug, we thought about and tested dual numbers, but apparently failed to think about Tracker. Won't be hard to fix but won't happen today. Tracker lies about element types so a bit of care will be needed to get it right & test to ensure it cannot go wrong again. Accidental promotion to Float64 was the most common Flux performance bug, and the one the docs spend pages trying to teach you to avoid. I was trying to re-write these to be clearer but really it's seems crazy to spend so much text teaching about a footgun we can simply remove. The other half of this footgun was accidental promotion of gradients, which is harder to see, and similarly solved some time ago. When we discussed this, nobody could picture a real reason to run a model with low-precision weights and high-precision data. For high-precision results, it is of course no problem to run a model with high-precision weights.
What does "a lot of things" mean? |
||
|
||
# Weird types like Nil, Dual, etc, we allow through: | ||
_match_eltype(layer, ::Type, x::AbstractArray) = x | ||
|
||
# 2-arg method, for common layers with layer.weight | ||
_match_eltype(layer, x) = _match_eltype(layer, eltype(layer.weight), x) | ||
|
||
# Trivial rule: | ||
function ChainRulesCore.rrule(::typeof(_match_eltype), layer, ::Type{T}, x::AbstractArray) where {T} | ||
_match_eltype(layer, T, x), dx -> (NoTangent(), ZeroTangent(), NoTangent(), dx) # does not un-thunk dx | ||
end | ||
function ChainRulesCore.rrule(::typeof(_match_eltype), layer, x::AbstractArray) | ||
_match_eltype(layer, x), dx -> (ZeroTangent(), NoTangent(), dx) # does not un-thunk dx | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any risk of us having to handle
Dual
s or such here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I forgot, the intention here is that an
RNNCell{F,I,H,V,<:AbstractMatrix{Float32}}
should accept also acceptMatrix{Float64}
etc, which we are sure that_match_eltype
will convert to Float32, and hence not promote the type of the hidden state. It does not acceptMatrix{Nil}
, nor dual numbers.For the old code to accept
Dual
, the weight matrix would have to also be duals with the same tag parameter. It's hard to imagine how that would happen. So I don't think that this will break anything.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, but wouldn't the weights be Duals when doing nested diff with Zygote + ForwardDiff? Or would the tags differ there too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the tags would in fact match. Say
destructure
builds the network (all weights have same Dual type) and then data passes through an earlier layer, then it would arrive here with the same tag.I'm not sure this would work at all. Chunked mode ForwardDiff is going to call the model several times, which sounds likely to screw up anything RNN.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be easy to just revert this part of the PR if we run into issues? If so, we can at least try it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Could have precisely old and new signatures, calling the same internal function.