diff --git a/README.md b/README.md index 1973874dd0..cc0a7f1803 100644 --- a/README.md +++ b/README.md @@ -22,19 +22,25 @@ Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, a Works best with [Julia 1.10](https://julialang.org/downloads/) or later. Here's a very short example to try it out: ```julia -using Flux, Plots -data = [([x], 2x-x^3) for x in -2:0.1f0:2] +using Flux +data = [(x, 2x-x^3) for x in -2:0.1f0:2] -model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only) +model = let + w, b, v = (randn(Float32, 23) for _ in 1:3) # parameters + x -> sum(v .* tanh.(w*x .+ b)) # callable +end +# model = Chain(vcat, Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only) opt_state = Flux.setup(Adam(), model) -for epoch in 1:1000 +for epoch in 1:100 Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, opt_state) end -plot(x -> 2x-x^3, -2, 2, legend=false) -scatter!(x -> model([x]), -2:0.1f0:2) +using Plots +plot(x -> 2x-x^3, -2, 2, label="truth") +scatter!(model, -2:0.1f0:2, label="learned") ``` +In Flux 0.15, almost any parameterised function in Julia is a valid Flux model -- such as this closure over `w, b, v`. The same function can also be implemented with built-in layers as shown. The [quickstart page](https://fluxml.ai/Flux.jl/stable/guide/models/quickstart/) has a longer example. See the [documentation](https://fluxml.github.io/Flux.jl/) for details, or the [model zoo](https://github.com/FluxML/model-zoo/) for examples. Ask questions on the [Julia discourse](https://discourse.julialang.org/) or [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866). diff --git a/docs/make.jl b/docs/make.jl index 6bdfcbb638..4367639d8e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -30,7 +30,6 @@ makedocs( "Quick Start" => "guide/models/quickstart.md", "Fitting a Line" => "guide/models/overview.md", "Gradients and Layers" => "guide/models/basics.md", - "Custom Layers" => "guide/models/custom_layers.md", "Training" => "guide/training/training.md", "Recurrence" => "guide/models/recurrence.md", "GPU Support" => "guide/gpu.md", @@ -63,6 +62,7 @@ makedocs( # Or perhaps those should just be trashed, model zoo versions are newer & more useful. "Linear Regression" => "tutorials/linear_regression.md", "Logistic Regression" => "tutorials/logistic_regression.md", + "Custom Layers" => "tutorials/custom_layers.md", "Model Zoo" => "tutorials/model_zoo.md", #= # "Multi-layer Perceptron" => "tutorials/mlp.md", diff --git a/docs/src/assets/zygote-crop.png b/docs/src/assets/zygote-crop.png new file mode 100644 index 0000000000..ddc04b3d17 Binary files /dev/null and b/docs/src/assets/zygote-crop.png differ diff --git a/docs/src/guide/models/basics.md b/docs/src/guide/models/basics.md index 85d54ee58b..7ad62ee207 100644 --- a/docs/src/guide/models/basics.md +++ b/docs/src/guide/models/basics.md @@ -1,241 +1,452 @@ -# [How Flux Works: Gradients and Layers](@id man-basics) +# [How Flux Works: Parameters, Gradients, and Layers](@id man-basics) -## [Taking Gradients](@id man-taking-gradients) +A neural network is a function with *parameters*. +That is, it takes some input `x` and gives you some output `y`, +whose value also depends on some other numbers `θ`. -Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.) +A sufficiently flexible function can, by adjusting the parameters just right, +be made to do many things. And the one magic trick for adjusting parameters +is to follow a *gradient*. -```jldoctest basics -julia> using Flux +This page describes Flux's take on how to construct such flexible functions +containing many parameters, and how to handle their gradients. -julia> f(x) = 3x^2 + 2x + 1; +## Parameterised Functions -julia> df(x) = gradient(f, x)[1]; # df/dx = 6x + 2 +Let's start with very simple functions. This is a polynomial in `x::Real`, +returning another real number `y` which depends on some coefficients stored in a vector: -julia> df(2) -14.0 -``` +```jldoctest poly; output = false +θ = [10, 1, 0.1] + +poly1(x::Real) = θ[1] + θ[2]*x + θ[3]*x^2 -When a function has many parameters, we can get gradients of each one at the same time: +poly1(5) == 17.5 # true -```jldoctest basics -julia> f(x, y) = sum((x .- y).^2); +# output -julia> gradient(f, [2, 1], [2, 0]) -([0.0, 2.0], [-0.0, -2.0]) +true ``` -These gradients are based on `x` and `y`. Flux works by instead taking gradients based on the weights and biases that make up the parameters of a model. +Here the parameters are a global variable `θ`. They could be handled in other ways, +for instance by explicitly passing them as an additional argument to the function: + +```jldoctest poly; output = false +poly2(x::Real, θ2) = evalpoly(x, θ2) # built-in, from Base.Math + +poly2(5, θ) == 17.5 # true -Machine learning often can have *hundreds* of parameter arrays. -Instead of passing them to `gradient` individually, we can store them together in a structure. -The simplest example is a named tuple, created by the following syntax: +# output -```jldoctest basics -julia> nt = (a = [2, 1], b = [2, 0], c = tanh); +true +``` + +Flux chooses a third path, by *encapsulating* the parameters within the function. +The simplest way to do this is a *closure*, an anonymous function which Julia knows +to depend on some local variable `θ3`: -julia> g(x::NamedTuple) = sum(abs2, x.a .- x.b); +```jldoctest poly; output = false +poly3 = let θ3 = [10, 1, 0.1] + x -> evalpoly(x, θ3) +end -julia> g(nt) -1 +poly3(5) == 17.5 # true -julia> dg_nt = gradient(g, nt)[1] -(a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing) +# output + +true ``` -Notice that `gradient` has returned a matching structure. The field `dg_nt.a` is the gradient -for `nt.a`, and so on. Some fields have no gradient, indicated by `nothing`. +An equivalent, but tidier, way is to construct a `struct` in which to store the parameters. +Any struct can be made callable, allowing its instances to act just like function: -Rather than define a function like `g` every time (and think up a name for it), -it is often useful to use anonymous functions: this one is `x -> sum(abs2, x.a .- x.b)`. -Anonymous functions can be defined either with `->` or with `do`, -and such `do` blocks are often useful if you have a few steps to perform: +```jldoctest poly; output = false +struct Poly3{T} # container struct + θ3::T +end +(p::Poly3)(x::Real) = evalpoly(x, p.θ3) # make this callable -```jldoctest basics -julia> gradient((x, y) -> sum(abs2, x.a ./ y .- x.b), nt, [1, 2]) -((a = [0.0, 0.5], b = [-0.0, -1.0], c = nothing), [-0.0, -0.25]) +poly3s = Poly3([10, 1, 0.1]) # construct an instance -julia> gradient(nt, [1, 2]) do x, y - z = x.a ./ y - sum(abs2, z .- x.b) - end -((a = [0.0, 0.5], b = [-0.0, -1.0], c = nothing), [-0.0, -0.25]) +poly3s(5) == 17.5 # true + +# output + +true ``` -Sometimes you may want to know the value of the function, as well as its gradient. -Rather than calling the function a second time, you can call [`withgradient`](@ref Zygote.withgradient) instead: +Internally, there is little difference between a closure and a struct. +They have the same fields, and equivalent methods: +```julia +dump(poly3), dump(poly3s) # both contain θ3: Array +poly3s.θ3 == poly3.θ3 == θ # field called :θ3 has same value +methods(poly3) +methods(poly3s) # each has 1 method, accepting x ``` -julia> Flux.withgradient(g, nt) -(val = 1, grad = ((a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing),)) + +The virtue of encapsulation is that it makes composition very easy. +We can make more complicated functions by combining simple ones, +and each will keep track of its own parameters. +Juia writes function composition as `∘`, for instance `(inv ∘ sin)(pi/6) ≈ 2`, +and we can use exactly this for our parameterised polynomials: + +```jldoctest poly; output = false +poly4 = Poly3([1, 0.5, 0]) ∘ Poly3([10, 1, 0.1]) + +poly4 isa ComposedFunction # ∘ creates another struct... +poly4.outer.θ3 == θ # which has fields :inner & :outer + +poly4(5) == 9.75 # true + +# output + +true ``` -## Building Simple Models +Flux models are precisely made by such function composition. +In fact, `poly3` and `poly4` are already valid Flux models. -Consider a simple linear regression, which tries to predict an output array `y` from an input `x`. -```julia -predict(W, b, x) = W*x .+ b +## [Structural Gradients](@id man-taking-gradients) -function loss(W, b, x, y) - ŷ = predict(W, b, x) - sum((y .- ŷ).^2) -end +The derivative of a scalar function is its slope: how fast the output changes as the input is changed slightly. +This may be found approximately by evaluating at two nearby points, and exactly by taking the limit in +which the distance between them approaches zero: -x, y = rand(5), rand(2) # Dummy data -W = rand(2, 5) -b = rand(2) +```jldoctest poly +julia> (poly1(5 + 0.1) - poly1(5)) / 0.1 +2.010000000000005 -loss(W, b, x, y) # ~ 3 +julia> (poly1(5 + 0.001) - poly1(5)) / 0.001 # answer is getting close to 2 +2.000100000003613 ``` -To improve the prediction we can take the gradients of the loss with respect to `W` and `b` and perform gradient descent. +Flux's `gradient(f, x)` works this out for `f(x)`, and gives exactly `∂f/∂x = 2.0` here: -```julia -using Flux +```jldoctest poly +julia> using Flux -dW, db = gradient((W, b) -> loss(W, b, x, y), W, b) +julia> gradient(poly1, 5) +(2.0,) ``` -Now that we have gradients, we can pull them out and update `W` to train the model. +The reason `gradient` returns a tuple, not just the number `2.0`, is to allow for +functions taking several arguments. (That's also why it's not called "derivative".) +For instance, this returns `∂f/∂x, ∂f/∂y, ∂f/∂z`: -```julia -W .-= 0.1 .* dW +```jldoctest poly +julia> gradient((x,y,z) -> (x*y)+z, 30, 40, 50) +(40.0, 30.0, 1.0) +``` + +For our parameterised polynomial, we have `∂f/∂x` but we are really more interested +in `∂f/∂θ`, as this will tell us about how the parameters are affecting the answer. +It is not impossible to track gradients with respect to global `θ`, but much clearer to track explicit arguments. +Here's how this works for `poly2` (which takes `θ` as a 2nd argument) and `poly3` (which encapsulates `θ`): + +```jldoctest poly +julia> grad2 = gradient(poly2, 5, θ) +(2.0, [1.0, 5.0, 25.0]) + +julia> grad3 = gradient((x,p) -> p(x), 5, poly3s) +(2.0, (θ3 = [1.0, 5.0, 25.0],)) +``` + +The first entry is `∂f/∂x` as before, but the second entry is more interesting. +For `poly2`, we get `∂f/∂θ` as `grad2[2]` directly. +It is a vector, because `θ` is a vector, and has elements `[∂f/∂θ[1], ∂f/∂θ[2], ∂f/∂θ[3]]`. + +For `poly3s`, however, we get a `NamedTuple` whose fields correspond to those of the struct `Poly3`. +This is called a *structural gradient*. And the nice thing about them is that they work for +arbitrarily complicated structures, for instance: -loss(W, b, x, y) # ~ 2.5 +```jldoctest poly +julia> grad4 = gradient(|>, 5, poly4) +(1.0, (outer = (θ3 = [1.0, 17.5, 306.25],), inner = (θ3 = [0.5, 2.5, 12.5],))) ``` -The loss has decreased a little, meaning that our prediction `x` is closer to the target `y`. If we have some data we can already try [training the model](../training/training.md). +Here `grad4.inner.θ3` corresponds to `poly4.inner.θ3`. +These matching nested structures are at the core of how Flux works. + +!!! note "Implicit gradients" + Earlier versions of Flux used a different way to relate parameters and gradients, + which looks like this: + ```julia + g1 = gradient(() -> poly1(5), Params([θ])) + g1[θ] == [1.0, 5.0, 25.0] + ``` + Here `Params` is a set of references to global variables using `objectid`, + and `g1 isa Grads` is a dictionary from these to their gradients. + This method of `gradient` takes a zero-argument function, which only *implicitly* + depends on `θ`. + +```@raw html +

 Zygote.jl

+``` -All deep learning in Flux, however complex, is a simple generalisation of this example. Of course, models can *look* very different – they might have millions of parameters or complex control flow. Let's see how Flux handles more complex models. +Flux's [`gradient`](@ref) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl). +Zygote performs source-to-source automatic differentiation, meaning that `gradient(f, x)` +hooks into Julia's compiler to find out what operations `f` contains, and transforms this +to produce code for computing `∂f/∂x`. -## Building Layers +Zygote can in principle differentiate almost any Julia code. However, it's not perfect, +and you may eventually want to read its [page about limitations](https://fluxml.ai/Zygote.jl/dev/limitations/). +In particular, a major limitation is that mutating an array is not allowed. -It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) in between them. We could write this as: +Flux can also be used with other automatic differentiation (AD) packages. +It was originally written using [Tracker](https://github.com/FluxML/Tracker.jl), a more traditional operator-overloading approach. +The future might be [Enzyme](https://github.com/EnzymeAD/Enzyme.jl), and Flux now builds in an easy way to use this instead, turned on by wrapping the model in `Duplicated`. (For details, see the [Enzyme page](@ref autodiff-enzyme) in the manual.) ```julia -using Flux +julia> using Enzyme: Const, Duplicated -W1 = rand(3, 5) -b1 = rand(3) -layer1(x) = W1 * x .+ b1 +julia> grad3e = Flux.gradient((x,p) -> p(x), Const(5.0), Duplicated(poly3s)) +(nothing, (θ3 = [1.0, 5.0, 25.0],)) +``` -W2 = rand(2, 3) -b2 = rand(2) -layer2(x) = W2 * x .+ b2 +`Flux.gradient` follows Zygote's convention that arguments with no derivative are marked `nothing`. +Here, this is because `Const(5.0)` is explicitly constant. +Below, we will see an example where `nothing` shows up because the model struct has fields containing things other than parameters, such as an activation function. +(It also adopts the convention that `gradient(f, x, y)` returns a tuple `(∂f/∂x, ∂f/∂y)`, without a "`∂f/∂f`" term for the function. This is why we had to write `gradient(|>, 5, poly4)` above, not just `gradient(poly4, 5)`.) -model(x) = layer2(sigmoid.(layer1(x))) +Finally, the function [`withgradient`](@ref) works the same way, but also returns the value of the function: -model(rand(5)) # => 2-element vector +```jldoctest poly +julia> Flux.withgradient((x,p) -> p(x), 5.0, poly3s) +(val = 17.5, grad = (2.0, (θ3 = [1.0, 5.0, 25.0],))) ``` -This works but is fairly unwieldy, with a lot of repetition – especially as we add more layers. One way to factor this out is to create a function that returns linear layers. +## Simple Neural Networks -```julia -function linear(in, out) - W = randn(out, in) - b = randn(out) - x -> W * x .+ b -end +The polynomial functions above send a number `x` to another a number `y`. +Neural networks typically take a vector of numbers, mix them all up, and return another vector. +Here's a very simple one, which will take a vector like `x = [1.0, 2.0, 3.0]` +and return another vector `y = layer1(x)` with `length(y) == 2`: -linear1 = linear(5, 3) # we can access linear1.W etc -linear2 = linear(3, 2) +```jldoctest poly; output = false +W = randn(2, 3) +b = zeros(2) -model(x) = linear2(sigmoid.(linear1(x))) +sigmoid(x::Real) = 1 / (1 + exp(-x)) +layer1(x) = sigmoid.(W*x .+ b) -model(rand(5)) # => 2-element vector +# output + +layer1 (generic function with 1 method) ``` -Another (equivalent) way is to create a struct that explicitly represents the affine layer. +Here `sigmoid` is a nonlinear function, applied element-wise +because it is called with `.()`, called broadcasting. -```julia -struct Affine - W - b -end +Like `poly1` above, this `layer1` has as its parameters the global variables `W, b`. +We can similarly define a version which takes these as arguments (like `poly2`), +and a version which encapsulates them (like `poly3` above): + +```jldoctest poly; output = false +layer2(x, W2, b2) = sigmoid.(W2*x .+ b2) # explicit parameter arguments -Affine(in::Integer, out::Integer) = - Affine(randn(out, in), zeros(out)) +layer3 = let + W3 = randn(2, 3) + b3 = zeros(2) + x -> sigmoid.(W3*x .+ b3) # closure over local variables +end -# Overload call, so the object can be used as a function -(m::Affine)(x) = m.W * x .+ m.b +layer3([1.0, 2.0, 3.0]) isa Vector # check that it runs -a = Affine(10, 5) +# output -a(rand(10)) # => 5-element vector +true ``` -Congratulations! You just built the [`Dense`](@ref) layer that comes with Flux. Flux has many interesting layers available, but they're all things you could have built yourself very easily. +This third way is precisely a Flux model. And we can again make a tidier version +using a `struct` to hold the parameters: -(There is one small difference with `Dense` – for convenience it also takes an activation function, like `Dense(10 => 5, sigmoid)`.) +```jldoctest poly; output = false, filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" +struct Layer # container struct + W::Matrix + b::Vector + act::Function +end -## Stacking It Up +(d::Layer)(x) = d.act.(d.W*x .+ d.b) # make it callabale -It's pretty common to write models that look something like: +Layer(in::Int, out::Int, act::Function=sigmoid) = + Layer(randn(Float32, out, in), zeros(Float32, out), act) -```julia -layer1 = Dense(10 => 5, relu) -# ... -model(x) = layer3(layer2(layer1(x))) +layer3s = Layer(3, 2) # instance with its own parameters + +# output + +Layer(Float32[0.6911411 0.47683495 -0.75600505; 0.5247729 1.2508286 0.27635413], Float32[0.0, 0.0], sigmoid) ``` -For long chains, it might be a bit more intuitive to have a list of layers, like this: +The one new thing here is a friendly constructor `Layer(in, out, act)`. +This is because we anticipate composing several instances of this thing, +with independent parameter arrays, of different sizes and different +random initial parameters. -```julia -using Flux +Let's try this out, and look at its gradient: + +```jldoctest poly; output = false, filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" +x = Float32[0.1, 0.2, 0.3] # input -layers = [Dense(10 => 5, relu), Dense(5 => 2), softmax] +layer3s(x) # output, 2-element Vector{Float32} -model(x) = foldl((x, m) -> m(x), layers, init = x) +Flux.gradient((x,d) -> d(x)[1], x, layer3s)[2] # NamedTuple{(:W, :b, :act)} -model(rand(10)) # => 2-element vector +# output + +(W = Float32[0.024975738 0.049951475 0.07492722; 0.0 0.0 0.0], b = Float32[0.24975738, 0.0], act = nothing) ``` -Handily, this is also provided for in Flux: +This `∂f/∂layer3s` is a named tuple with the same fields as `Layer`. +Within it, the gradient with respect to `W` is a matrix of seemingly random numbers. +Notice that there is also an entry for `act`, which is `nothing`, +as this field of the struct is not a smoothly adjustible parameter. -```julia -model2 = Chain( - Dense(10 => 5, relu), - Dense(5 => 2), - softmax) +We can compose these layers just as we did the polynomials above, in `poly4`. +Here's a composition of 3 functions, in which the last step is the function `only` +which takes a 1-element vector and gives us the number inside: + +```jldoctest poly; output = false, filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" +model1 = only ∘ Layer(20, 1) ∘ Layer(1, 20) + +y = model1(Float32[0.1]) # output is a Float32 number + +grad = Flux.gradient(|>, [1f0], model1)[2] -model2(rand(10)) # => 2-element vector +# output + +(outer = (outer = nothing, inner = (W = Float32[0.058179587 0.1276911 … 0.08071162 0.034993216], b = Float32[0.14223717], act = nothing)), inner = (W = Float32[-0.048111934; -0.0008379104; … ; 0.017658396; -0.015104223;;], b = Float32[-0.048111934, -0.0008379104, 0.017207285, 0.026828118, -0.024858447, -0.015956078, 0.0020494608, -0.012577536, -0.044770215, 0.01478136, 0.034534186, -0.004748393, 0.026848236, -0.016794706, -0.041044597, 0.016186379, -0.036814954, 0.034786277, 0.017658396, -0.015104223], act = nothing)) ``` -This quickly starts to look like a high-level deep learning library; yet you can see how it falls out of simple abstractions, and we lose none of the power of Julia code. +This gradient is starting to be a complicated nested structure. +But it works just like before: `grad.outer.inner.W` corresponds to `model1.outer.inner.W`. + +We don't have to use `∘` (which makes a `ComposedFunction` struct) to combine layers. +Instead, we could define our own container struct, or use a closure. +This `model2` will work the same way (although its fields have different names): + +```jldoctest poly; output = false, filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" +model2 = let + lay1 = Layer(1, 20) # local variables containing layers + lay2 = Layer(20, 1) + function fwd(x) # equivalent to x -> only(lay2(lay1(x))) + mid = lay1(x) + lay2(mid) |> only + end +end + +model2(Float32[0.1]) -A nice property of this approach is that because "models" are just functions (possibly with trainable parameters), you can also see this as simple function composition. +Flux.gradient(|>, [1f0], model2)[2] -```julia -m = Dense(5 => 2) ∘ Dense(10 => 5, σ) +# output -m(rand(10)) +(lay2 = (W = Float32[0.051824596 0.03971491 … 0.038365345 0.051143322], b = Float32[0.09477656], act = nothing), lay1 = (W = Float32[-0.00049770635; 0.002891017; … ; -0.0022540581; 0.0039325757;;], b = Float32[-0.00049770635, 0.002891017, -0.00865399, -0.015051818, -0.005504916, -0.004188145, -0.01533527, -0.0059600063, -0.003092169, -0.00697084, -0.012470333, -0.0048766206, -0.010671042, -0.006604657, -0.0086712, -0.0044975257, -0.0028462198, -0.009992857, -0.0022540581, 0.0039325757], act = nothing)) ``` -Likewise, `Chain` will happily work with any Julia function. +```@raw html +

 Flux's layers

+``` -```julia -m = Chain(x -> x^2, x -> x+1) +Rather than define everything from scratch every time, Flux provides a library of +commonly used layers. The same model could be defined: + +```jldoctest poly; output = false +model3 = Chain(Dense(1 => 20, σ), Dense(20 => 1), only) -m(5) # => 26 +# output + +Chain( + Dense(1 => 20, σ), # 40 parameters + Dense(20 => 1), # 21 parameters + only, +) # Total: 4 arrays, 61 parameters, 452 bytes. ``` -## Layer Helpers +How does this `model3` differ from the `model1` we had before? + +* Flux's [`Chain`](@ref Flux.Chain) works left-to-right, the reverse of Base's `∘`. + Its contents is stored in a tuple, thus `model3.layers[1].weight` is an array. +* Flux's layer [`Dense`](@ref Flux.Dense) has only minor differences from our `struct Layer`: + - Like `struct Poly3{T}` above, it has type parameters for its fields -- the compiler does not know exactly what type `layer3s.W` will be, which costs speed. + - Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref) by default. + - It reshapes some inputs (to allow several batch dimensions), and produces more friendly errors on wrong-size input. + - And it has some performance tricks: making sure element types match, and re-using some memory. +* The function [`σ`](@ref NNlib.sigmoid) is calculated in a slightly better way, + and has a rule telling Zygote how to differentiate it efficiently. +* Flux overloads `Base.show` so to give pretty printing at the REPL prompt. + Calling [`Flux.@layer Layer`](@ref Flux.@layer) will add this, and some other niceties. + +All Flux layers accept a batch of samples: Instead of mapping one sample `x::Vector` to one output `y::Vector`, they map columns of a matrix `xs::Matrix` to columns of the output. This looks like `f(xs) ≈ stack(f(x) for x in eachcol(xs))` but is done more efficiently. + +If what you need isn't covered by Flux's built-in layers, it's easy to write your own. +There are more details [later](@ref man-advanced), but the steps are invariably those shown for `struct Layer` above: +1. Define a `struct` which will hold the parameters. +2. Make it callable, to define how it uses them to transform the input `x` +3. Define a constructor which initialises the parameters (if the default constructor doesn't do what you want). +4. Annotate with `@layer` to opt-in to pretty printing, and other enhacements. + +```@raw html +

 Functors.jl

+``` -We can give our layer some additional functionality, like nice printing, using the [`@layer`](@ref Flux.@layer) macro: +To deal with such nested structures, Flux relies heavily on an associated package +called Functors. Its basic function is [`fmap`](@ref Functors.fmap), +which generalises `map(f, x)` to work on almost anything. + +For example, this is how [gpu](@ref Flux.gpu) moves all arrays within a model to the GPU, +reconstructing another `only ∘ Layer(...) ∘ Layer(...)` (or a `Chain` etc.) around the new `CuArray`s: ```julia -Flux.@layer Affine +using CUDA, Functors +fmap(cu, model1) ``` -Finally, most Flux layers make bias optional, and allow you to supply the function used for generating random weights. We can easily add these refinements to the `Affine` layer as follows, using the helper function [`create_bias`](@ref Flux.create_bias): +And this is a very simple gradient update of the parameters, walking over `model` and `grad` simultaneously: ```julia -function Affine((in, out)::Pair; bias=true, init=glorot_uniform) - W = init(out, in) - b = Flux.create_bias(W, bias, out) - return Affine(W, b) +fmap((x, dx) -> x isa Array ? (x - dx/100) : x, model, grad) +``` + +!!! note + Before Flux v0.15 (and Functors v0.5), this exploration of structs was opt-in. + After defining `struct Layer` it was necessary to call `@functor Layer` (or `@layer Layer`) before Flux would look inside. + This has now changed to be opt-out: Functors (and hence Flux) will explore arbitrary structs, unless told not to (using `Functors.@leaf`). + This is why even "anonymous structs" created by closures, like `poly3` and `layer3` above, are now valid Flux models, although the use of named structs is still recommended practice. + +## Curve Fitting + +Above we took gradients of the output, or sometimes to the first element +of the output -- it must be a number, not a vector. Adjusting the parameters +to make this smaller won't lead us anywhere interesting. Instead, we should minimise +some *loss function* which compares the actual output to our desired output. + +Perhaps the simplest example is curve fitting. The [previous page](@ref man-overview) fitted +a linear model to data. With out two-layer model, we can fit a nonlinear function. +For example, let us use `f(x) = 2x - x^3` evaluated at some points `x in -2:0.1:2` as the data, +and adjust the parameters of `model3` from above so that its output is similar. + +```jldoctest poly; output = false +data = [([x], 2x-x^3) for x in -2:0.1f0:2] # training points (x, y) + +for _ in 1:1000 # adjust parameters to minimise the error: + Flux.train!((m,x,y) -> (m(x) - y)^2, model3, data, Descent(0.01)) end -Affine(3 => 1, bias=false) |> gpu +# output + +``` + +The same code will also work with `model1` or `model2` instead. +Here's how to plot the desired and actual outputs: + +```julia +using Plots +plot(x -> 2x-x^3, -2, 2, label="truth") +scatter!(x -> model3([x]), -2:0.1f0:2, label="fitted") ``` +More detail about what exactly the function `train!` is doing, and how to use rules other than simple [`Descent`](@ref Optimisers.Descent), is what the next page in this guide is about: [training](@ref man-training). diff --git a/docs/src/reference/models/layers.md b/docs/src/reference/models/layers.md index b798a35291..8e5c0e873c 100644 --- a/docs/src/reference/models/layers.md +++ b/docs/src/reference/models/layers.md @@ -1,9 +1,4 @@ -```@meta -CurrentModule = Flux -CollapsedDocStrings = true -``` - -# Built-in Layer Types +# [Built-in Layer Types](@id man-layers) If you started at the beginning of the guide, then you have already met the basic [`Dense`](@ref) layer, and seen [`Chain`](@ref) for combining layers. @@ -58,7 +53,7 @@ documented in NNlib's [Attention](@ref) section. ```@docs MultiHeadAttention -``` +``` ### Pooling @@ -75,7 +70,7 @@ GlobalMeanPool ## Upsampling -The opposite of pooling, these layers increase the size of an array. They have no trainable parameters. +The opposite of pooling, these layers increase the size of an array. They have no trainable parameters. ```@docs Upsample @@ -135,7 +130,7 @@ Flux.normalise ### Test vs. Train -Several normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference. +Several normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference. !!! warning This automatic train/test detection works best with Zygote, the default diff --git a/docs/src/guide/models/custom_layers.md b/docs/src/tutorials/custom_layers.md similarity index 100% rename from docs/src/guide/models/custom_layers.md rename to docs/src/tutorials/custom_layers.md