-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add destructure
, take II
#54
Conversation
src/destructure.jl
Outdated
len = Ref(0) | ||
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y | ||
push!(arrays, vec(y)) | ||
o = len[] | ||
len[] = o + length(y) | ||
o | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Off-topic, but this (and gamma!
) is a concrete example of where something like FluxML/Functors.jl#32 or fmapreduce
would help. Instead of writing to a ref, you'd just carry an accumulated total through the traversal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Earlier commits, FWIW, just append!
ed to a vector, and uses its length instead. This is faster, and delays working out the promoted type instead of needing another basically fmapreduce
pass. A fancy way to pass the integer forwards might avoid the Ref, but we don't want to do pairwise vcat
array-by-array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was referring to just the total instead of using carried state for the actual array of arrays (or single array). I assume the single array path was nixed because of not knowing the container type without traversing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, getting the type needed another walk: 5a7bfc8#diff-2fc6059e247338b5ac149900b866865ae69bdf3693d9ce37cef19f230ddb8e30L104
(And forgetting to make that walk non-differentiable gave some surprising bugs...)
src/destructure.jl
Outdated
struct Restucture{T,S} | ||
model::T | ||
offsets::S | ||
length::Int | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm really happy this is being reified. The current implementation in Flux tries to dance around having to explicitly represent the state, but that's a big part of why it's so inflexible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My main reason to replace this:
function destructure(x)
flat, off, len = alpha(x)
flat, v -> beta(x, off, v; len)
end
with 7 more lines was that the anonymous function's type contains the huge offset struct's type, so it fills your screen...
Glad you don't object though! This struct is very much internal still, of course.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, that'll be a great QoL change too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also make it possible to port some of the speedy stuff (if we want to) in DiffEqFlux for FastChain
etc. easier too. We can consider adding
(re::Restructure)(x, flat) = re(flat)(x)
for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No downsides to that.
One question here is whether we want overwriting versions. Should there be a destructure!
whose re
writes back into the same model? And if we want (c::Chain)(x, p::Vector)
to call something like that, but never make the vector, can these share code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been thinking a bit around that in light of FluxML/Flux.jl#1855 and FluxML/Flux.jl#1809 (comment), but it's not clear to me how it could be done somewhat generically within Flux's existing design. This would be a lot easier if layers didn't own their trainable parameters, Flax-style.
Just a +1 comment from me. Brian already covered anything I would have. |
e772de1
to
af14f84
Compare
6e4f634
to
756b450
Compare
b62e0a2
to
d95a147
Compare
@testset "second derivative" begin | ||
@test_broken gradient(collect(1:6.0)) do y | ||
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) | ||
end[1] ≈ [8,16,24,0,0,0] | ||
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} | ||
# with Zygote, which can be fixed by: | ||
# Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the remaining question on this PR is whether and how much to care about 2nd derivatives. Some work, some don't. I convinced myself there is no bug in the basic logic. But in the details of when to wrap what in a Tangent, or unwrap it for Zygote, there might be bugs, here or upstream.
If we want to be pedantic we could make all 2nd derivatives an error, rather than risk any being wrong. Or a warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least a warning sounds good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
All warnings are maxlog=3
, so as not to be too annoying if something does actually work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to go?
|
||
# These are only needed for 2nd derivatives: | ||
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat) | ||
@warn "second derivatives of Restructure may not work yet, sorry!" maxlog=3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given these were completely busted most of the time before, I don't think we need to apologize so profusely 😆
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Heh. May as well be friendly!
Also, I think the point is that they ought to work, the structure does allow for them. Just it has bugs right now.
This adds a
destructure
function, like Flux's, which should handle awkward cases. Reqiures FluxML/Functors.jl#37 .Alternative to #40 . This seems like a tidier approach, although not quite as short and elegant as it was before I started adding tests. The key idea is that, on the first walk over the model to flatten it, you can make a tree of vector offsets, which simplifies the reconstruction step and the gradients. The gradient of reconstruction isn't an
fmap
walk, but because it already knows the offsets, it does not care if the walks' orders don't match.Should work with numbers too if
isnumeric
is widened to allow them. Should work with mixed element types too, promote for the vector, project back for the reconstruction.The reason to put it here not in Functors is that this package must already depend on ChainRulesCore, and that this builds in
trainable
deeply enough to make having another version without it a pain. And, Functors at present doesn't haveisnumeric
.Closes #40, closes FluxML/Functors.jl#31 if it can.