Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in-place destructure! #165

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Nov 2, 2023

This adds a variant of destructure with minimal changes such that it writes back into the original model, instead of creating a copy. This may close #146, cc @glatteis

Marked draft as it seems surprisingly slow -- why?

julia> model = Chain(Dense(28^2 => 1024, relu), Dense(1024 => 10));

julia> params, re = destructure(model);  # old

julia> params, re! = destructure!(model);  # new

julia> @btime $re($params);  # This is the reconstruction cost
  min 229.334 μs, mean 374.473 μs (70 allocations, 3.11 MiB)

julia> @btime copy($params);  # ... and it's mostly allocation, same mean:
  min 219.417 μs, mean 367.168 μs (3 allocations, 3.11 MiB)

julia> @btime $re!($params);  # this avoids the allocations, but is quite slow.
  min 432.917 μs, mean 472.293 μs (58 allocations, 2.02 KiB)

@linusheck
Copy link

Thanks you! Weird that it is so slow.

@ToucheSir
Copy link
Member

Could the in-place version be tripping some aliasing heuristic and hitting a slow path? I guess a profile would be illuminating.

@kishore-nori
Copy link

kishore-nori commented Apr 11, 2024

Thank you initiating and implementing this idea, I think this is a great idea and would be very useful, I was trying this out because I am interested in-place copy of parameters into a model from a flat vector. From my comparisons, I suspect that one of the reason for the slowness of the in-place version is cache issues involving copyto! (simultaneously handling two big memories) which doesn't occur in the usual restructure, this can be observed when we have a relatively smaller model.

Additionally the following version (just utilising the fact that copyto! doesn't need reshape, as memory based) is faster in my tests,

function _rebuild_alt!(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...)
  len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
  fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
    vecy = vec(y)
    copyto!(y, _getat_alt(vecy, o, flat, view))
  end
  x
end

_getat_alt(y::AbstractVector, o::Int, flat::AbstractVector, get=getindex) =
ProjectTo(y)(get(flat, o .+ (1:length(y))))

and its get better than the usual restructure when the model get smaller, which probably directs at cache issues for copyto!

using Flux, Optimisers, Zygote, BenchmarkTools

N = 1024
model = Chain(Dense(28^2 => N, relu), Dense(N => 10));

params,re = destructure(model)
params!,re! = destructure!(model)
params_alt!,re_alt! = destructure_alt!(model) # using above alternatives

@btime $re($params)
@btime $re!($params)
@btime $re_alt!($params)

  106.964 μs (44 allocations: 3.11 MiB)

  250.546 μs (35 allocations: 1.53 KiB)

  156.664 μs (39 allocations: 1.69 KiB)

When I choose $N = 100$ I get the following timings:

  12.184 μs (43 allocations: 312.61 KiB)

  21.374 μs (35 allocations: 1.53 KiB)

  7.651 μs (39 allocations: 1.69 KiB)

@mcabbott
Copy link
Member Author

Ah that looks great, thanks for digging! For me, with the example at top:

julia> @btime $re($params);  # This is the reconstruction cost
  min 92.167 μs, mean 301.432 μs (44 allocations, 3.11 MiB)

julia> @btime copy($params);  # ... and it's mostly allocation, same mean:
  min 97.333 μs, mean 309.699 μs (2 allocations, 3.11 MiB)

julia> @btime $re!($params);  # new version without reshape
  min 58.333 μs, mean 62.932 μs (39 allocations, 1.69 KiB)

and with N=100:

julia> @btime $re($params);
  min 7.167 μs, mean 25.760 μs (43 allocations, 312.67 KiB)

julia> @btime copy($params);
  min 4.944 μs, mean 31.490 μs (2 allocations, 310.67 KiB)

julia> @btime $re!($params);
  min 8.812 μs, mean 9.047 μs (39 allocations, 1.69 KiB)

I think the mean times are probably a better indication of the cost in actual use, when allocations differ so much, although possibly not perfect.

@kishore-nori
Copy link

kishore-nori commented Apr 11, 2024

Nice, that's good to know, the in-place version seems to be pretty stable with the timings, and how do I make @btime output both min and mean like you have? (sorry for going off-topic).

And is the PR good to go?

@mcabbott
Copy link
Member Author

Mean is from JuliaCI/BenchmarkTools.jl#258, which I should eventually re-write to @btime / @btimes or @bmin / @btime or something.

I see I did write some tests of this, it could all use one more look over. There's a commented out destructure!(flat::AbstractVector, x) which would allow the other direction not to allocate a whole copy too... not sure if that's almost ready or another whole project.

@kishore-nori
Copy link

Mean is from JuliaCI/BenchmarkTools.jl#258

Interesting! That would be very useful! I ll see if I can take some pirate code out of it to start using locally :)

I didn't take a look at destructure!(flat::AbstractVector, x) earlier, but looking at it now, I think it seems good, just that the length compatibility check needs to be done post all the copyto! to flat is done or in the last iteration. Can we do it before using Flux.State (earlier I was using Flux.Params for these purposes; but these parse over all the layers to make it happen, so there is a redundancy :/). Another thing is the type checking for flat, is it required here? Anyways, I think for most purposes destructure! is used just once (I can't think of cases other than where there is pruning involved), so even if destructure!(flat::AbstractVector, x) is not ready, the rest should be useful in its own right, I think. But is there a reason on why it was commented out? I can try looking into it..

@mcabbott mcabbott force-pushed the in-place-destructure branch from 89c8d43 to 058a25b Compare April 11, 2024 15:30
@mcabbott
Copy link
Member Author

I commented the method out just to focus on getting one thing working first. I believe it still needs tests, but otherwise this is nearly done.

Maybe I should check that my scary warning is true. I think something like this will return zero gradient:

v, re! = destructure!(model)
gradient(v) do w
  _ = re!(w)  # mutates model, in a way Zygote does not see
  sum(abs2, model(x))
end

@kishore-nori
Copy link

That's great!

Yes, you are right it returns (nothing,), _ breaks the connection in the chain of rrules.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Restructure makes a copy
4 participants