Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transparent handling of tied weights #100

Closed
wants to merge 2 commits into from
Closed

Transparent handling of tied weights #100

wants to merge 2 commits into from

Conversation

ToucheSir
Copy link
Member

This makes Leaf a mutable type so that tied weights are represented by the same leaf instance.

Although only mutable array types are automatically detected as tied, one can also tie immutable parameters by manually creating shared Leafs.

The test suite is practically the same as #42, with some slight modifications since there is no equivalent to Tied in this PR.

This makes `Leaf` a mutable type so that tied weights are represented by the same leaf instance.

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Comment on lines +48 to +51
_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, _...) = nothing
_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, ::Zero, ::Zero...) = nothing
_accumulate!(::AbstractDict{Leaf,Any}, ℓ::Leaf, _, ::Zero, ::Zero...) = nothing
_accumulate!(::AbstractDict{Leaf,Any}, _, _, ::Zero, ::Zero...) = nothing
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of overloads with high degrees of overlap across multiple functions. I couldn't think of a way to deduplicate some of them, so if anyone has ideas that would be swell.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can be just 4 methods, if the state tree has () instead of nothing, as in #106.

I also think it would be clearer to write variable names more often, not _, since 5 arguments is quite a few to count.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I tried that, but ran into ambiguities. This is the smallest number of methods I could come up with that didn't have ambiguities. If you can narrow that down, that would be superb.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underscores are mostly to appease the linter and possibly improve latency(??) Perhaps ::Any would work better, though I'm not sure that addresses your point about clarity?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how names can affect latency. I just mean they let your eye know what the 4th argument means, which ::Any doesn't help.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My impression was it implicitly acted as a @nospecialize, but looking at https://github.com/JuliaLang/julia/blob/98e1b13a7db5aa1d05b9a48085993375cf2298d0/src/method.c#L656 that may not be the case.

Comment on lines +88 to +92
tree′ = fmap(tree; cache, exclude = Base.Fix2(isa, Leaf)) do ℓ
Leaf(ℓ.rule, fmap(copy, ℓ.state; cache, exclude = iswriteable))
end
x′ = fmap(copy, x; cache = empty!(cache), exclude = iswriteable)
x̄s′ = fmap(copy, x̄s; cache = empty!(cache), exclude = iswriteable)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out we were not defensively copying state or gradients before, so they could still be mutated by a call to update.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine never to copy gradients. It's never safe to mutate them anyway, a rule which does so (or an rrule likewise) is simply a bug.

For copying state, can't we just say @functor Leaf (state,) and let fmap do it?

Copy link
Member Author

@ToucheSir ToucheSir Aug 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For copying state, can't we just say @functor Leaf (state,) and let fmap do it?

That breaks Leaf identity, unfortunately. fmap will end up untying shared parameters by creating new leaves at each location during reconstruction.

Not defensively copying gradients seems fine though, good point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't fmap will preserve the Leaf identifications? That's what its cache is for.

Copy link
Member Author

@ToucheSir ToucheSir Aug 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a recollection that it sometimes only preserved leaves, but re-reading the code you are correct.

@ToucheSir ToucheSir closed this Jul 10, 2022
@ToucheSir ToucheSir reopened this Jul 10, 2022
@ToucheSir
Copy link
Member Author

Doctests appear to be picking up changes on master that aren't present on this branch, is that expected? I can't tweak the test because it doesn't exist here!

@ToucheSir ToucheSir requested review from MikeInnes and removed request for MikeInnes July 10, 2022 03:17
@mcabbott mcabbott added the enhancement New feature or request label Jul 25, 2022
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I tried to read this, and in the process of understanding what's going on, wrote #106. Maybe that explains my thoughts more clearly than the comments here.

Comment on lines +9 to 12
mutable struct Leaf{R,S}
rule::R
state::S
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once this is mutable, then update!(tree, model, grad) can be guaranteed to alter the state tree in place. This opens the possibility of simplifying the interface, and never returning multiple things whose order you have to remember.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +81 to 83
xtree = map(cb, tree, x′, x̄s′...)
return map(first, xtree), re(map(last, xtree))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This complication exists I think to reconstruct both the tree and the model on the way out of the recursion. But once Leaf is mutable, can't we skip that, and just mutate it? Just call fmap?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, absolutely. I held off from doing that here in case some user was stashing old state trees and would be blindsided by the values in those leaves suddenly changing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow. update! claimed it would mutate the states if it wanted to, and would typically alter arrays. (And update claimed not to, but had a bug.)

Copy link
Member Author

@ToucheSir ToucheSir Aug 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if you had immutable arrays in your state tree before, the original state tree would be unchanged after update!. Perhaps we don't feel that was ever a solid guarantee (I don't), but we ought to get that point out in writing for posterity.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I think we should change the doc for update! to be explicit that it now guarantees to update the state tree. (I thought the old one said "inputs are trash afterwards" but in fact it is explicit only about the model.)

update! need not in fact return two arguments, but whether that is too confusing to change (and to differ from update which must) is another question.

Comment on lines +88 to +92
tree′ = fmap(tree; cache, exclude = Base.Fix2(isa, Leaf)) do ℓ
Leaf(ℓ.rule, fmap(copy, ℓ.state; cache, exclude = iswriteable))
end
x′ = fmap(copy, x; cache = empty!(cache), exclude = iswriteable)
x̄s′ = fmap(copy, x̄s; cache = empty!(cache), exclude = iswriteable)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine never to copy gradients. It's never safe to mutate them anyway, a rule which does so (or an rrule likewise) is simply a bug.

For copying state, can't we just say @functor Leaf (state,) and let fmap do it?

end
end

_add!(x, x̄) = iswriteable(x) ? (x .= x .+ x̄) : eltype(x).(x .+ x̄)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is what we want. We should never ever mutate a gradient, but I think we can just call @lazy x̄old + x̄new and lazily accumulate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My worry with the lazy accumulation approach is threefold. First, it blows any chance of making this type stable out the window. Secondly, it's possible the lazy Broadcasted may be evaluated multiple times as it passes through a chain of rules and thus incur accumulation overhead more than once. Lastly, complicated broadcasts come with a lot of compilation latency (especially on GPU) and I'm wary of making optimizers worse than they already are on that front.

Copy link
Member

@mcabbott mcabbott Aug 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do it eagerly to avoid this. But we cannot mutate the gradients, as they may be shared with others (e.g. from the rule for +).

Lazy .+ is almost free, it's very difficult to picture evaluating this twice ever costing as much as a copy. Not sure about compile times.

Copy link
Member Author

@ToucheSir ToucheSir Aug 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point about aliased gradients. If this is a correctness issue, we don't have much of a choice :)

Comment on lines +48 to +51
_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, _...) = nothing
_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, ::Zero, ::Zero...) = nothing
_accumulate!(::AbstractDict{Leaf,Any}, ℓ::Leaf, _, ::Zero, ::Zero...) = nothing
_accumulate!(::AbstractDict{Leaf,Any}, _, _, ::Zero, ::Zero...) = nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can be just 4 methods, if the state tree has () instead of nothing, as in #106.

I also think it would be clearer to write variable names more often, not _, since 5 arguments is quite a few to count.

Comment on lines +64 to +68
# slightly cleaner way of closing over update! internal state
struct UpdateCallback
acc_grads::IdDict{Leaf,Any}
param_cache::IdDict{Leaf,Any}
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The limitation might be mine but I have to say I find this struct really hard to read, compared to just closing over things which have one name.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what "just closing over things which have one name." entails here, can you elaborate? Another reason for the struct over a normal closure is self-recursion, which I use here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean things like this, which define a dict & then use it:

   cache = IdDict{Leaf,Any}()
   _accumulate!(cache, tree, x, x̄s...)

With no further names: no structs, no field names.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall trying this first, and deciding to bundle things into a struct after seeing a lot of long, long lines from threading the two IdDicts through multiple levels of functions. It may also have been tricky to get that working in a backwards compatible way, but it's been long enough that I don't remember the whole context.

@mcabbott mcabbott mentioned this pull request Aug 28, 2022
@mcabbott
Copy link
Member

Here's an evil case for shared parameters:

mutable struct MutTwo; x; y; end
Functors.@functor MutTwo

tmp = MutTwo([1.0], [2.0])
model = (a=tmp, b=tmp, c=MutTwo(tmp.x, tmp.y))
state = Optimisers.setup(Momentum(), model)

model.a === model.b
model.a !== model.c  # fields are identified, but struct is not

state.a.x === state.b.x
state.a === state.b
state.a === state.c  # unavoidable, but means we can't use Leaf ID alone?

mgrad = (a=(x=[1.], y=[10.]), b=(x=[100], y=[1000]), c=(x=[1/3], y=[1/30]))
state2, model2 = Optimisers.update(state, model, mgrad)
model2.a === model2.b
model2.a !== model2.c 

The state of all 3 components is (x=Leaf(...), y=Leaf(...)). A cache which is IdDict{Leaf,Any} can't identify the two structs. A cache which also stores higher levels of the state tree will instead identify all three structs.

One answer here is to store tuples (x, Leaf(...)). Then identifying the Leafs can be used as a trick to tie StaticArray parameters. But cannot tie Array parameters (which aren't already tied by ===).

@ToucheSir
Copy link
Member Author

ToucheSir commented Aug 29, 2022

I don't think we ever guaranteed model.a !== model.c => state.a !== state.c. model2.a !== model2.c after an update seems more like a bug rather than an intrinsic limitation? If x and y are === for all 3 components, then this should just work. I had to read the comments in #106 for context. We weren't preserving identity at higher levels before, so I think that is orthogonal to the issue of tied leaves. It would be nice if we could though, which I see has been done there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Development

Successfully merging this pull request may close these issues.

2 participants