diff --git a/README.md b/README.md index 4120701a18..1973874dd0 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2] model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only) -optim = Flux.setup(Adam(), model) +opt_state = Flux.setup(Adam(), model) for epoch in 1:1000 - Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim) + Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, opt_state) end plot(x -> 2x-x^3, -2, 2, legend=false) diff --git a/docs/src/guide/saving.md b/docs/src/guide/saving.md index fb00454eec..57f2b9bdb9 100644 --- a/docs/src/guide/saving.md +++ b/docs/src/guide/saving.md @@ -21,7 +21,12 @@ julia> Flux.@layer MyModel julia> MyModel() = MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))); julia> model = MyModel() -MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))) # 67 parameters +MyModel( + Chain( + Dense(10 => 5, relu), # 55 parameters + Dense(5 => 2), # 12 parameters + ), +) # Total: 4 arrays, 67 parameters, 484 bytes. julia> model_state = Flux.state(model); diff --git a/src/Flux.jl b/src/Flux.jl index 189db6d6c7..31a78e5c96 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -8,6 +8,7 @@ using MacroTools, Reexport, ProgressLogging, SpecialFunctions using MacroTools: @forward @reexport using NNlib +using NNlib: conv, ∇conv_data, depthwiseconv, output_size using MLUtils using Optimisers: Optimisers, destructure, freeze!, thaw!, adjust!, trainables, update! @@ -27,7 +28,7 @@ export gradient CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice, # get_device, # we define get_device here for retrocompatibility - # gpu_backend!, # have to define here due to https://github.com/JuliaPackaging/Preferences.jl/issues/39 + gpu_backend!, get_device_type, DeviceIterator @@ -118,7 +119,7 @@ include("losses/Losses.jl") using .Losses include("devices.jl") -export get_device, gpu_backend! +export get_device # Distributed Training include("distributed/backend.jl") diff --git a/src/functor.jl b/src/functor.jl index e049959d09..af28ca4906 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -102,18 +102,6 @@ julia> m.bias """ cpu(x) = cpu_device()(x) -# TODO remove after https://github.com/LuxDL/Lux.jl/pull/1089 -ChainRulesCore.@non_differentiable cpu_device() - - -# Remove when -# https://github.com/JuliaPackaging/Preferences.jl/issues/39 -# is resolved -function gpu_backend!(backend::String) - @set_preferences!("gpu_backend" => backend) - MLDataDevices.gpu_backend!(backend) -end - """ gpu(m) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index d4a33283d9..50c023d7ca 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2} out_proj::P2 end -@layer MultiHeadAttention +@layer :noexpand MultiHeadAttention function MultiHeadAttention(dims; nheads::Int = 8, diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 546f8f29ce..5cfefb86fd 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -60,7 +60,7 @@ end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, Base.keys, Base.firstindex -@layer :expand Chain # the option :expand opts-in to container-style pretty-printing +@layer Chain (c::Chain)(x) = _applychain(c.layers, x) (c::Chain)(x, ys...) = _applychain(c.layers, (x, ys...)) @@ -334,7 +334,7 @@ end Maxout(layers...) = Maxout(layers) Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...) -@layer :expand Maxout +@layer Maxout function (mo::Maxout)(input::AbstractArray) # Perhaps surprisingly, pairwise max broadcast is often faster, @@ -381,7 +381,7 @@ struct SkipConnection{T,F} connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b end -@layer :expand SkipConnection +@layer SkipConnection function (skip::SkipConnection)(input) skip.connection(skip.layers(input), input) @@ -575,7 +575,7 @@ end Parallel(connection, layers::Union{Tuple{}, @NamedTuple{}}) = throw(ArgumentError("cannot construct a Parallel layer with no sub-layers")) -@layer :expand Parallel +@layer Parallel (m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) # one argument @@ -705,7 +705,7 @@ end end applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x) -@layer :expand PairwiseFusion +@layer PairwiseFusion Base.getindex(m::PairwiseFusion, i) = m.layers[i] Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i]) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index b2186f9abf..a5a7734313 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,4 +1,3 @@ -using NNlib: conv, ∇conv_data, depthwiseconv, output_size # pad dims of x with dims of y until ndims(x) == ndims(y) _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...) diff --git a/src/layers/macro.jl b/src/layers/macro.jl index 56ead6dbcf..cd711f3b9c 100644 --- a/src/layers/macro.jl +++ b/src/layers/macro.jl @@ -1,21 +1,24 @@ """ - @layer Dense - @layer :expand Chain - @layer BatchNorm trainable=(β,γ) - + @layer [showtype] MyModel [trainable=(field1,...)] + This macro adds convenience functionality to a custom type to serve -as a neural network layer, module, or entire model. +as a neural network layer, as a module, or as an entire model. -The keyword `trainable` allows you to limit this exploration, instead of visiting all `fieldnames(T)`. +The optional keyword `trainable` allows you to specify which fields of your model can be trained, +instead of assuming all `fieldnames(MyModel)` to trainable. Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes. +This can be also be done by defining [`trainable(::MyModel)`](@ref Optimisers.trainable) for your type. + +The macro also handles overloads of the 3-arg `show(::IO, ::MIME"text/plain", ::MyModel)` for pretty printing. +The optional argument `showtype` can take any of the following values: -The macro also handles overloads of `show` for pretty printing. -* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`. -* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents. -* To disable all `show` overloads, there is an `:ignore` option too. +- `:expand` (default): This will expand the representation of container types like `Chain`, + while maintaining a compat representation of types like `Dense` containing only arrays. +- `:noexpand`: This is to be used in case your type contains other layers but you want to keep the representation simple. +- `:ignore`: To opt out of the pretty printing. -(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.) +You probably still want to define 2-arg `show(::IO, ::MyModel)`, the macro does not touch this. Note that re-running the macro with different options may not remove all methods, you will need to restart. @@ -26,7 +29,7 @@ julia> struct Trio; a; b; c end julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense(hcat(3.3), false), Dropout(0.4)) Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4)) -julia> Flux.@layer :expand Trio +julia> Flux.@layer Trio julia> tri # now the layer is printed like Chain Trio( @@ -34,8 +37,14 @@ Trio( Dense(1 => 1; bias=false), # 1 parameters Dropout(0.4), ) # Total: 3 arrays, 4 parameters, 240 bytes. -``` +julia> Flux.@layer :noexpand Trio trainable=(a,b) + +julia> tri # now the layer is printed compactly +Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4)) # 4 parameters + +julia> opt_state = Flux.setup(Adam(), tri); # `c` is not in the optimizer state +``` """ macro layer(exs...) _layer_macro(exs...) @@ -46,14 +55,17 @@ function _layer_macro(exs...) # These functions are defined in show.jl, and each return an expression overloading Base.show type, rest... = if exs[1] == QuoteNode(:expand) - push!(out.args, _macro_big_show(esc(exs[2]))) + push!(out.args, _macro_big_show(esc(exs[2]))) + exs[2:end] + elseif exs[1] == QuoteNode(:noexpand) + push!(out.args, _macro_layer_show(esc(exs[2]))) exs[2:end] elseif exs[1] == QuoteNode(:ignore) exs[2:end] elseif exs[1] isa QuoteNode - error("`@layer` accepts only two options before the layer type, `:expand` and `:ignore` (to control `show`)") + error("`@layer` accepts only the options `:ignore`, `:noexpand`, and `:expand` before the layer type (to control `show`).") else - push!(out.args, _macro_layer_show(esc(exs[1]))) + push!(out.args, _macro_big_show(esc(exs[1]))) exs end diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 9d294e3e6e..dded9ab306 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -198,7 +198,7 @@ end LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...) LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...) -@layer LayerNorm +@layer :noexpand LayerNorm function (a::LayerNorm)(x::AbstractArray) ChainRulesCore.@ignore_derivatives if a.diag isa Scale diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 20b7c9c5aa..25642f2187 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -158,7 +158,7 @@ struct Model h0::AbstractVector end -Flux.@layer :expand Model +Flux.@layer Model (m::Model)(x) = m.rnn(x, m.h0) @@ -169,7 +169,7 @@ struct RNN{M} cell::M end -@layer :expand RNN +@layer RNN function RNN((in, out)::Pair, σ = tanh; cell_kwargs...) cell = RNNCell(in => out, σ; cell_kwargs...) @@ -344,7 +344,7 @@ struct Model c0::AbstractVector end -Flux.@layer :expand Model +Flux.@layer Model (m::Model)(x) = m.lstm(x, (m.h0, m.c0)) @@ -359,7 +359,7 @@ struct LSTM{M} cell::M end -@layer :expand LSTM +@layer LSTM function LSTM((in, out)::Pair; cell_kwargs...) cell = LSTMCell(in => out; cell_kwargs...) @@ -531,7 +531,7 @@ struct GRU{M} cell::M end -@layer :expand GRU +@layer GRU function GRU((in, out)::Pair; cell_kwargs...) cell = GRUCell(in => out; cell_kwargs...) @@ -669,7 +669,7 @@ struct GRUv3{M} cell::M end -@layer :expand GRUv3 +@layer GRUv3 function GRUv3((in, out)::Pair; cell_kwargs...) cell = GRUv3Cell(in => out; cell_kwargs...) diff --git a/src/layers/show.jl b/src/layers/show.jl index b68340886d..8b5aa716b5 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -1,6 +1,6 @@ @nospecialize # just for this file, for startup time -# This is called by @layer :expand, on layers which should be treated like Chain, and returns an expression: +# This is called by @layer and returns an expression: function _macro_big_show(ex) quote # Entry point: @@ -83,7 +83,7 @@ function _flat_children(x) gamma = ((beta...)...,) end -# This is called by @layer, on layers which should be treated like Dense, and returns an expression: +# This is called by @layer :noexpand, on layers which should be treated like Dense, and returns an expression: function _macro_layer_show(ex) quote # Entry point: @@ -176,40 +176,53 @@ _all(f, xs) = !_any(!f, xs) #= -julia> struct Tmp2; x; y; end; Flux.@functor Tmp2 +julia> struct Tmp2; x; y; end; -# Before, notice Array(), NamedTuple(), and values +julia> t = Tmp2([Dense(2,3), randn(3,4)'], (x=1:4, y=Dense(3,4), z=rand(3))) +Tmp2(Any[Dense(2 => 3), [-0.559390071462934 -0.6357914190386781 -0.8516823037180543; -2.187495592853204 -0.6807254521505784 -1.2334639710489697; -0.12790952072543338 -1.4672700459421741 1.3687526519721238; 0.5232171922680576 -1.012045481192333 1.4953790632112915]], (x = 1:4, y = Dense(3 => 4), z = [0.29222096031585143, 0.6562195256556428, 0.9741896713499167])) -julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3)))) +julia> Chain(t) Chain( Tmp2( - Array( + [ Dense(2 => 3), # 9 parameters - [0.351978391016603 0.6408681372462821 -1.326533184688648; 0.09481930831795712 1.430103476272605 0.7250467613675332; 2.03372151428719 -0.015879812799495713 1.9499692162118236; -1.6346846180722918 -0.8364610153059454 -1.2907265737483433], # 12 parameters - ), - NamedTuple( - 1:3, # 3 parameters - Dense(3 => 4), # 16 parameters - [0.9666158193429335, 0.01613900990539574, 0.0205920186127464], # 3 parameters + 4×3 Adjoint{Float64,...}, # 12 parameters + ], + (; + x = 4-element UnitRange{Int64}, + y = Dense(3 => 4), # 16 parameters + z = 3-element Vector{Float64}, # 3 parameters ), ), -) # Total: 7 arrays, 43 parameters, 644 bytes. +) # Total: 6 trainable arrays, 40 parameters, + # plus 1 non-trainable, 4 parameters, summarysize 620 bytes. + -# After, (; x=, y=, z=) and "3-element Array" +julia> Flux.@layer Tmp2 -julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3)))) +julia> t +Tmp2( + [ + Dense(2 => 3), # 9 parameters + 4×3 Adjoint{Float64,...}, # 12 parameters + ], + 4-element UnitRange{Int64}, + Dense(3 => 4), # 16 parameters + 3-element Vector{Float64}, # 3 parameters +) # Total: 6 trainable arrays, 40 parameters, + # plus 1 non-trainable, 4 parameters, summarysize 620 bytes. + +julia> Chain(t) Chain( Tmp2( [ Dense(2 => 3), # 9 parameters - 4×3 Adjoint, # 12 parameters + 4×3 Adjoint{Float64,...}, # 12 parameters ], - (; - x = 3-element UnitRange, # 3 parameters - y = Dense(3 => 4), # 16 parameters - z = 3-element Array, # 3 parameters - ), + 4-element UnitRange{Int64}, + Dense(3 => 4), # 16 parameters + 3-element Vector{Float64}, # 3 parameters ), -) # Total: 7 arrays, 43 parameters, 644 bytes. - +) # Total: 6 trainable arrays, 40 parameters, + # plus 1 non-trainable, 4 parameters, summarysize 620 bytes. =# diff --git a/test/ext_common/recurrent_gpu_ad.jl b/test/ext_common/recurrent_gpu_ad.jl index d2ef3fe34b..704f147f60 100644 --- a/test/ext_common/recurrent_gpu_ad.jl +++ b/test/ext_common/recurrent_gpu_ad.jl @@ -27,7 +27,7 @@ end h0::AbstractVector end - Flux.@layer :expand ModelRNN + Flux.@layer ModelRNN (m::ModelRNN)(x) = m.rnn(x, m.h0) @@ -74,7 +74,7 @@ end c0::AbstractVector end - Flux.@layer :expand ModelLSTM + Flux.@layer ModelLSTM (m::ModelLSTM)(x) = m.lstm(x, (m.h0, m.c0)) @@ -113,7 +113,7 @@ end h0::AbstractVector end - Flux.@layer :expand ModelGRU + Flux.@layer ModelGRU (m::ModelGRU)(x) = m.gru(x, m.h0) @@ -150,7 +150,7 @@ end h0::AbstractVector end - Flux.@layer :expand ModelGRUv3 + Flux.@layer ModelGRUv3 (m::ModelGRUv3)(x) = m.gru(x, m.h0) diff --git a/test/layers/macro.jl b/test/layers/macro.jl index 53585fb427..b96fc30cf4 100644 --- a/test/layers/macro.jl +++ b/test/layers/macro.jl @@ -4,7 +4,7 @@ module MacroTest using Flux: @layer struct Duo{T,S}; x::T; y::S; end - @layer :expand Duo + @layer Duo struct Trio; a; b; c end # @layer Trio trainable=(a,b) test=(c) # should be (c,) but it lets you forget @@ -33,7 +33,7 @@ end m23 = MacroTest.TwoThirds([1 2], [3 4], [5 6]) # Check that we can use the macro with a qualified type name, outside the defining module: - Flux.@layer :expand MacroTest.TwoThirds trainable=(:a) # documented as (a,c) but allow quotes + Flux.@layer MacroTest.TwoThirds trainable=(:a) # documented as (a,c) but allow quotes m23re = Functors.functor(m23)[2]((a = [10 20], b = [3 4], c = [50 60])) @test m23re isa MacroTest.TwoThirds diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 6da1f73ee9..f882cdccc2 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -68,7 +68,7 @@ end h0::AbstractVector end - Flux.@layer :expand ModelRNN + Flux.@layer ModelRNN (m::ModelRNN)(x) = m.rnn(x, m.h0) @@ -138,7 +138,7 @@ end c0::AbstractVector end - Flux.@layer :expand ModelLSTM + Flux.@layer ModelLSTM (m::ModelLSTM)(x) = m.lstm(x, (m.h0, m.c0)) @@ -215,7 +215,7 @@ end h0::AbstractVector end - Flux.@layer :expand ModelGRU + Flux.@layer ModelGRU (m::ModelGRU)(x) = m.gru(x, m.h0) @@ -265,7 +265,7 @@ end h0::AbstractVector end - Flux.@layer :expand ModelGRUv3 + Flux.@layer ModelGRUv3 (m::ModelGRUv3)(x) = m.gru(x, m.h0)