diff --git a/README.md b/README.md
index 1973874dd0..cc0a7f1803 100644
--- a/README.md
+++ b/README.md
@@ -22,19 +22,25 @@ Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, a
Works best with [Julia 1.10](https://julialang.org/downloads/) or later. Here's a very short example to try it out:
```julia
-using Flux, Plots
-data = [([x], 2x-x^3) for x in -2:0.1f0:2]
+using Flux
+data = [(x, 2x-x^3) for x in -2:0.1f0:2]
-model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only)
+model = let
+ w, b, v = (randn(Float32, 23) for _ in 1:3) # parameters
+ x -> sum(v .* tanh.(w*x .+ b)) # callable
+end
+# model = Chain(vcat, Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only)
opt_state = Flux.setup(Adam(), model)
-for epoch in 1:1000
+for epoch in 1:100
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, opt_state)
end
-plot(x -> 2x-x^3, -2, 2, legend=false)
-scatter!(x -> model([x]), -2:0.1f0:2)
+using Plots
+plot(x -> 2x-x^3, -2, 2, label="truth")
+scatter!(model, -2:0.1f0:2, label="learned")
```
+In Flux 0.15, almost any parameterised function in Julia is a valid Flux model -- such as this closure over `w, b, v`. The same function can also be implemented with built-in layers as shown.
The [quickstart page](https://fluxml.ai/Flux.jl/stable/guide/models/quickstart/) has a longer example. See the [documentation](https://fluxml.github.io/Flux.jl/) for details, or the [model zoo](https://github.com/FluxML/model-zoo/) for examples. Ask questions on the [Julia discourse](https://discourse.julialang.org/) or [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866).
diff --git a/docs/make.jl b/docs/make.jl
index 6bdfcbb638..4367639d8e 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -30,7 +30,6 @@ makedocs(
"Quick Start" => "guide/models/quickstart.md",
"Fitting a Line" => "guide/models/overview.md",
"Gradients and Layers" => "guide/models/basics.md",
- "Custom Layers" => "guide/models/custom_layers.md",
"Training" => "guide/training/training.md",
"Recurrence" => "guide/models/recurrence.md",
"GPU Support" => "guide/gpu.md",
@@ -63,6 +62,7 @@ makedocs(
# Or perhaps those should just be trashed, model zoo versions are newer & more useful.
"Linear Regression" => "tutorials/linear_regression.md",
"Logistic Regression" => "tutorials/logistic_regression.md",
+ "Custom Layers" => "tutorials/custom_layers.md",
"Model Zoo" => "tutorials/model_zoo.md",
#=
# "Multi-layer Perceptron" => "tutorials/mlp.md",
diff --git a/docs/src/assets/zygote-crop.png b/docs/src/assets/zygote-crop.png
new file mode 100644
index 0000000000..ddc04b3d17
Binary files /dev/null and b/docs/src/assets/zygote-crop.png differ
diff --git a/docs/src/guide/models/basics.md b/docs/src/guide/models/basics.md
index 85d54ee58b..7ad62ee207 100644
--- a/docs/src/guide/models/basics.md
+++ b/docs/src/guide/models/basics.md
@@ -1,241 +1,452 @@
-# [How Flux Works: Gradients and Layers](@id man-basics)
+# [How Flux Works: Parameters, Gradients, and Layers](@id man-basics)
-## [Taking Gradients](@id man-taking-gradients)
+A neural network is a function with *parameters*.
+That is, it takes some input `x` and gives you some output `y`,
+whose value also depends on some other numbers `θ`.
-Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
+A sufficiently flexible function can, by adjusting the parameters just right,
+be made to do many things. And the one magic trick for adjusting parameters
+is to follow a *gradient*.
-```jldoctest basics
-julia> using Flux
+This page describes Flux's take on how to construct such flexible functions
+containing many parameters, and how to handle their gradients.
-julia> f(x) = 3x^2 + 2x + 1;
+## Parameterised Functions
-julia> df(x) = gradient(f, x)[1]; # df/dx = 6x + 2
+Let's start with very simple functions. This is a polynomial in `x::Real`,
+returning another real number `y` which depends on some coefficients stored in a vector:
-julia> df(2)
-14.0
-```
+```jldoctest poly; output = false
+θ = [10, 1, 0.1]
+
+poly1(x::Real) = θ[1] + θ[2]*x + θ[3]*x^2
-When a function has many parameters, we can get gradients of each one at the same time:
+poly1(5) == 17.5 # true
-```jldoctest basics
-julia> f(x, y) = sum((x .- y).^2);
+# output
-julia> gradient(f, [2, 1], [2, 0])
-([0.0, 2.0], [-0.0, -2.0])
+true
```
-These gradients are based on `x` and `y`. Flux works by instead taking gradients based on the weights and biases that make up the parameters of a model.
+Here the parameters are a global variable `θ`. They could be handled in other ways,
+for instance by explicitly passing them as an additional argument to the function:
+
+```jldoctest poly; output = false
+poly2(x::Real, θ2) = evalpoly(x, θ2) # built-in, from Base.Math
+
+poly2(5, θ) == 17.5 # true
-Machine learning often can have *hundreds* of parameter arrays.
-Instead of passing them to `gradient` individually, we can store them together in a structure.
-The simplest example is a named tuple, created by the following syntax:
+# output
-```jldoctest basics
-julia> nt = (a = [2, 1], b = [2, 0], c = tanh);
+true
+```
+
+Flux chooses a third path, by *encapsulating* the parameters within the function.
+The simplest way to do this is a *closure*, an anonymous function which Julia knows
+to depend on some local variable `θ3`:
-julia> g(x::NamedTuple) = sum(abs2, x.a .- x.b);
+```jldoctest poly; output = false
+poly3 = let θ3 = [10, 1, 0.1]
+ x -> evalpoly(x, θ3)
+end
-julia> g(nt)
-1
+poly3(5) == 17.5 # true
-julia> dg_nt = gradient(g, nt)[1]
-(a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing)
+# output
+
+true
```
-Notice that `gradient` has returned a matching structure. The field `dg_nt.a` is the gradient
-for `nt.a`, and so on. Some fields have no gradient, indicated by `nothing`.
+An equivalent, but tidier, way is to construct a `struct` in which to store the parameters.
+Any struct can be made callable, allowing its instances to act just like function:
-Rather than define a function like `g` every time (and think up a name for it),
-it is often useful to use anonymous functions: this one is `x -> sum(abs2, x.a .- x.b)`.
-Anonymous functions can be defined either with `->` or with `do`,
-and such `do` blocks are often useful if you have a few steps to perform:
+```jldoctest poly; output = false
+struct Poly3{T} # container struct
+ θ3::T
+end
+(p::Poly3)(x::Real) = evalpoly(x, p.θ3) # make this callable
-```jldoctest basics
-julia> gradient((x, y) -> sum(abs2, x.a ./ y .- x.b), nt, [1, 2])
-((a = [0.0, 0.5], b = [-0.0, -1.0], c = nothing), [-0.0, -0.25])
+poly3s = Poly3([10, 1, 0.1]) # construct an instance
-julia> gradient(nt, [1, 2]) do x, y
- z = x.a ./ y
- sum(abs2, z .- x.b)
- end
-((a = [0.0, 0.5], b = [-0.0, -1.0], c = nothing), [-0.0, -0.25])
+poly3s(5) == 17.5 # true
+
+# output
+
+true
```
-Sometimes you may want to know the value of the function, as well as its gradient.
-Rather than calling the function a second time, you can call [`withgradient`](@ref Zygote.withgradient) instead:
+Internally, there is little difference between a closure and a struct.
+They have the same fields, and equivalent methods:
+```julia
+dump(poly3), dump(poly3s) # both contain θ3: Array
+poly3s.θ3 == poly3.θ3 == θ # field called :θ3 has same value
+methods(poly3)
+methods(poly3s) # each has 1 method, accepting x
```
-julia> Flux.withgradient(g, nt)
-(val = 1, grad = ((a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing),))
+
+The virtue of encapsulation is that it makes composition very easy.
+We can make more complicated functions by combining simple ones,
+and each will keep track of its own parameters.
+Juia writes function composition as `∘`, for instance `(inv ∘ sin)(pi/6) ≈ 2`,
+and we can use exactly this for our parameterised polynomials:
+
+```jldoctest poly; output = false
+poly4 = Poly3([1, 0.5, 0]) ∘ Poly3([10, 1, 0.1])
+
+poly4 isa ComposedFunction # ∘ creates another struct...
+poly4.outer.θ3 == θ # which has fields :inner & :outer
+
+poly4(5) == 9.75 # true
+
+# output
+
+true
```
-## Building Simple Models
+Flux models are precisely made by such function composition.
+In fact, `poly3` and `poly4` are already valid Flux models.
-Consider a simple linear regression, which tries to predict an output array `y` from an input `x`.
-```julia
-predict(W, b, x) = W*x .+ b
+## [Structural Gradients](@id man-taking-gradients)
-function loss(W, b, x, y)
- ŷ = predict(W, b, x)
- sum((y .- ŷ).^2)
-end
+The derivative of a scalar function is its slope: how fast the output changes as the input is changed slightly.
+This may be found approximately by evaluating at two nearby points, and exactly by taking the limit in
+which the distance between them approaches zero:
-x, y = rand(5), rand(2) # Dummy data
-W = rand(2, 5)
-b = rand(2)
+```jldoctest poly
+julia> (poly1(5 + 0.1) - poly1(5)) / 0.1
+2.010000000000005
-loss(W, b, x, y) # ~ 3
+julia> (poly1(5 + 0.001) - poly1(5)) / 0.001 # answer is getting close to 2
+2.000100000003613
```
-To improve the prediction we can take the gradients of the loss with respect to `W` and `b` and perform gradient descent.
+Flux's `gradient(f, x)` works this out for `f(x)`, and gives exactly `∂f/∂x = 2.0` here:
-```julia
-using Flux
+```jldoctest poly
+julia> using Flux
-dW, db = gradient((W, b) -> loss(W, b, x, y), W, b)
+julia> gradient(poly1, 5)
+(2.0,)
```
-Now that we have gradients, we can pull them out and update `W` to train the model.
+The reason `gradient` returns a tuple, not just the number `2.0`, is to allow for
+functions taking several arguments. (That's also why it's not called "derivative".)
+For instance, this returns `∂f/∂x, ∂f/∂y, ∂f/∂z`:
-```julia
-W .-= 0.1 .* dW
+```jldoctest poly
+julia> gradient((x,y,z) -> (x*y)+z, 30, 40, 50)
+(40.0, 30.0, 1.0)
+```
+
+For our parameterised polynomial, we have `∂f/∂x` but we are really more interested
+in `∂f/∂θ`, as this will tell us about how the parameters are affecting the answer.
+It is not impossible to track gradients with respect to global `θ`, but much clearer to track explicit arguments.
+Here's how this works for `poly2` (which takes `θ` as a 2nd argument) and `poly3` (which encapsulates `θ`):
+
+```jldoctest poly
+julia> grad2 = gradient(poly2, 5, θ)
+(2.0, [1.0, 5.0, 25.0])
+
+julia> grad3 = gradient((x,p) -> p(x), 5, poly3s)
+(2.0, (θ3 = [1.0, 5.0, 25.0],))
+```
+
+The first entry is `∂f/∂x` as before, but the second entry is more interesting.
+For `poly2`, we get `∂f/∂θ` as `grad2[2]` directly.
+It is a vector, because `θ` is a vector, and has elements `[∂f/∂θ[1], ∂f/∂θ[2], ∂f/∂θ[3]]`.
+
+For `poly3s`, however, we get a `NamedTuple` whose fields correspond to those of the struct `Poly3`.
+This is called a *structural gradient*. And the nice thing about them is that they work for
+arbitrarily complicated structures, for instance:
-loss(W, b, x, y) # ~ 2.5
+```jldoctest poly
+julia> grad4 = gradient(|>, 5, poly4)
+(1.0, (outer = (θ3 = [1.0, 17.5, 306.25],), inner = (θ3 = [0.5, 2.5, 12.5],)))
```
-The loss has decreased a little, meaning that our prediction `x` is closer to the target `y`. If we have some data we can already try [training the model](../training/training.md).
+Here `grad4.inner.θ3` corresponds to `poly4.inner.θ3`.
+These matching nested structures are at the core of how Flux works.
+
+!!! note "Implicit gradients"
+ Earlier versions of Flux used a different way to relate parameters and gradients,
+ which looks like this:
+ ```julia
+ g1 = gradient(() -> poly1(5), Params([θ]))
+ g1[θ] == [1.0, 5.0, 25.0]
+ ```
+ Here `Params` is a set of references to global variables using `objectid`,
+ and `g1 isa Grads` is a dictionary from these to their gradients.
+ This method of `gradient` takes a zero-argument function, which only *implicitly*
+ depends on `θ`.
+
+```@raw html
+
+```
-All deep learning in Flux, however complex, is a simple generalisation of this example. Of course, models can *look* very different – they might have millions of parameters or complex control flow. Let's see how Flux handles more complex models.
+Flux's [`gradient`](@ref) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl).
+Zygote performs source-to-source automatic differentiation, meaning that `gradient(f, x)`
+hooks into Julia's compiler to find out what operations `f` contains, and transforms this
+to produce code for computing `∂f/∂x`.
-## Building Layers
+Zygote can in principle differentiate almost any Julia code. However, it's not perfect,
+and you may eventually want to read its [page about limitations](https://fluxml.ai/Zygote.jl/dev/limitations/).
+In particular, a major limitation is that mutating an array is not allowed.
-It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) in between them. We could write this as:
+Flux can also be used with other automatic differentiation (AD) packages.
+It was originally written using [Tracker](https://github.com/FluxML/Tracker.jl), a more traditional operator-overloading approach.
+The future might be [Enzyme](https://github.com/EnzymeAD/Enzyme.jl), and Flux now builds in an easy way to use this instead, turned on by wrapping the model in `Duplicated`. (For details, see the [Enzyme page](@ref autodiff-enzyme) in the manual.)
```julia
-using Flux
+julia> using Enzyme: Const, Duplicated
-W1 = rand(3, 5)
-b1 = rand(3)
-layer1(x) = W1 * x .+ b1
+julia> grad3e = Flux.gradient((x,p) -> p(x), Const(5.0), Duplicated(poly3s))
+(nothing, (θ3 = [1.0, 5.0, 25.0],))
+```
-W2 = rand(2, 3)
-b2 = rand(2)
-layer2(x) = W2 * x .+ b2
+`Flux.gradient` follows Zygote's convention that arguments with no derivative are marked `nothing`.
+Here, this is because `Const(5.0)` is explicitly constant.
+Below, we will see an example where `nothing` shows up because the model struct has fields containing things other than parameters, such as an activation function.
+(It also adopts the convention that `gradient(f, x, y)` returns a tuple `(∂f/∂x, ∂f/∂y)`, without a "`∂f/∂f`" term for the function. This is why we had to write `gradient(|>, 5, poly4)` above, not just `gradient(poly4, 5)`.)
-model(x) = layer2(sigmoid.(layer1(x)))
+Finally, the function [`withgradient`](@ref) works the same way, but also returns the value of the function:
-model(rand(5)) # => 2-element vector
+```jldoctest poly
+julia> Flux.withgradient((x,p) -> p(x), 5.0, poly3s)
+(val = 17.5, grad = (2.0, (θ3 = [1.0, 5.0, 25.0],)))
```
-This works but is fairly unwieldy, with a lot of repetition – especially as we add more layers. One way to factor this out is to create a function that returns linear layers.
+## Simple Neural Networks
-```julia
-function linear(in, out)
- W = randn(out, in)
- b = randn(out)
- x -> W * x .+ b
-end
+The polynomial functions above send a number `x` to another a number `y`.
+Neural networks typically take a vector of numbers, mix them all up, and return another vector.
+Here's a very simple one, which will take a vector like `x = [1.0, 2.0, 3.0]`
+and return another vector `y = layer1(x)` with `length(y) == 2`:
-linear1 = linear(5, 3) # we can access linear1.W etc
-linear2 = linear(3, 2)
+```jldoctest poly; output = false
+W = randn(2, 3)
+b = zeros(2)
-model(x) = linear2(sigmoid.(linear1(x)))
+sigmoid(x::Real) = 1 / (1 + exp(-x))
+layer1(x) = sigmoid.(W*x .+ b)
-model(rand(5)) # => 2-element vector
+# output
+
+layer1 (generic function with 1 method)
```
-Another (equivalent) way is to create a struct that explicitly represents the affine layer.
+Here `sigmoid` is a nonlinear function, applied element-wise
+because it is called with `.()`, called broadcasting.
-```julia
-struct Affine
- W
- b
-end
+Like `poly1` above, this `layer1` has as its parameters the global variables `W, b`.
+We can similarly define a version which takes these as arguments (like `poly2`),
+and a version which encapsulates them (like `poly3` above):
+
+```jldoctest poly; output = false
+layer2(x, W2, b2) = sigmoid.(W2*x .+ b2) # explicit parameter arguments
-Affine(in::Integer, out::Integer) =
- Affine(randn(out, in), zeros(out))
+layer3 = let
+ W3 = randn(2, 3)
+ b3 = zeros(2)
+ x -> sigmoid.(W3*x .+ b3) # closure over local variables
+end
-# Overload call, so the object can be used as a function
-(m::Affine)(x) = m.W * x .+ m.b
+layer3([1.0, 2.0, 3.0]) isa Vector # check that it runs
-a = Affine(10, 5)
+# output
-a(rand(10)) # => 5-element vector
+true
```
-Congratulations! You just built the [`Dense`](@ref) layer that comes with Flux. Flux has many interesting layers available, but they're all things you could have built yourself very easily.
+This third way is precisely a Flux model. And we can again make a tidier version
+using a `struct` to hold the parameters:
-(There is one small difference with `Dense` – for convenience it also takes an activation function, like `Dense(10 => 5, sigmoid)`.)
+```jldoctest poly; output = false, filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
+struct Layer # container struct
+ W::Matrix
+ b::Vector
+ act::Function
+end
-## Stacking It Up
+(d::Layer)(x) = d.act.(d.W*x .+ d.b) # make it callabale
-It's pretty common to write models that look something like:
+Layer(in::Int, out::Int, act::Function=sigmoid) =
+ Layer(randn(Float32, out, in), zeros(Float32, out), act)
-```julia
-layer1 = Dense(10 => 5, relu)
-# ...
-model(x) = layer3(layer2(layer1(x)))
+layer3s = Layer(3, 2) # instance with its own parameters
+
+# output
+
+Layer(Float32[0.6911411 0.47683495 -0.75600505; 0.5247729 1.2508286 0.27635413], Float32[0.0, 0.0], sigmoid)
```
-For long chains, it might be a bit more intuitive to have a list of layers, like this:
+The one new thing here is a friendly constructor `Layer(in, out, act)`.
+This is because we anticipate composing several instances of this thing,
+with independent parameter arrays, of different sizes and different
+random initial parameters.
-```julia
-using Flux
+Let's try this out, and look at its gradient:
+
+```jldoctest poly; output = false, filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
+x = Float32[0.1, 0.2, 0.3] # input
-layers = [Dense(10 => 5, relu), Dense(5 => 2), softmax]
+layer3s(x) # output, 2-element Vector{Float32}
-model(x) = foldl((x, m) -> m(x), layers, init = x)
+Flux.gradient((x,d) -> d(x)[1], x, layer3s)[2] # NamedTuple{(:W, :b, :act)}
-model(rand(10)) # => 2-element vector
+# output
+
+(W = Float32[0.024975738 0.049951475 0.07492722; 0.0 0.0 0.0], b = Float32[0.24975738, 0.0], act = nothing)
```
-Handily, this is also provided for in Flux:
+This `∂f/∂layer3s` is a named tuple with the same fields as `Layer`.
+Within it, the gradient with respect to `W` is a matrix of seemingly random numbers.
+Notice that there is also an entry for `act`, which is `nothing`,
+as this field of the struct is not a smoothly adjustible parameter.
-```julia
-model2 = Chain(
- Dense(10 => 5, relu),
- Dense(5 => 2),
- softmax)
+We can compose these layers just as we did the polynomials above, in `poly4`.
+Here's a composition of 3 functions, in which the last step is the function `only`
+which takes a 1-element vector and gives us the number inside:
+
+```jldoctest poly; output = false, filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
+model1 = only ∘ Layer(20, 1) ∘ Layer(1, 20)
+
+y = model1(Float32[0.1]) # output is a Float32 number
+
+grad = Flux.gradient(|>, [1f0], model1)[2]
-model2(rand(10)) # => 2-element vector
+# output
+
+(outer = (outer = nothing, inner = (W = Float32[0.058179587 0.1276911 … 0.08071162 0.034993216], b = Float32[0.14223717], act = nothing)), inner = (W = Float32[-0.048111934; -0.0008379104; … ; 0.017658396; -0.015104223;;], b = Float32[-0.048111934, -0.0008379104, 0.017207285, 0.026828118, -0.024858447, -0.015956078, 0.0020494608, -0.012577536, -0.044770215, 0.01478136, 0.034534186, -0.004748393, 0.026848236, -0.016794706, -0.041044597, 0.016186379, -0.036814954, 0.034786277, 0.017658396, -0.015104223], act = nothing))
```
-This quickly starts to look like a high-level deep learning library; yet you can see how it falls out of simple abstractions, and we lose none of the power of Julia code.
+This gradient is starting to be a complicated nested structure.
+But it works just like before: `grad.outer.inner.W` corresponds to `model1.outer.inner.W`.
+
+We don't have to use `∘` (which makes a `ComposedFunction` struct) to combine layers.
+Instead, we could define our own container struct, or use a closure.
+This `model2` will work the same way (although its fields have different names):
+
+```jldoctest poly; output = false, filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
+model2 = let
+ lay1 = Layer(1, 20) # local variables containing layers
+ lay2 = Layer(20, 1)
+ function fwd(x) # equivalent to x -> only(lay2(lay1(x)))
+ mid = lay1(x)
+ lay2(mid) |> only
+ end
+end
+
+model2(Float32[0.1])
-A nice property of this approach is that because "models" are just functions (possibly with trainable parameters), you can also see this as simple function composition.
+Flux.gradient(|>, [1f0], model2)[2]
-```julia
-m = Dense(5 => 2) ∘ Dense(10 => 5, σ)
+# output
-m(rand(10))
+(lay2 = (W = Float32[0.051824596 0.03971491 … 0.038365345 0.051143322], b = Float32[0.09477656], act = nothing), lay1 = (W = Float32[-0.00049770635; 0.002891017; … ; -0.0022540581; 0.0039325757;;], b = Float32[-0.00049770635, 0.002891017, -0.00865399, -0.015051818, -0.005504916, -0.004188145, -0.01533527, -0.0059600063, -0.003092169, -0.00697084, -0.012470333, -0.0048766206, -0.010671042, -0.006604657, -0.0086712, -0.0044975257, -0.0028462198, -0.009992857, -0.0022540581, 0.0039325757], act = nothing))
```
-Likewise, `Chain` will happily work with any Julia function.
+```@raw html
+
+```
-```julia
-m = Chain(x -> x^2, x -> x+1)
+Rather than define everything from scratch every time, Flux provides a library of
+commonly used layers. The same model could be defined:
+
+```jldoctest poly; output = false
+model3 = Chain(Dense(1 => 20, σ), Dense(20 => 1), only)
-m(5) # => 26
+# output
+
+Chain(
+ Dense(1 => 20, σ), # 40 parameters
+ Dense(20 => 1), # 21 parameters
+ only,
+) # Total: 4 arrays, 61 parameters, 452 bytes.
```
-## Layer Helpers
+How does this `model3` differ from the `model1` we had before?
+
+* Flux's [`Chain`](@ref Flux.Chain) works left-to-right, the reverse of Base's `∘`.
+ Its contents is stored in a tuple, thus `model3.layers[1].weight` is an array.
+* Flux's layer [`Dense`](@ref Flux.Dense) has only minor differences from our `struct Layer`:
+ - Like `struct Poly3{T}` above, it has type parameters for its fields -- the compiler does not know exactly what type `layer3s.W` will be, which costs speed.
+ - Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref) by default.
+ - It reshapes some inputs (to allow several batch dimensions), and produces more friendly errors on wrong-size input.
+ - And it has some performance tricks: making sure element types match, and re-using some memory.
+* The function [`σ`](@ref NNlib.sigmoid) is calculated in a slightly better way,
+ and has a rule telling Zygote how to differentiate it efficiently.
+* Flux overloads `Base.show` so to give pretty printing at the REPL prompt.
+ Calling [`Flux.@layer Layer`](@ref Flux.@layer) will add this, and some other niceties.
+
+All Flux layers accept a batch of samples: Instead of mapping one sample `x::Vector` to one output `y::Vector`, they map columns of a matrix `xs::Matrix` to columns of the output. This looks like `f(xs) ≈ stack(f(x) for x in eachcol(xs))` but is done more efficiently.
+
+If what you need isn't covered by Flux's built-in layers, it's easy to write your own.
+There are more details [later](@ref man-advanced), but the steps are invariably those shown for `struct Layer` above:
+1. Define a `struct` which will hold the parameters.
+2. Make it callable, to define how it uses them to transform the input `x`
+3. Define a constructor which initialises the parameters (if the default constructor doesn't do what you want).
+4. Annotate with `@layer` to opt-in to pretty printing, and other enhacements.
+
+```@raw html
+
+```
-We can give our layer some additional functionality, like nice printing, using the [`@layer`](@ref Flux.@layer) macro:
+To deal with such nested structures, Flux relies heavily on an associated package
+called Functors. Its basic function is [`fmap`](@ref Functors.fmap),
+which generalises `map(f, x)` to work on almost anything.
+
+For example, this is how [gpu](@ref Flux.gpu) moves all arrays within a model to the GPU,
+reconstructing another `only ∘ Layer(...) ∘ Layer(...)` (or a `Chain` etc.) around the new `CuArray`s:
```julia
-Flux.@layer Affine
+using CUDA, Functors
+fmap(cu, model1)
```
-Finally, most Flux layers make bias optional, and allow you to supply the function used for generating random weights. We can easily add these refinements to the `Affine` layer as follows, using the helper function [`create_bias`](@ref Flux.create_bias):
+And this is a very simple gradient update of the parameters, walking over `model` and `grad` simultaneously:
```julia
-function Affine((in, out)::Pair; bias=true, init=glorot_uniform)
- W = init(out, in)
- b = Flux.create_bias(W, bias, out)
- return Affine(W, b)
+fmap((x, dx) -> x isa Array ? (x - dx/100) : x, model, grad)
+```
+
+!!! note
+ Before Flux v0.15 (and Functors v0.5), this exploration of structs was opt-in.
+ After defining `struct Layer` it was necessary to call `@functor Layer` (or `@layer Layer`) before Flux would look inside.
+ This has now changed to be opt-out: Functors (and hence Flux) will explore arbitrary structs, unless told not to (using `Functors.@leaf`).
+ This is why even "anonymous structs" created by closures, like `poly3` and `layer3` above, are now valid Flux models, although the use of named structs is still recommended practice.
+
+## Curve Fitting
+
+Above we took gradients of the output, or sometimes to the first element
+of the output -- it must be a number, not a vector. Adjusting the parameters
+to make this smaller won't lead us anywhere interesting. Instead, we should minimise
+some *loss function* which compares the actual output to our desired output.
+
+Perhaps the simplest example is curve fitting. The [previous page](@ref man-overview) fitted
+a linear model to data. With out two-layer model, we can fit a nonlinear function.
+For example, let us use `f(x) = 2x - x^3` evaluated at some points `x in -2:0.1:2` as the data,
+and adjust the parameters of `model3` from above so that its output is similar.
+
+```jldoctest poly; output = false
+data = [([x], 2x-x^3) for x in -2:0.1f0:2] # training points (x, y)
+
+for _ in 1:1000 # adjust parameters to minimise the error:
+ Flux.train!((m,x,y) -> (m(x) - y)^2, model3, data, Descent(0.01))
end
-Affine(3 => 1, bias=false) |> gpu
+# output
+
+```
+
+The same code will also work with `model1` or `model2` instead.
+Here's how to plot the desired and actual outputs:
+
+```julia
+using Plots
+plot(x -> 2x-x^3, -2, 2, label="truth")
+scatter!(x -> model3([x]), -2:0.1f0:2, label="fitted")
```
+More detail about what exactly the function `train!` is doing, and how to use rules other than simple [`Descent`](@ref Optimisers.Descent), is what the next page in this guide is about: [training](@ref man-training).
diff --git a/docs/src/reference/models/layers.md b/docs/src/reference/models/layers.md
index b798a35291..8e5c0e873c 100644
--- a/docs/src/reference/models/layers.md
+++ b/docs/src/reference/models/layers.md
@@ -1,9 +1,4 @@
-```@meta
-CurrentModule = Flux
-CollapsedDocStrings = true
-```
-
-# Built-in Layer Types
+# [Built-in Layer Types](@id man-layers)
If you started at the beginning of the guide, then you have already met the
basic [`Dense`](@ref) layer, and seen [`Chain`](@ref) for combining layers.
@@ -58,7 +53,7 @@ documented in NNlib's [Attention](@ref) section.
```@docs
MultiHeadAttention
-```
+```
### Pooling
@@ -75,7 +70,7 @@ GlobalMeanPool
## Upsampling
-The opposite of pooling, these layers increase the size of an array. They have no trainable parameters.
+The opposite of pooling, these layers increase the size of an array. They have no trainable parameters.
```@docs
Upsample
@@ -135,7 +130,7 @@ Flux.normalise
### Test vs. Train
-Several normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference.
+Several normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference.
!!! warning
This automatic train/test detection works best with Zygote, the default
diff --git a/docs/src/guide/models/custom_layers.md b/docs/src/tutorials/custom_layers.md
similarity index 100%
rename from docs/src/guide/models/custom_layers.md
rename to docs/src/tutorials/custom_layers.md