-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
define modules function #1444
define modules function #1444
Conversation
Better done in with functors directly |
how? |
@DhairyaLGandhi bump |
@DhairyaLGandhi bump |
Is this the intended behaviour:
|
Those tuples should be excluded I guess, so ideally it should look like: julia> modules(Chain(SkipConnection(Conv((2,3),4=>5;pad=6,stride=7),+),LayerNorm(8)))
5-element Vector{Any}:
Diagonal(8)
Conv((2, 3), 4=>5)
LayerNorm(8)
Chain(SkipConnection(Conv((2, 3), 4=>5), +), LayerNorm(8))
SkipConnection(Conv((2, 3), 4=>5), +) (Diagonal looks weird there but it is fine, it is contained in LayerNorm). |
I was wondering why include the containers I didn't think IdSet was sorted, but maybe this counts as depth-first: function layers(x, cache=[])
if x in cache || _leaflike(x)
return cache
end
children = trainable(x)
if all(_leaflike, children)
push!(cache, x)
else
foreach(y -> layers(y, cache), children)
end
return cache
end
_leaflike(x) = isleaf(x)
_leaflike(::Tuple{Vararg{<:Number}}) = true
_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true
layers(Chain(SkipConnection(Conv((2,3),4=>5;pad=6,stride=7),+),LayerNorm(8)))
layers(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4))) |
Since this is supposed to also work on generic functors defined by users, we cannot exclude mix of parameters and submodules. Also, the usage may be more generic of that of parameter extraction. In practice, I would like to replicate what is done by pytorch's modules method here |
Shameless plug, but I think the conv tuple issue (and to some extent custom layer handling) will be ameliorated by FluxML/Functors.jl#7. Most layers are marking too many fields as trainable or having to override |
I think this is more inline with what I was thinking with my comment on #1144. Having something like this might make writing things like the L2 example easy enough for users to do without needing to know Functors.jl. |
I'm inclined to use functors approach here for clarity and it's not difficult to write l2 as it is. |
I'm not sure I follow, this uses functors (via trainable) under the hood? Were you thinking of something like an |
@DhairyaLGandhi I asked you multiple times here (with an "how?" and two bumps) to be more explicit about the functor approach you suggest. And even if there are alternative approaches, they don't seem very straightforward, I don't see why we cannot provide a convenience utility such as the proposed here. |
So in the comment above, it even mentions a simple fmap(reduce), I was thinking of letting users mark which fields are params (which we already do, and is documented), we get the same effect. Marking fields as trainable can be done via the functor macro already, so I don't think there's any more need for this. |
I'd like to see the code that solves the following problem: apply_reg(m:) = false
apply_reg(m::Union{Dense,Conv}) = true
L2reg(model) = sum(norm(m.W)^2 for m in model if apply_reg(m)) Btw, everyone understands iterators, |
So as in the docstring, you have to define what happens if one were to hit the Dense layer (and every other layer). This doesn't account for the differences in removing the bias from regularisation which is what #1284 needs. Also, if we have to define the struct Mw{T,F}
a::T
b::F
end
@functor Mw (a,)
or trainable or defining L2 on every layer takes care of handling layers specifically, but there isn't a general route that one can take. The fmap(reduce) approach seems to me to generalise better. |
L2 is just a simple toy example of the value in this PR. The core problem is I have some function
Writing code like the L2 example is extremely common. I cannot think of a single project in Flux or PyTorch where I have not written code like this. (1) is probably what your beginner Julia user might be able to do. It's fine for solo projects, but if you want your code to be used by others, then you need (2) or (3). I agree with @CarloLucibello that (2) is ugly and no user should ever need to understand Functors.jl or (3) is basically (2) except all the ugly stuff is hidden behind Now it's unclear to me whether the suggestion is to use Functors.jl directly or to change this PR to use |
Sorry for the late reply I was thinking of something like this (its not tested code exactly, but worth talking about). This would need some bug fixing for handling Chain neatly, but allows for the indirection, and works on all layers. It has similar issues surrounding not being able to remove certain params, but I feel that is better handled by cantrav(x::T, fs = fieldnames(T)) where T = any(!isleaf, (getfield(x, f) for f in fs))
function func(f, m::T, fs = fieldnames(T)) where T
cantrav(m) || return f(m)
(func(f, getfield(m, f_)) for f_ in fs)
end |
I'll preface this by saying that I don't see the harm in having Adapting the example from the PR: L2(m) = 0f0
L2(m::Dense) = sum(abs2, m.W) This is equivalent to 1) in @darsnack's example, but doesn't handle container types. If we add the constraint that it shouldn't have to handle containers (i.e. you shouldn't have to override for
Note how the only real difference here is in whether we use internal or external iteration. In fact, it's easy to see how one could implement @darsnack or @CarloLucibello, feel free to throw up some more counterexamples and I'll adapt them so we can draw some more informed comparisons. |
ahah, ok, I don't think I can learn about catamorphisms on a Friday night but I'll try tomorrow |
Is The difference between One thing that both the |
So this is a bit of a rabbit hole, but the reason I brought up recursion schemes is that it's possible to do exactly what you mention. Part 5 of the series linked above goes into all of the nitty-gritty details. Will users need that level of control in practice though? I think it's reasonable to ask anyone who does to handle their own recursion. We've already gone beyond pretty much every other DL framework at this point anyhow. |
Good point. Yeah just pointing out random thoughts. I agree it doesn't come up often enough to support. On this PR, my inclination is to have |
a05a0b2
to
489764c
Compare
Why can't I |
With the implementation in FluxML/Functors.jl#13 copied into the REPL: julia> using Functors
julia> Functors.isleaf(x) = children(x) === ()
julia> children(x) = Functors.functor(x)[1]
children (generic function with 1 method)
julia> function fcollect(x; cache = [], exclude = v -> false)
x in cache && return cache
if !exclude(x)
push!(cache, x)
foreach(y -> fcollect(y; cache=cache, exclude=exclude), children(x))
end
return cache
end
fcollect (generic function with 1 method)
julia> modules(m) = fcollect(m; exclude = Functors.isleaf)
modules (generic function with 1 method)
julia> struct M1
a
b
end
julia> Functors.@functor M1
julia> struct M2
a
b
end
julia> Functors.@functor M2
julia> m2 = M2(M1(1, 2), 3)
M2(M1(1, 2), 3)
julia> modules(m2)
2-element Vector{Any}:
M2(M1(1, 2), 3)
M1(1, 2) We can discuss what |
I think you'll need to give a MWE of the relevance of that comment to layers/models which is what |
When layers have components other than arrays as params... |
To be honest, the generic solution to that is more a commentary on Functors.jl than this PR. We could allow |
I think the concern is very valid, but IMO this PR does the right thing under the current limitations of Functors.jl. The |
fdedb1b
to
c4b632f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some small doc changes, but looks good to me. This is a useful utility for exposing functors to users without needing them to understand Functors.jl. I think it should be included (interested if @ToucheSir agrees).
The concern about parameters that aren't AbstractArray{<:Number}
being leaves is something that we need to address generically with Functors.jl. It's out of scope for this PR IMO.
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
50c5a47
to
8afedcd
Compare
bors r+ |
Build succeeded: |
I thought the function is still producing results like in #1444 (comment) instead of [M1(...),
3,
M2(...)] |
The |
I think that isn't something that this function should decide, instead return things at the "same level", once it finds a non leaf. |
Some layers go deeper than others. The goal here is to extract all the layers down to primitives like |
This isn't doing the right thing is it, you want to find the first instance you find a struct you can't look inside, for everything else, its the users responsibility to decide how far down we are allowed to recurse. For the case you describe, why not iterate over a Chain instead, which is much more intuitive anyway. |
If I have
Again, this is why |
A function returning an iterator over non-leaf parameters.
Mainly motivated by the need to apply L2 regularization to weights only (see #1284 , #939),
but may be of more general use.
Fix #1294, Fix #1284
PR Checklist
@dhairyagandhi96
(for API changes).