Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load flat parameters without mutation or restructure #2026

Open
baedan opened this issue Jul 24, 2022 · 6 comments
Open

load flat parameters without mutation or restructure #2026

baedan opened this issue Jul 24, 2022 · 6 comments

Comments

@baedan
Copy link

baedan commented Jul 24, 2022

is there a way to create a copy of an existing model with a new, flat parameters vector, without mutating, and without using restructure? the reason for the latter two is because i need the higher order derivatives wrt the flat parameter vector.

i've been fumbling about different threads (#1979 was particularly relevant) to find a solution without success, and would appreciate any pointers!

@mcabbott
Copy link
Member

mcabbott commented Jul 24, 2022

You can try, the simplest way to write this with fmap is very simple. But its gradient needs the gradient of getindex, and IIRC this is still a problem for Zygote & 2nd order. Although today I hit a different error...

function fromflat(x, flat::AbstractVector)
  off = offsets(x)
  fmap(x, off; exclude = x -> x isa AbstractArray{<:Number}) do y, o
    v = flat[o .+ (1:length(y))]  # this needs grdient of getindex 
    reshape(v, axes(y))
  end
end

using Functors: fmap
function offsets(x)
  len = Ref(0)
  fmap(x; exclude = x -> x isa AbstractArray{<:Number}) do y
    o = len[]
    len[] = o + length(y)
    o
  end
end
using ChainRulesCore
ChainRulesCore.@non_differentiable offsets(::Any)

model = (x = rand(2), y = (sin, rand(2)))
fromflat(model, [1,2,3,4])  # ok

using Zygote
gradient(v -> fromflat(model, v).y[2] |> sum |> abs2, [1,2,3,4])  # ok

gradient([1,2,3,4], model) do v, m
  gradient(w -> fromflat(m, w).y[2] |> sum |> abs2, v)[1] |> sum
end # UndefVarError: S not defined

(At 1st order, the gradient of getindex makes a zero vector the same size as the whole model, and avoiding that allocation is one of the reasons that there are gradient rules here at all. The mutation of this array of zeros is why getindex at 2nd order is a problem.)

In FluxML/Optimisers.jl#54, I do think the logic was broadly correct for 2nd order derivatives -- each order reverses the arrows between flat and structured. What goes wrong is something about exactly how Tangent types are constructed, or perhaps how they are converted to & from Zygote's types. It ought to be possible to straighten that out.

@mcabbott
Copy link
Member

ps. Zygote has many difficulties with 2nd derivatives, see https://github.com/FluxML/Zygote.jl/labels/second%20order . Some of these are about its handling of Tangent types / translation to & from ChainRules, which may be related.

The most reliable option tends to be ForwardDiff over Zygote (as e.g. in Zygote.hessian). Some people also try mixing ReverseDiff & Zygote.

@baedan
Copy link
Author

baedan commented Jul 26, 2022

thank you for your help!

The most reliable option tends to be ForwardDiff over Zygote (as e.g. in Zygote.hessian). Some people also try mixing ReverseDiff & Zygote.

will take a look at them — the second derivative is of a scalar wrt a large parameter vector, which would make reverse mode a good candidate here, right?

two other lingering questions:

  1. what’s your recommended way to use Flux for something like a hypernetwork, where a function of the output of one network is used as loss for another, whose output determines the shape/weights of the former? the most intuitive approach to me is to use re/destructure — when updates are sparse, would this not incur a lot of cost from creating a new model each time, as compared to updating in place, but which Zygote does not support? this is not specifically my case but they shared similarities in this respect.
  2. how can i modify fromflat to use on a Chain? now that i think about it, it should work since a Chain is just a 1-tuple of layers, i think, but when i tried yesterday, it didn’t.

@mcabbott
Copy link
Member

second derivative is of a scalar wrt a large parameter vector, which would make reverse mode a good candidate

If f is N->1, then the reverse mode gradient should be more efficient. But the gradient(f) is N -> N, and the Hessian of f is the Jacobian of this function, for which reverse no longer has a theorerical advantage.

In practice ForwardDiff is very simple, robust & low-overhead. But it only knows about arrays, and it does not work with BLAS for matrix mult.

would this not incur a lot of cost from creating a new model each time, as compared to updating in place

I think that mostly you shouldn't worry about this cost, at least to start. Taking the gradient of a model with Zygote will typically allocate as much as a few complete copies. It would not be hard to write a destructure! whose re mutates the original model, since rrules hide this from Zygote's sight, but this will save only one copy & thus nobody has got around to it.

how can i modify fromflat to use on a Chain?

I expected this to work, as fmap already knows how to walk the contents of a Chain, but indeed it does not. It's cut down from the one in Optimisers.jl, but I don't immediately see what I missed.

@baedan
Copy link
Author

baedan commented Jul 27, 2022

haven't checked correctness yet, but happy to report that a second-order Forward.gradient on a first-order Zygote gradient, using restructure.

using Zygote for both yields error Type NamedTuple has no field backing. should i open an issue for that for Zygote?

@DhairyaLGandhi
Copy link
Member

It should be possible to use views. It incurs the same getindex penalty but avoids allocating a new parameter array. I was working on NIF which requires heavy use of restructure creating a new model for every timestep. As you can imagine that would quickly lead to slow downs. So i had come up with a way minimise re usage and instead reuse the same vector (matrix in the case of NIF) and get good performance. We'll see if that can be used instead of the current restructure approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants