diff --git a/Project.toml b/Project.toml index 41b23a8b..15fc6479 100644 --- a/Project.toml +++ b/Project.toml @@ -13,13 +13,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" Functors = "0.3, 0.4" +Yota = "0.8.2" Zygote = "0.6.40" julia = "1.6" [extras] StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Yota = "cd998857-8626-517d-b929-70ad188a48f0" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "StaticArrays", "Zygote"] +test = ["Test", "StaticArrays", "Yota", "Zygote"] diff --git a/docs/src/index.md b/docs/src/index.md index 9ebfac0b..863428b7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -38,7 +38,7 @@ to adjust the model: ```julia -using Flux, Metalhead, Optimisers +using Flux, Metalhead, Zygote, Optimisers model = Metalhead.ResNet(18) |> gpu # define a model to train image = rand(Float32, 224, 224, 3, 1) |> gpu; # dummy data @@ -52,7 +52,7 @@ state = Optimisers.setup(rule, model); # initialise this optimiser's momentum e end; state, model = Optimisers.update(state, model, ∇model); -@show sum(model(image)); +@show sum(model(image)); # reduced ``` @@ -62,8 +62,14 @@ tree formed by the model and update the parameters using the gradients. There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, but is free to mutate arrays within the old one for efficiency. -The method of `apply!` for each rule is likewise free to mutate arrays within its state; -they are defensively copied when this rule is used with `update`. +(The method of `apply!` above is likewise free to mutate arrays within its state; +they are defensively copied when this rule is used with `update`.) +For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`: + +```julia +Base.summarysize(model) / 1024^2 # about 45MB +Base.summarysize(state) / 1024^2 # about 90MB +``` Optimisers.jl does not depend on any one automatic differentiation package, but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl). @@ -72,14 +78,34 @@ This `∇model` is another tree structure, rather than the dictionary-like objec Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference. + +## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl) + +Yota is another modern automatic differentiation package, an alternative to Zygote. + +Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`) +but also returns a gradient component for the loss function. +To extract what Optimisers.jl needs, you can write (for the Flux model above): + +```julia +using Yota + +loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x + sum(m(x) +end; + +# Or else, this may save computing ∇image: +loss, (_, ∇model) = grad(m -> sum(m(image)), model); +``` + ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl) -The main design difference of Lux is that the tree of parameters is separate from +The main design difference of Lux from Flux is that the tree of parameters is separate from the layer structure. It is these parameters which `setup` and `update` need to know about. Lux describes this separation of parameter storage from model description as "explicit" parameters. Beware that it has nothing to do with Zygote's notion of "explicit" gradients. -(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will often be +(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will be nearly identical trees of nested `NamedTuple`s.) ```julia @@ -88,27 +114,47 @@ using Lux, Boltz, Zygote, Optimisers lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model images = rand(Float32, 224, 224, 3, 4) |> gpu; # batch of dummy data -y, _ = Lux.apply(lux_model, images, params, lux_state); # run the model -@show sum(y) # initial dummy loss +y, lux_state = Lux.apply(lux_model, images, params, lux_state); # run the model +@show sum(y); # initial dummy loss rule = Optimisers.Adam() opt_state = Optimisers.setup(rule, params); # optimiser state based on model parameters -∇params, _ = gradient(params, images) do p, x # gradient with respect to parameter tree - y, _ = Lux.apply(lux_model, x, p, lux_state) - sum(y) +(loss, lux_state), back = Zygote.pullback(params, images) do p, x + y, st = Lux.apply(lux_model, x, p, lux_state) + sum(y), st # return both the loss, and the updated lux_state end; +∇params, _ = back((one.(loss), nothing)); # gradient of only the loss, with respect to parameter tree +loss == sum(y) # not yet changed opt_state, params = Optimisers.update!(opt_state, params, ∇params); -y, _ = Lux.apply(lux_model, images, params, lux_state); -@show sum(y) +y, lux_state = Lux.apply(lux_model, images, params, lux_state); +@show sum(y); # now reduced ``` Besides the parameters stored in `params` and gradually optimised, any other model state -is stored in `lux_state`. For simplicity this example does not show how to propagate the -updated `lux_state` to the next iteration, see Lux's documentation. +is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.) +This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit. + +```julia +Base.summarysize(lux_model) / 1024 # just 2KB +Base.summarysize(params) / 1024^2 # about 45MB, same as Flux model +Base.summarysize(lux_state) / 1024 # 40KB +Base.summarysize(opt_state) / 1024^2 # about 90MB, with Adam +``` + +If you are certain there is no model state, then the gradient calculation can +be simplified to use `Zygote.gradient` instead of `Zygote.pullback`: + +```julia +∇params, _ = gradient(params, images) do p, x + y, _ = Lux.apply(lux_model, x, p, lux_state) # discards new lux_state + sum(y) +end; +``` + ## Non-`trainable` Parameters diff --git a/test/destructure.jl b/test/destructure.jl index 0b3a482a..90f28fb4 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -104,6 +104,31 @@ end # Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z) # Diffractor error in perform_optic_transform end + + VERSION < v"1.9-" && @testset "using Yota" begin + @test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] + @test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) + @test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) + @test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) + @test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) + + g5 = Yota_gradient(m -> destructure(m)[1][3], m5)[1] + @test g5.a[1].x == [0,0,1] + @test g5.a[2] === nothing + + g6 = Yota_gradient(m -> imag(destructure(m)[1][4]), m6)[1] + @test g6.a == [0,0,0] + @test g6.a isa Vector{Float64} + @test g6.b == [0+im] + + g8 = Yota_gradient(m -> sum(abs2, destructure(m)[1]), m8)[1] + @test g8[1].x == [2,4,6] + @test g8[2].b.x == [8] + @test g8[3] == [[10.0]] + + g9 = Yota_gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1] + @test g9.c === nothing + end end @testset "gradient of rebuild" begin @@ -149,6 +174,36 @@ end # Not fixed by this: # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) end + + VERSION < v"1.9-" && @testset "using Yota" begin + re1 = destructure(m1)[2] + @test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] + re2 = destructure(m2)[2] + @test Yota_gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0] + re3 = destructure(m3)[2] + @test Yota_gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0] + @test Yota_gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0] + + re4 = destructure(m4)[2] + @test Yota_gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0] + @test Yota_gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0] + @test Yota_gradient(rand(6)) do x + m = re4(x) + m.x[1] + 2*m.y[2] + 3*m.z[3] + end[1] == [1,2,0, 0,0,3] + + re7 = destructure(m7)[2] + @test Yota_gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1] + @test Yota_gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] + @test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] + + v8, re8 = destructure(m8) + @test Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] + @test Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] + + re9 = destructure(m9)[2] + @test Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] + end end @testset "Flux issue 1826" begin diff --git a/test/rules.jl b/test/rules.jl index 0fbe1a00..fd9660a1 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -229,3 +229,18 @@ end @test static_loss(static_model) < 1.9 end end + +VERSION < v"1.9-" && @testset "using Yota" begin + @testset "$(name(o))" for o in RULES + w′ = (abc = (α = rand(3, 3), β = rand(3, 3), γ = rand(3)), d = (δ = rand(3), ε = eps)) + w = (abc = (α = 5rand(3, 3), β = rand(3, 3), γ = rand(3)), d = (δ = rand(3), ε = eps)) + st = Optimisers.setup(o, w) + loss(x, y) = mean((x.abc.α .* x.abc.β .- y.abc.α .* y.abc.β) .^ 2) # does not use γ, δ, ε + @test loss(w, w′) > 0.5 + for i = 1:10^4 + _, (_, g, _) = Yota.grad(loss, w, w′) + st, w = Optimisers.update(st, w, g) + end + @test loss(w, w′) < 0.001 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 1ad2a09a..23d474c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Optimisers -using ChainRulesCore, Functors, StaticArrays, Zygote +using ChainRulesCore, Functors, StaticArrays, Zygote, Yota using LinearAlgebra, Statistics, Test, Random using Optimisers: @.., @lazy @@ -37,6 +37,13 @@ function Optimisers.apply!(o::BiRule, state, x, dx, dx2) return state, dx end +# Make Yota's output look like Zygote's: + +Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2])) +y2z(::AbstractZero) = nothing # we don't care about different flavours of zero +y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples! +y2z(x) = x + @testset verbose=true "Optimisers.jl" begin @testset verbose=true "Features" begin