Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

define modules function #1444

Merged
merged 6 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716"
git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "0.5.0"
version = "1.0.1"

[[AbstractTrees]]
deps = ["Markdown"]
git-tree-sha1 = "33e450545eaf7699da1a6e755f9ea65f14077a45"
git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.3.3"
version = "0.3.4"

[[Adapt]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -46,15 +45,15 @@ version = "2.4.1"

[[ChainRules]]
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"]
git-tree-sha1 = "6ba8100fa9356807f1d0df6468ae463c67627c30"
git-tree-sha1 = "e01f521443e3700f40ad3c7c1c6aa3a6940aaea1"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.49"
version = "0.7.54"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "53fed426c9af1eb68e63b3999e96454c2db79757"
git-tree-sha1 = "de4f08843c332d355852721adb1592bce7924da3"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.27"
version = "0.9.29"

[[CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
Expand All @@ -64,9 +63,9 @@ version = "0.7.0"

[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "4bffea7ed1a9f0f3d1a131bbcd4b925548d75288"
git-tree-sha1 = "5e9769a17f17b587c951d57ba4319782b40c3513"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.10.9"
version = "0.10.10"

[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"]
Expand All @@ -93,9 +92,9 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "0.3.4+0"

[[DataAPI]]
git-tree-sha1 = "8ab70b4de35bb3b8cc19654f6b893cf5164f8ee8"
git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.5.1"
version = "1.6.0"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
Expand Down Expand Up @@ -134,9 +133,9 @@ version = "0.1.3"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "50eabdace27aa27b143f65b65e762bb0112a7708"
git-tree-sha1 = "4705cc4e212c3c978c60b1b18118ec49b4d731fd"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.11.1"
version = "0.11.5"

[[FixedPointNumbers]]
deps = ["Statistics"]
Expand All @@ -152,9 +151,9 @@ version = "0.10.16"

[[Functors]]
deps = ["MacroTools"]
git-tree-sha1 = "cd79039c468eac0a15256c55f260eec7ce551d07"
git-tree-sha1 = "a7bb2af991c43dcf5c3455d276dd83976799634f"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.2.0"
version = "0.2.1"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
Expand Down Expand Up @@ -252,9 +251,9 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+4"

[[OrderedCollections]]
git-tree-sha1 = "d45739abcfc03b51f6a42712894a593f74c80a23"
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.3.3"
version = "1.4.0"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand Down Expand Up @@ -283,9 +282,9 @@ version = "1.0.0"

[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "cfbac6c1ed70c002ec6361e7fd334f02820d6419"
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.2"
version = "1.1.3"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
Expand Down Expand Up @@ -318,9 +317,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "OpenSpecFun_jll"]
git-tree-sha1 = "75394dbe2bd346beeed750fb02baa6445487b862"
git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.2.1"
version = "1.3.0"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
Expand All @@ -334,19 +333,19 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "7bab7d4eb46b225b35179632852b595a3162cb61"
git-tree-sha1 = "400aa43f7de43aeccc5b2e39a76a79d262202b76"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.2"
version = "0.33.3"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Printf"]
git-tree-sha1 = "3318281dd4121ecf9713ce1383b9ace7d7476fdd"
git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.7"
version = "0.5.8"

[[TranscodingStreams]]
deps = ["Random", "Test"]
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Adapt = "2.0, 3.0"
CUDA = "2.1"
CodecZlib = "0.7"
Colors = "0.12"
Functors = "0.1, 0.2"
Functors = "0.2.1"
Juno = "0.8"
MacroTools = "0.5"
NNlib = "0.7.14"
Expand Down
1 change: 1 addition & 0 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Flux.outputsize
## Model Abstraction

```@docs
Flux.modules
Flux.destructure
Flux.nfan
```
Expand Down
3 changes: 2 additions & 1 deletion src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import Adapt: adapt, adapt_storage
using LinearAlgebra: Cholesky
using Zygote: IdSet
import Functors: @functor, functor, fmap
import Functors

trainable(m) = functor(m)[1]

Expand Down Expand Up @@ -78,4 +79,4 @@ f64(m) = paramtype(Float64, m)

# Functors for certain Julia data structures
@functor Cholesky
trainable(c::Cholesky) = ()
trainable(c::Cholesky) = ()
41 changes: 41 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -701,3 +701,44 @@ function throttle(f, timeout; leading=true, trailing=false)
return result
end
end


"""
modules(m)

Return an iterator over non-leaf objects
that can be reached from `m` through recursion
on the children given by [`functor`](@ref).
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

It can be used to apply a regularization
over certain specific modules or subsets of
the parameters (e.g. the weights but not the biases).
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

# Examples

```jldoctest
julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu))
Chain(Dense(784, 64), BatchNorm(64, relu))

julia> m2 = Chain(m1, Dense(64, 10))
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))

julia> Flux.modules(m2)
5-element Vector{Any}:
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))
Chain(Dense(784, 64), BatchNorm(64, relu))
Dense(784, 64)
BatchNorm(64, relu)
Dense(64, 10)

julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)
```
"""
modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]

@nograd modules

isleaflike(x) = Functors.isleaf(x)
isleaflike(::Tuple{Vararg{<:Number}}) = true
isleaflike(::Tuple{Vararg{<:AbstractArray{<:Number}}}) = true
25 changes: 25 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,28 @@ end
trainmode!(c)
@test !c[1].testing
end

@testset "modules" begin
m1 = Conv((2,3), 4=>5; pad=6, stride=7)
m2 = LayerNorm(8)
m3 = m2.diag
m4 = SkipConnection(m1, +)
m5 = Chain(m4, m2)
modules = Flux.modules(m5)
# Depth-first descent
@test length(modules) == 5
@test modules[1] === m5
@test modules[2] === m4
@test modules[3] === m1
@test modules[4] === m2
@test modules[5] === m3

modules = Flux.modules(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4)))
@test length(modules) == 5

modules = Flux.modules(Chain(SkipConnection(
Conv((2,3), 4=>5; pad=6, stride=7),
+),
LayerNorm(8)))
@test length(modules) == 5
end