diff --git a/Manifest.toml b/Manifest.toml index 7210fa33bf..81aff3a323 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,15 +2,14 @@ [[AbstractFFTs]] deps = ["LinearAlgebra"] -git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716" +git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "0.5.0" +version = "1.0.1" [[AbstractTrees]] -deps = ["Markdown"] -git-tree-sha1 = "33e450545eaf7699da1a6e755f9ea65f14077a45" +git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5" uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.3.3" +version = "0.3.4" [[Adapt]] deps = ["LinearAlgebra"] @@ -46,15 +45,15 @@ version = "2.4.1" [[ChainRules]] deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"] -git-tree-sha1 = "6ba8100fa9356807f1d0df6468ae463c67627c30" +git-tree-sha1 = "e01f521443e3700f40ad3c7c1c6aa3a6940aaea1" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.49" +version = "0.7.54" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "53fed426c9af1eb68e63b3999e96454c2db79757" +git-tree-sha1 = "de4f08843c332d355852721adb1592bce7924da3" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.27" +version = "0.9.29" [[CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -64,9 +63,9 @@ version = "0.7.0" [[ColorTypes]] deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "4bffea7ed1a9f0f3d1a131bbcd4b925548d75288" +git-tree-sha1 = "5e9769a17f17b587c951d57ba4319782b40c3513" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.10.9" +version = "0.10.10" [[Colors]] deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"] @@ -93,9 +92,9 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" version = "0.3.4+0" [[DataAPI]] -git-tree-sha1 = "8ab70b4de35bb3b8cc19654f6b893cf5164f8ee8" +git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.5.1" +version = "1.6.0" [[DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] @@ -134,9 +133,9 @@ version = "0.1.3" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "50eabdace27aa27b143f65b65e762bb0112a7708" +git-tree-sha1 = "4705cc4e212c3c978c60b1b18118ec49b4d731fd" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.11.1" +version = "0.11.5" [[FixedPointNumbers]] deps = ["Statistics"] @@ -152,9 +151,9 @@ version = "0.10.16" [[Functors]] deps = ["MacroTools"] -git-tree-sha1 = "cd79039c468eac0a15256c55f260eec7ce551d07" +git-tree-sha1 = "a7bb2af991c43dcf5c3455d276dd83976799634f" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.2.0" +version = "0.2.1" [[GPUArrays]] deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] @@ -252,9 +251,9 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" version = "0.5.3+4" [[OrderedCollections]] -git-tree-sha1 = "d45739abcfc03b51f6a42712894a593f74c80a23" +git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.3.3" +version = "1.4.0" [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -283,9 +282,9 @@ version = "1.0.0" [[Requires]] deps = ["UUIDs"] -git-tree-sha1 = "cfbac6c1ed70c002ec6361e7fd334f02820d6419" +git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.1.2" +version = "1.1.3" [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -318,9 +317,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SpecialFunctions]] deps = ["ChainRulesCore", "OpenSpecFun_jll"] -git-tree-sha1 = "75394dbe2bd346beeed750fb02baa6445487b862" +git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "1.2.1" +version = "1.3.0" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] @@ -334,9 +333,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] -git-tree-sha1 = "7bab7d4eb46b225b35179632852b595a3162cb61" +git-tree-sha1 = "400aa43f7de43aeccc5b2e39a76a79d262202b76" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.2" +version = "0.33.3" [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] @@ -344,9 +343,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[TimerOutputs]] deps = ["Printf"] -git-tree-sha1 = "3318281dd4121ecf9713ce1383b9ace7d7476fdd" +git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.7" +version = "0.5.8" [[TranscodingStreams]] deps = ["Random", "Test"] diff --git a/Project.toml b/Project.toml index b51f143252..6995e695b6 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ Adapt = "2.0, 3.0" CUDA = "2.1" CodecZlib = "0.7" Colors = "0.12" -Functors = "0.1, 0.2" +Functors = "0.2.1" Juno = "0.8" MacroTools = "0.5" NNlib = "0.7.14" diff --git a/docs/src/utilities.md b/docs/src/utilities.md index a5306060d4..5edaf46c2a 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -91,6 +91,7 @@ Flux.outputsize ## Model Abstraction ```@docs +Flux.modules Flux.destructure Flux.nfan ``` diff --git a/src/functor.jl b/src/functor.jl index a92e207c32..1e7f9e1fc2 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -2,6 +2,7 @@ import Adapt: adapt, adapt_storage using LinearAlgebra: Cholesky using Zygote: IdSet import Functors: @functor, functor, fmap +import Functors trainable(m) = functor(m)[1] @@ -78,4 +79,4 @@ f64(m) = paramtype(Float64, m) # Functors for certain Julia data structures @functor Cholesky -trainable(c::Cholesky) = () \ No newline at end of file +trainable(c::Cholesky) = () diff --git a/src/utils.jl b/src/utils.jl index cec8a1d7a5..b75d1e5a23 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -701,3 +701,44 @@ function throttle(f, timeout; leading=true, trailing=false) return result end end + + +""" + modules(m) + +Return an iterator over non-leaf objects +that can be reached by recursing `m` over +the children given by [`functor`](@ref). + +Useful for applying a function (e.g. a regularizer) +over specific modules or subsets of the parameters +(e.g. the weights but not the biases). + +# Examples + +```jldoctest +julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu)) +Chain(Dense(784, 64), BatchNorm(64, relu)) + +julia> m2 = Chain(m1, Dense(64, 10)) +Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) + +julia> Flux.modules(m2) +5-element Array{Any,1}: + Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) + Chain(Dense(784, 64), BatchNorm(64, relu)) + Dense(784, 64) + BatchNorm(64, relu) + Dense(64, 10) + +julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense) +L2 (generic function with 1 method) +``` +""" +modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)] + +@nograd modules + +isleaflike(x) = Functors.isleaf(x) +isleaflike(::Tuple{Vararg{<:Number}}) = true +isleaflike(::Tuple{Vararg{<:AbstractArray{<:Number}}}) = true diff --git a/test/utils.jl b/test/utils.jl index bcf5d1b49f..47444ca5fb 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -393,3 +393,28 @@ end trainmode!(c) @test !c[1].testing end + +@testset "modules" begin + m1 = Conv((2,3), 4=>5; pad=6, stride=7) + m2 = LayerNorm(8) + m3 = m2.diag + m4 = SkipConnection(m1, +) + m5 = Chain(m4, m2) + modules = Flux.modules(m5) + # Depth-first descent + @test length(modules) == 5 + @test modules[1] === m5 + @test modules[2] === m4 + @test modules[3] === m1 + @test modules[4] === m2 + @test modules[5] === m3 + + modules = Flux.modules(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4))) + @test length(modules) == 5 + + modules = Flux.modules(Chain(SkipConnection( + Conv((2,3), 4=>5; pad=6, stride=7), + +), + LayerNorm(8))) + @test length(modules) == 5 +end