-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add in-place destructure!
#165
base: master
Are you sure you want to change the base?
Conversation
Thanks you! Weird that it is so slow. |
Could the in-place version be tripping some aliasing heuristic and hitting a slow path? I guess a profile would be illuminating. |
Thank you initiating and implementing this idea, I think this is a great idea and would be very useful, I was trying this out because I am interested in-place copy of parameters into a model from a flat vector. From my comparisons, I suspect that one of the reason for the slowness of the in-place version is cache issues involving Additionally the following version (just utilising the fact that function _rebuild_alt!(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...)
len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
vecy = vec(y)
copyto!(y, _getat_alt(vecy, o, flat, view))
end
x
end
_getat_alt(y::AbstractVector, o::Int, flat::AbstractVector, get=getindex) =
ProjectTo(y)(get(flat, o .+ (1:length(y)))) and its get better than the usual using Flux, Optimisers, Zygote, BenchmarkTools
N = 1024
model = Chain(Dense(28^2 => N, relu), Dense(N => 10));
params,re = destructure(model)
params!,re! = destructure!(model)
params_alt!,re_alt! = destructure_alt!(model) # using above alternatives
@btime $re($params)
@btime $re!($params)
@btime $re_alt!($params)
106.964 μs (44 allocations: 3.11 MiB)
250.546 μs (35 allocations: 1.53 KiB)
156.664 μs (39 allocations: 1.69 KiB) When I choose 12.184 μs (43 allocations: 312.61 KiB)
21.374 μs (35 allocations: 1.53 KiB)
7.651 μs (39 allocations: 1.69 KiB) |
Ah that looks great, thanks for digging! For me, with the example at top: julia> @btime $re($params); # This is the reconstruction cost
min 92.167 μs, mean 301.432 μs (44 allocations, 3.11 MiB)
julia> @btime copy($params); # ... and it's mostly allocation, same mean:
min 97.333 μs, mean 309.699 μs (2 allocations, 3.11 MiB)
julia> @btime $re!($params); # new version without reshape
min 58.333 μs, mean 62.932 μs (39 allocations, 1.69 KiB) and with N=100:
I think the mean times are probably a better indication of the cost in actual use, when allocations differ so much, although possibly not perfect. |
Nice, that's good to know, the in-place version seems to be pretty stable with the timings, and how do I make And is the PR good to go? |
Mean is from JuliaCI/BenchmarkTools.jl#258, which I should eventually re-write to I see I did write some tests of this, it could all use one more look over. There's a commented out |
Interesting! That would be very useful! I ll see if I can take some pirate code out of it to start using locally :) I didn't take a look at |
89c8d43
to
058a25b
Compare
I commented the method out just to focus on getting one thing working first. I believe it still needs tests, but otherwise this is nearly done. Maybe I should check that my scary warning is true. I think something like this will return zero gradient:
|
That's great! Yes, you are right it returns |
This adds a variant of
destructure
with minimal changes such that it writes back into the original model, instead of creating a copy. This may close #146, cc @glatteisMarked draft as it seems surprisingly slow -- why?