-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Frozen parameters #107
Comments
Optax's design may be of interest here. They of course can get away with making everything immutable. However, if we think of a masked state tree as a temporary view of the original, perhaps we can do something similar. One concern with wholesale replacement vs mutating a flag is tied weights. We'd either have to do a two-pass solution to also catch those, or document that they won't be picked up.
Inserting this struct in place of leaf nodes instead of entire subtrees would mean more time spent traversing, but it would preserve the structural equivalence.
Accessors.@set state.layers[1].enc[3] = state.layers.[1].dec[3].parent Would be the direct equivalent. Not prettier, but at least more familiar syntax. At this stage, I think figuring out how to tie immutable params in bulk can be left to users. |
To be clear I've only considered reversible modifications. So the closest thing is 1.,#49 which replaces the Leaf with a different one. Irreversibly truncating the state tree is another option. But then perhaps we need a way to merge back the old one. And the user needs to keep two objects.
Good question. First, what's the desired behaviour here at all? If you freeze |
How common are state tree merges? If you make
There doesn't seem to be a clear answer, yeah. For safety then, we should try to match current Flux semantics and freeze ties as well I think. |
The For shared parameters, maybe the word "freeze" implies they never change (hence tied ones must be frozen too) while "mask" could be read as caring about some gradients. As you say Flux freezes both. |
The optax one sees the entire tree. In their example they map over all leaves in the callback function, but as long as you return something with the same shape as the original state they don't care how you do it. Optax in general though tends to write their rules "vectorized", however (instead of taking in a leaf, each rule takes in a state tree and is responsible for mapping some function over each leaf), so a direct comparison to Optimisers.jl needs to account for that. |
Is there anything else to do here? |
Only the part about having a mechanism (e.g. the proposed |
It would be nice to be able to temporarily exclude some parameters from training.
(Edit: I forgot that there is FluxML/Flux.jl#1931, now folded in here.)
One mechanism is to alter
Leaf
to record whether is is frozen. This is what Per-leaf freezing #49 does, and what Allow shared parameters, take III #106 suggests as an aside. The former is immutable, changed by walking & re-building. The latter makesLeaf
mutable (for other reasons) so this can be changed in place. (Edit: implemented in Addfreeze!
/thaw!
#112, merged.)Another mechanism would be to insert some
Frozen
struct into the state tree which stops further exploration. This may make it easier to freeze a whole branch. But will result in a tree with different labels to the model, some pattern likemodel.layers[1].layers[1].weight
will no longer translate directly to one for the state tree.A similar struct could equally be inserted into the model not the state. Or into both. Since gradient calculation never sees the state, changing the model may allow for faster gradients. Does Optimisers.jl own the
struct Frozen
, if it is to recognise it?Maybe independently, it needs a friendly way to set & remove these labels.
PR Per-leaf freezing #49 proposes that you give an address like
freeze(state, (:layers, 1, :enc, 3))
. It seems a bit awkward to require you to know all the field names from the root.It would also be possible to work based on just one field name:
freeze(state, :enc)
acts on anything within any field calledenc
(which in practice is someChain(enc = ..., dec = ...)
). Likewisefreeze(state, :bias)
could affect every layer.Another possibility is to allow control based on the type in the model. Then it has to walk both,
state = freeze(state, model, cond)
or perhapsstate = freeze(f, state, model)
wheref
is ado
block which testsx isa Dense
or whatever. Doesn't lend itself so obviously to freezing only some fields,enc
orbias
... unlessf
returns not a Bool but a list of fields, likex isa Chain && return :enc
.If the modification is to the model, then 6. becomes
model = freeze(f, model)
.If Leaf is mutable, then instead of an address you can just pass a part of the tree:
freeze!(tree.layers[1].enc[3])
, after confirming thatmodel.layers[1].enc[3]
is the part you want. (Edit: implemented as Addfreeze!
/thaw!
#112, merged.)There's a related API question for shared weights. At present Flux (and Functors) rely on objectid. This won't work for immutable arrays.
One idea is to wrap them in a struct like
TiedWeight(array, Ref())
to get an objectid (and possibly remove this later).The idea of Transparent handling of tied weights #100 is that instead the state tree can have the same (mutable)
Leaf
struct at the location of tied arrays. How do you construct this? With 4. this might betie(state, (:layers, 1, :enc, 3) => (:layers, 1, :dec, 3, :parent))
where the:parent
is because of a Transpose. Is there a less ugly way?The text was updated successfully, but these errors were encountered: