Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

define modules function #1444

Merged
merged 6 commits into from
Mar 10, 2021
Merged

define modules function #1444

merged 6 commits into from
Mar 10, 2021

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Dec 27, 2020

A function returning an iterator over non-leaf parameters.
Mainly motivated by the need to apply L2 regularization to weights only (see #1284 , #939),
but may be of more general use.

Fix #1294, Fix #1284

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • Final review from @dhairyagandhi96 (for API changes).

@DhairyaLGandhi
Copy link
Member

Better done in with functors directly

@CarloLucibello
Copy link
Member Author

Better done in with functors directly

how?

@CarloLucibello
Copy link
Member Author

@DhairyaLGandhi bump

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member Author

@DhairyaLGandhi bump

@mcabbott
Copy link
Member

Is this the intended behaviour:

julia> modules(Chain(SkipConnection(Conv((2,3),4=>5;pad=6,stride=7),+),LayerNorm(8)))
8-element Vector{Any}:
 Diagonal(8)
 (6, 6, 6, 6)
 Conv((2, 3), 4=>5)
 (1, 1)
 LayerNorm(8)
 (7, 7)
 Chain(SkipConnection(Conv((2, 3), 4=>5), +), LayerNorm(8))
 SkipConnection(Conv((2, 3), 4=>5), +)

@CarloLucibello
Copy link
Member Author

Those tuples should be excluded I guess, so ideally it should look like:

julia> modules(Chain(SkipConnection(Conv((2,3),4=>5;pad=6,stride=7),+),LayerNorm(8)))
5-element Vector{Any}:
 Diagonal(8)
 Conv((2, 3), 4=>5)
 LayerNorm(8)
 Chain(SkipConnection(Conv((2, 3), 4=>5), +), LayerNorm(8))
 SkipConnection(Conv((2, 3), 4=>5), +)

(Diagonal looks weird there but it is fine, it is contained in LayerNorm).
Also I don't understand the ordering, it doesn't look depth first as I expected. Maybe it would be even better to have something breadth first? I'd have to think about the implementation

@mcabbott
Copy link
Member

mcabbott commented Jan 16, 2021

I was wondering why include the containers Chain, SkipConnection too, though. Are there any containers which also have parameters (directly) which ought to be included?

I didn't think IdSet was sorted, but maybe this counts as depth-first:

function layers(x, cache=[])
    if x in cache || _leaflike(x)
        return cache
    end
    children = trainable(x)
    if all(_leaflike, children)
        push!(cache, x)
    else
        foreach(y -> layers(y, cache), children)
    end
    return cache
end

_leaflike(x) = isleaf(x)
_leaflike(::Tuple{Vararg{<:Number}}) = true
_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true

layers(Chain(SkipConnection(Conv((2,3),4=>5;pad=6,stride=7),+),LayerNorm(8)))
layers(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4)))

@CarloLucibello
Copy link
Member Author

I was wondering why include the containers Chain, SkipConnection too, though. Are there any containers which also have parameters (directly) which ought to be included?

Since this is supposed to also work on generic functors defined by users, we cannot exclude mix of parameters and submodules. Also, the usage may be more generic of that of parameter extraction. In practice, I would like to replicate what is done by pytorch's modules method here

@CarloLucibello CarloLucibello mentioned this pull request Jan 16, 2021
92 tasks
@ToucheSir
Copy link
Member

ToucheSir commented Jan 16, 2021

Shameless plug, but I think the conv tuple issue (and to some extent custom layer handling) will be ameliorated by FluxML/Functors.jl#7. Most layers are marking too many fields as trainable or having to override trainable because they want non-trainable fields to survive functoring.

@darsnack
Copy link
Member

I think this is more inline with what I was thinking with my comment on #1144. Having something like this might make writing things like the L2 example easy enough for users to do without needing to know Functors.jl.

@DhairyaLGandhi
Copy link
Member

I'm inclined to use functors approach here for clarity and it's not difficult to write l2 as it is.

@ToucheSir
Copy link
Member

I'm not sure I follow, this uses functors (via trainable) under the hood? Were you thinking of something like an fmap(reduce) over leaf nodes?

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Feb 12, 2021

@DhairyaLGandhi I asked you multiple times here (with an "how?" and two bumps) to be more explicit about the functor approach you suggest. And even if there are alternative approaches, they don't seem very straightforward, I don't see why we cannot provide a convenience utility such as the proposed here.

@DhairyaLGandhi
Copy link
Member

So in the comment above, it even mentions a simple fmap(reduce), I was thinking of letting users mark which fields are params (which we already do, and is documented), we get the same effect. Marking fields as trainable can be done via the functor macro already, so I don't think there's any more need for this.

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Feb 12, 2021

I'd like to see the code that solves the following problem:
"single out weight params (and not biases) from Dense and Conv layers of an arbitrary model (and arbitrarily nested) model so that L2 regularization can then be applied".
I can glimpse what fmap(reduce) could do, I don't understand what @DhairyaLGandhi says in the last comment.
So please, let's be very concrete and explicit, show the code, so that we can have a proper discussion on whether it's reasonable or not that users should implement it by themselves and compare it with what they'll have after this PR:

apply_reg(m:) = false
apply_reg(m::Union{Dense,Conv}) = true
L2reg(model) = sum(norm(m.W)^2 for m in model if apply_reg(m))

Btw, everyone understands iterators, fmap is much more problematic instead, and I think we should avoid mentally overloading the users as much as we can.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Feb 12, 2021

So as in the docstring, you have to define what happens if one were to hit the Dense layer (and every other layer). This doesn't account for the differences in removing the bias from regularisation which is what #1284 needs.

Also, if we have to define the L2 on every layer anyway, we might as well iterate over a chain or tuple. This doesn't add much benefit over that.

struct Mw{T,F}
  a::T
  b::F
end

@functor Mw (a,)
or trainable or defining L2 on every layer

takes care of handling layers specifically, but there isn't a general route that one can take. The fmap(reduce) approach seems to me to generalise better.

@darsnack
Copy link
Member

L2 is just a simple toy example of the value in this PR. The core problem is I have some function f that applies to Dense and Conv (just examples) in a special way, and I want to apply f to an arbitrary model, recursing through layers until I hit all leaves, then aggregate the results of applying f to those leaves. There are three (reasonable) ways to go about this:

  1. Pure dispatch. Define the following:

    f(m::Dense) = # ...
    f(m::Conv) = # ...
    f(m) = # default value (usually 0 or 1) ...
    f(m::Chain) = sum(f.(m.layers)) # some reduction operator, I used sum

    This is fine when you know m, but it doesn't scale to arbitrary m. For example, if there is a Parallel in the model, then this would break unless you define f(m::Parallel).

  2. Functors.jl. This has a couple ways I think. You can either define a new Dense or Conv where only the weight is functored then use fmap, or you can use isleaf to recurse all the way except the last stage to apply f on Dense instead of the parameters of Dense.

  3. This PR. You only need to define f on the layers you care about. The recursion in (1) is handled for you through modules.

Writing code like the L2 example is extremely common. I cannot think of a single project in Flux or PyTorch where I have not written code like this.

(1) is probably what your beginner Julia user might be able to do. It's fine for solo projects, but if you want your code to be used by others, then you need (2) or (3).

I agree with @CarloLucibello that (2) is ugly and no user should ever need to understand Functors.jl or fmap for such a basic code pattern. Also, there is no way I am going to define a new layer just to compute things like L2. And I don't want to override Dense's trainable because I want it to the regular thing for the rest of my program.

(3) is basically (2) except all the ugly stuff is hidden behind modules allowing you to use iterators + dispatch to write very little code. IMO this PR brings the power of Functors.jl to the average Flux user. I would much prefer we merge some version of this over just documenting how to do it with fmap.

Now it's unclear to me whether the suggestion is to use Functors.jl directly or to change this PR to use fmap internally. You can't directly use fmap because the leaves in this tree are one layer higher than in Functors.jl. So, I don't really see exactly how you would use it internally, but even so, I think it makes more sense to define "leaf-ness" on trainable for this feature (which relies on Functors.jl anyways).

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Feb 12, 2021

Sorry for the late reply I was thinking of something like this (its not tested code exactly, but worth talking about). This would need some bug fixing for handling Chain neatly, but allows for the indirection, and works on all layers. It has similar issues surrounding not being able to remove certain params, but I feel that is better handled by functor anyway.

cantrav(x::T, fs = fieldnames(T)) where T = any(!isleaf, (getfield(x, f) for f in fs))

function func(f, m::T, fs = fieldnames(T)) where T
  cantrav(m) || return f(m)
  (func(f, getfield(m, f_)) for f_ in fs)
end

@ToucheSir
Copy link
Member

ToucheSir commented Feb 12, 2021

I'll preface this by saying that I don't see the harm in having modules. It should probably return a generator though.

Adapting the example from the PR:

L2(m) = 0f0
L2(m::Dense) = sum(abs2, m.W)

This is equivalent to 1) in @darsnack's example, but doesn't handle container types. If we add the constraint that it shouldn't have to handle containers (i.e. you shouldn't have to override for Chain), there are two options for calculating the total l2:

  1. L2reg(model) = sum(L2(m) for m in modules(model)), which is given in the PR
  2. L2reg(model) = fmapreduce(L2, sum, m). Where fmapreduce is a stand-in name for a function that recurses to the _isleaf level and passes each node to L2.

Note how the only real difference here is in whether we use internal or external iteration. In fact, it's easy to see how one could implement modules in terms of fmapreduce or vice versa! The point of both is that a user should not be forced to create specific dispatches for types they don't care about or to perform their own traversal in the function being applied to each node. I personally find something like 2. more elegant because it resembles a catamorphism (which, being a recursion scheme, is literally a method for traversing arbitrary structures), but I don't find 1. inelegant either.

@darsnack or @CarloLucibello, feel free to throw up some more counterexamples and I'll adapt them so we can draw some more informed comparisons.

@CarloLucibello
Copy link
Member Author

ahah, ok, I don't think I can learn about catamorphisms on a Friday night but I'll try tomorrow

@darsnack
Copy link
Member

Is _isleaf meant to be like Functors.isleaf but one level higher? I'm okay with using (2) in the implementation of this PR. One thing to note between @DhairyaLGandhi's snippet and fmapreduce (assuming this is a catamorphism), is that fmapreduce will apply f to the non-leaf as well. Of course, this is easily fixed, I'm just pointing it out.

The difference between fmapreduce and acc(f(i...) for i in itr) is subjective (much like whether you prefer mapreduce or acc(f(i...) for i in itr)). Personally, I feel that fmapreduce belongs in Functors.jl and that the generator syntax is easier for most users to comprehend. At the very least because it is easy to print out the iterator elements with modules (you could fmap(show, m) but you get my point). Lists may be a special case of fmap/functors, but at the end of the day, everyone understands lists.

One thing that both the fmapreduce and modules doesn't do is intercepting the recursion at an intermediate level. In (1) from my comment, I could write something like f(m::Parallel) to block recursion into Parallel and special case. Again, you could modify both options to handle these cases. But I'm just bringing it up to highlight things to consider.

@ToucheSir
Copy link
Member

So this is a bit of a rabbit hole, but the reason I brought up recursion schemes is that it's possible to do exactly what you mention. Part 5 of the series linked above goes into all of the nitty-gritty details.

Will users need that level of control in practice though? I think it's reasonable to ask anyone who does to handle their own recursion. We've already gone beyond pretty much every other DL framework at this point anyhow.

@darsnack
Copy link
Member

Good point. Yeah just pointing out random thoughts. I agree it doesn't come up often enough to support.

On this PR, my inclination is to have fmapreduce in Functors.jl then wrap that in a generator interface with modules.

@darsnack
Copy link
Member

Why can't I @functor M1 and @functor M2?

@darsnack
Copy link
Member

With the implementation in FluxML/Functors.jl#13 copied into the REPL:

julia> using Functors

julia> Functors.isleaf(x) = children(x) === ()

julia> children(x) = Functors.functor(x)[1]
children (generic function with 1 method)

julia> function fcollect(x; cache = [], exclude = v -> false)
         x in cache && return cache
         if !exclude(x)
           push!(cache, x)
           foreach(y -> fcollect(y; cache=cache, exclude=exclude), children(x))
         end
         return cache
       end
fcollect (generic function with 1 method)

julia> modules(m) = fcollect(m; exclude = Functors.isleaf)
modules (generic function with 1 method)

julia> struct M1
       a
       b
       end

julia> Functors.@functor M1

julia> struct M2
       a
       b
       end

julia> Functors.@functor M2

julia> m2 = M2(M1(1, 2), 3)
M2(M1(1, 2), 3)

julia> modules(m2)
2-element Vector{Any}:
 M2(M1(1, 2), 3)
 M1(1, 2)

We can discuss what exclude should be set to (I do think it will need to be a little more than Functors.isleaf).

@DhairyaLGandhi
Copy link
Member

#1504 (comment)

@darsnack
Copy link
Member

I think you'll need to give a MWE of the relevance of that comment to layers/models which is what modules(m) is designed for.

@DhairyaLGandhi
Copy link
Member

When layers have components other than arrays as params...

@darsnack
Copy link
Member

To be honest, the generic solution to that is more a commentary on Functors.jl than this PR.

We could allow modules(m; ignore = []) to accept a list of ignore types which get excluded from the collected list in addition to what's already considered a leaf. Alternatively, the user can define Functors.isleaf(::RGB) = true.

@darsnack
Copy link
Member

I think the concern is very valid, but IMO this PR does the right thing under the current limitations of Functors.jl. The exclude field in fcollect and fmap make Functors.jl more flexible for use cases like this. How we expose that field to the user is a good question.

@CarloLucibello CarloLucibello requested a review from darsnack March 9, 2021 04:32
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some small doc changes, but looks good to me. This is a useful utility for exposing functors to users without needing them to understand Functors.jl. I think it should be included (interested if @ToucheSir agrees).

The concern about parameters that aren't AbstractArray{<:Number} being leaves is something that we need to address generically with Functors.jl. It's out of scope for this PR IMO.

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
CarloLucibello and others added 2 commits March 9, 2021 21:41
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
@CarloLucibello
Copy link
Member Author

bors r+

@bors
Copy link
Contributor

bors bot commented Mar 10, 2021

Build succeeded:

@bors bors bot merged commit e2b73e1 into master Mar 10, 2021
@DhairyaLGandhi
Copy link
Member

I thought the function is still producing results like in #1444 (comment) instead of

[M1(...),
3,
M2(...)]

@DhairyaLGandhi DhairyaLGandhi deleted the cl/modules branch March 10, 2021 13:04
@darsnack
Copy link
Member

The 3 is unwanted in this case since it is a leaf. modules stops "one-level up" from fmap.

@DhairyaLGandhi
Copy link
Member

I think that isn't something that this function should decide, instead return things at the "same level", once it finds a non leaf.

@darsnack
Copy link
Member

Some layers go deeper than others. The goal here is to extract all the layers down to primitives like Dense or Conv. This function is opinionated to give users what they are looking for 95% of the time. The other 5% can use fcollect directly. That's why we separated out the core functionality into Functors.jl and provide a thin user-friendly, opinionated wrapper here.

@DhairyaLGandhi
Copy link
Member

This isn't doing the right thing is it, you want to find the first instance you find a struct you can't look inside, for everything else, its the users responsibility to decide how far down we are allowed to recurse. For the case you describe, why not iterate over a Chain instead, which is much more intuitive anyway.

@darsnack
Copy link
Member

you want to find the first instance you find a struct you can't look inside

If I have Parallel(+, Chain(Dense(...), Dense(...)), Conv(...)), why would you want it to stop recursing the Chain because it hit Conv?

users responsibility to decide how far down we are allowed to recurse

Again, this is why fcollect exists.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

define modules function How to apply L2 regularization to a subset of parameters?
5 participants