From 76b681c1b55a238141f81205c370d33f69d50167 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 17 Aug 2022 16:25:38 -0600 Subject: [PATCH 1/8] test with Yota too, and document this --- Project.toml | 4 +++- docs/src/index.md | 49 +++++++++++++++++++++++++++++++++++------------ test/rules.jl | 15 +++++++++++++++ test/runtests.jl | 2 +- 4 files changed, 56 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 41b23a8b..7cf05fad 100644 --- a/Project.toml +++ b/Project.toml @@ -13,13 +13,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" Functors = "0.3, 0.4" +Yota = "0.7.3" Zygote = "0.6.40" julia = "1.6" [extras] StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Yota = "cd998857-8626-517d-b929-70ad188a48f0" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "StaticArrays", "Zygote"] +test = ["Test", "StaticArrays", "Yota", "Zygote"] diff --git a/docs/src/index.md b/docs/src/index.md index 9ebfac0b..68d56628 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -38,7 +38,7 @@ to adjust the model: ```julia -using Flux, Metalhead, Optimisers +using Flux, Metalhead, Zygote, Optimisers model = Metalhead.ResNet(18) |> gpu # define a model to train image = rand(Float32, 224, 224, 3, 1) |> gpu; # dummy data @@ -72,6 +72,21 @@ This `∇model` is another tree structure, rather than the dictionary-like objec Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference. +## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl) + +Yota is another modern automatic differentiation package, an alternative to Zygote. + +Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`) +but also returns a gradient component for the loss function. +To extract what Optimisers.jl needs, you can write `_, (_, ∇model) = Yota.grad(f, model, data)` +or, for the Flux model above: + +```julia +loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x + sum(m(x)) +end; +``` + ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl) The main design difference of Lux is that the tree of parameters is separate from @@ -79,7 +94,7 @@ the layer structure. It is these parameters which `setup` and `update` need to k Lux describes this separation of parameter storage from model description as "explicit" parameters. Beware that it has nothing to do with Zygote's notion of "explicit" gradients. -(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will often be +(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will be nearly identical trees of nested `NamedTuple`s.) ```julia @@ -88,27 +103,37 @@ using Lux, Boltz, Zygote, Optimisers lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model images = rand(Float32, 224, 224, 3, 4) |> gpu; # batch of dummy data -y, _ = Lux.apply(lux_model, images, params, lux_state); # run the model +y, lux_state = Lux.apply(lux_model, images, params, lux_state); # run the model @show sum(y) # initial dummy loss rule = Optimisers.Adam() opt_state = Optimisers.setup(rule, params); # optimiser state based on model parameters -∇params, _ = gradient(params, images) do p, x # gradient with respect to parameter tree - y, _ = Lux.apply(lux_model, x, p, lux_state) - sum(y) -end; +(loss, lux_state), back = Zygote.pullback(params, images) do p, x + y, st = Lux.apply(lux_model, x, p, lux_state) + sum(y), st # return both the loss, and the updated lux_state +end +∇params, _ = back((one.(loss), nothing)) # gradient of only the loss, with respect to parameter tree -opt_state, params = Optimisers.update!(opt_state, params, ∇params); +@show sum(loss) -y, _ = Lux.apply(lux_model, images, params, lux_state); -@show sum(y) +opt_state, params = Optimisers.update!(opt_state, params, ∇params); ``` Besides the parameters stored in `params` and gradually optimised, any other model state -is stored in `lux_state`. For simplicity this example does not show how to propagate the -updated `lux_state` to the next iteration, see Lux's documentation. +is stored in `lux_state`, and returned by `Lux.apply`. +This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit. +If you are certain there is no model state, then the gradient calculation can +be simplified to use `Zygote.gradient` instead of `Zygote.pullback`: + +```julia +∇params, _ = gradient(params, images) do p, x + y, _ = Lux.apply(lux_model, x, p, lux_state) # discards new lux_state + sum(y) +end; +``` + ## Non-`trainable` Parameters diff --git a/test/rules.jl b/test/rules.jl index 0fbe1a00..766ae148 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -229,3 +229,18 @@ end @test static_loss(static_model) < 1.9 end end + +@testset "using Yota" begin + @testset "$(name(o))" for o in RULES + w′ = (abc = (α = rand(3, 3), β = rand(3, 3), γ = rand(3)), d = (δ = rand(3), ε = eps)) + w = (abc = (α = 5rand(3, 3), β = rand(3, 3), γ = rand(3)), d = (δ = rand(3), ε = eps)) + st = Optimisers.setup(o, w) + loss(x, y) = mean((x.abc.α .* x.abc.β .- y.abc.α .* y.abc.β) .^ 2) # does not use γ, δ, ε + @test loss(w, w′) > 0.5 + for i = 1:10^4 + _, (_, g, _) = Yota.grad(loss, w, w′) + st, w = Optimisers.update(st, w, g) + end + @test loss(w, w′) < 0.001 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 1ad2a09a..8340e62f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Optimisers -using ChainRulesCore, Functors, StaticArrays, Zygote +using ChainRulesCore, Functors, StaticArrays, Zygote, Yota using LinearAlgebra, Statistics, Test, Random using Optimisers: @.., @lazy From a7d575f31b869645d7706980a53aab044d9a4807 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 17 Aug 2022 16:30:54 -0600 Subject: [PATCH 2/8] also test destructure --- test/destructure.jl | 56 +++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++++ 2 files changed, 60 insertions(+) diff --git a/test/destructure.jl b/test/destructure.jl index 0b3a482a..fee24368 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -104,6 +104,32 @@ end # Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z) # Diffractor error in perform_optic_transform end + + @testset "using Yota" begin + @test_broken Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] # Unexpected expression: $(Expr(:static_parameter, 1)) + # These are all broken! + @test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) + @test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], ZeroTangent()) + @test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = ZeroTangent(), z = [0,0,0]) + @test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = ZeroTangent(), z = [0,0,0]) + + g5 = Yota_gradient(m -> destructure(m)[1][3], m5)[1] + @test g5.a[1].x == [0,0,1] + @test g5.a[2] === ZeroTangent() + + g6 = Yota_gradient(m -> imag(destructure(m)[1][4]), m6)[1] + @test g6.a == [0,0,0] + @test g6.a isa Vector{Float64} + @test g6.b == [0+im] + + g8 = Yota_gradient(m -> sum(abs2, destructure(m)[1]), m8)[1] + @test g8[1].x == [2,4,6] + @test g8[2].b.x == [8] + @test g8[3] == [[10.0]] + + g9 = Yota_gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1] + @test g9.c === ZeroTangent() + end end @testset "gradient of rebuild" begin @@ -149,6 +175,36 @@ end # Not fixed by this: # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) end + + @testset "using Yota" begin + re1 = destructure(m1)[2] + @test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] + re2 = destructure(m2)[2] + @test Yota_gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0] + re3 = destructure(m3)[2] + @test Yota_gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0] + @test Yota_gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0] + + re4 = destructure(m4)[2] + @test Yota_gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0] + @test Yota_gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0] + @test Yota_gradient(rand(6)) do x + m = re4(x) + m.x[1] + 2*m.y[2] + 3*m.z[3] + end[1] == [1,2,0, 0,0,3] + + re7 = destructure(m7)[2] + @test Yota_gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1] + @test Yota_gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] + @test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] + + v8, re8 = destructure(m8) + @test_broken Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] # MethodError: no method matching zero(::Type{Any}) + @test_broken Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] # MethodError: no method matching !(::Expr) + + re9 = destructure(m9)[2] + @test_broken Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] # MethodError: no method matching zero(::Type{Array}) + end end @testset "Flux issue 1826" begin diff --git a/test/runtests.jl b/test/runtests.jl index 8340e62f..294f5a6f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,6 +37,10 @@ function Optimisers.apply!(o::BiRule, state, x, dx, dx2) return state, dx end +# Make Yota's output look like Zygote's: + +Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2]) + @testset verbose=true "Optimisers.jl" begin @testset verbose=true "Features" begin From 181c2f040e2a2b8cdeabffab1397cf0fa51774a7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 19 Aug 2022 16:26:13 -0400 Subject: [PATCH 3/8] actually try out the doc examples --- docs/src/index.md | 55 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 68d56628..b1cb8532 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -52,7 +52,7 @@ state = Optimisers.setup(rule, model); # initialise this optimiser's momentum e end; state, model = Optimisers.update(state, model, ∇model); -@show sum(model(image)); +@show sum(model(image)); # reduced ``` @@ -82,14 +82,51 @@ To extract what Optimisers.jl needs, you can write `_, (_, ∇model) = Yota.grad or, for the Flux model above: ```julia +using Yota + loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x sum(m(x)) end; ``` +Unfortunately this example doesn't actually run right now. This is the error: +``` +julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x + sum(m(x)) + end; +┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via: +│ (f, args) = Yota.RRULE_VIA_AD_STATE[] +└ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160 +ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using + + ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ... + +Stacktrace: + [1] error(s::String) + @ Base ./error.jl:35 + [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable) + @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197 + [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol) + @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238 + [4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol) + @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249 + [5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol) + @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276 + [6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}) + @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109 + [7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}) + @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153 +... + +(jl_GWa2lX) pkg> st +Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml` +⌃ [587475ba] Flux v0.13.4 + [cd998857] Yota v0.7.4 +``` + ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl) -The main design difference of Lux is that the tree of parameters is separate from +The main design difference of Lux from Flux is that the tree of parameters is separate from the layer structure. It is these parameters which `setup` and `update` need to know about. Lux describes this separation of parameter storage from model description as "explicit" parameters. @@ -104,7 +141,7 @@ using Lux, Boltz, Zygote, Optimisers lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model images = rand(Float32, 224, 224, 3, 4) |> gpu; # batch of dummy data y, lux_state = Lux.apply(lux_model, images, params, lux_state); # run the model -@show sum(y) # initial dummy loss +@show sum(y); # initial dummy loss rule = Optimisers.Adam() opt_state = Optimisers.setup(rule, params); # optimiser state based on model parameters @@ -112,17 +149,19 @@ opt_state = Optimisers.setup(rule, params); # optimiser state based on model pa (loss, lux_state), back = Zygote.pullback(params, images) do p, x y, st = Lux.apply(lux_model, x, p, lux_state) sum(y), st # return both the loss, and the updated lux_state -end -∇params, _ = back((one.(loss), nothing)) # gradient of only the loss, with respect to parameter tree - -@show sum(loss) +end; +∇params, _ = back((one.(loss), nothing)); # gradient of only the loss, with respect to parameter tree +loss == sum(y) # not yet changed opt_state, params = Optimisers.update!(opt_state, params, ∇params); +y, lux_state = Lux.apply(lux_model, images, params, lux_state); +@show sum(y); # now reduced + ``` Besides the parameters stored in `params` and gradually optimised, any other model state -is stored in `lux_state`, and returned by `Lux.apply`. +is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.) This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit. If you are certain there is no model state, then the gradient calculation can be simplified to use `Zygote.gradient` instead of `Zygote.pullback`: From 1a426ce7e65dae8b52898051a99e3ad8ea66764b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 19 Aug 2022 16:54:29 -0400 Subject: [PATCH 4/8] tidy, add summarysize --- docs/src/index.md | 51 ++++++++++++++++------------------------------- 1 file changed, 17 insertions(+), 34 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index b1cb8532..5a1e5210 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -64,6 +64,14 @@ There is also [`Optimisers.update!`](@ref) which similarly returns a new model a but is free to mutate arrays within the old one for efficiency. The method of `apply!` for each rule is likewise free to mutate arrays within its state; they are defensively copied when this rule is used with `update`. +(The method of `apply!` above is likewise free to mutate arrays within its state; +they are defensively copied when this rule is used with `update`.) +For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`: + +```julia +Base.summarysize(model) / 1024^2 # about 45MB +Base.summarysize(state) / 1024^2 # about 90MB +``` Optimisers.jl does not depend on any one automatic differentiation package, but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl). @@ -72,6 +80,7 @@ This `∇model` is another tree structure, rather than the dictionary-like objec Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference. + ## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl) Yota is another modern automatic differentiation package, an alternative to Zygote. @@ -89,40 +98,6 @@ loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x end; ``` -Unfortunately this example doesn't actually run right now. This is the error: -``` -julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x - sum(m(x)) - end; -┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via: -│ (f, args) = Yota.RRULE_VIA_AD_STATE[] -└ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160 -ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using - - ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ... - -Stacktrace: - [1] error(s::String) - @ Base ./error.jl:35 - [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable) - @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197 - [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol) - @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238 - [4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol) - @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249 - [5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol) - @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276 - [6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}) - @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109 - [7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}) - @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153 -... - -(jl_GWa2lX) pkg> st -Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml` -⌃ [587475ba] Flux v0.13.4 - [cd998857] Yota v0.7.4 -``` ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl) @@ -163,6 +138,14 @@ y, lux_state = Lux.apply(lux_model, images, params, lux_state); Besides the parameters stored in `params` and gradually optimised, any other model state is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.) This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit. + +```julia +Base.summarysize(lux_model) / 1024 # just 2KB +Base.summarysize(params) / 1024^2 # about 45MB, same as Flux model +Base.summarysize(lux_state) / 1024 # 40KB +Base.summarysize(opt_state) / 1024^2 # about 90MB, with Adam +``` + If you are certain there is no model state, then the gradient calculation can be simplified to use `Zygote.gradient` instead of `Zygote.pullback`: From 86a23bbe9ff67decfd18315e95f51b7ca962dbb3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 17 Oct 2022 14:40:21 -0400 Subject: [PATCH 5/8] add again changes made on website which got lost in a local rebase without checking first because I forgot about this for ages --- Project.toml | 2 +- test/destructure.jl | 19 +++++++++---------- test/runtests.jl | 5 ++++- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 7cf05fad..ef4a65da 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" Functors = "0.3, 0.4" -Yota = "0.7.3" +Yota = "0.8.1" Zygote = "0.6.40" julia = "1.6" diff --git a/test/destructure.jl b/test/destructure.jl index fee24368..592b9be9 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -106,16 +106,15 @@ end end @testset "using Yota" begin - @test_broken Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] # Unexpected expression: $(Expr(:static_parameter, 1)) - # These are all broken! + @test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] @test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) - @test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], ZeroTangent()) - @test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = ZeroTangent(), z = [0,0,0]) - @test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = ZeroTangent(), z = [0,0,0]) + @test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) + @test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) + @test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) g5 = Yota_gradient(m -> destructure(m)[1][3], m5)[1] @test g5.a[1].x == [0,0,1] - @test g5.a[2] === ZeroTangent() + @test g5.a[2] === nothing g6 = Yota_gradient(m -> imag(destructure(m)[1][4]), m6)[1] @test g6.a == [0,0,0] @@ -128,7 +127,7 @@ end @test g8[3] == [[10.0]] g9 = Yota_gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1] - @test g9.c === ZeroTangent() + @test g9.c === nothing end end @@ -199,11 +198,11 @@ end @test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] v8, re8 = destructure(m8) - @test_broken Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] # MethodError: no method matching zero(::Type{Any}) - @test_broken Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] # MethodError: no method matching !(::Expr) + @test Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] + @test Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] re9 = destructure(m9)[2] - @test_broken Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] # MethodError: no method matching zero(::Type{Array}) + @test Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] end end diff --git a/test/runtests.jl b/test/runtests.jl index 294f5a6f..23d474c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,7 +39,10 @@ end # Make Yota's output look like Zygote's: -Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2]) +Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2])) +y2z(::AbstractZero) = nothing # we don't care about different flavours of zero +y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples! +y2z(x) = x @testset verbose=true "Optimisers.jl" begin @testset verbose=true "Features" begin From e4a21d9592786b8d800dbe18c624f98b12b0cae9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 30 Oct 2022 20:20:48 -0400 Subject: [PATCH 6/8] Yota 0.8.2, etc --- Project.toml | 2 +- docs/src/index.md | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index ef4a65da..15fc6479 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" Functors = "0.3, 0.4" -Yota = "0.8.1" +Yota = "0.8.2" Zygote = "0.6.40" julia = "1.6" diff --git a/docs/src/index.md b/docs/src/index.md index 5a1e5210..863428b7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -62,8 +62,6 @@ tree formed by the model and update the parameters using the gradients. There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, but is free to mutate arrays within the old one for efficiency. -The method of `apply!` for each rule is likewise free to mutate arrays within its state; -they are defensively copied when this rule is used with `update`. (The method of `apply!` above is likewise free to mutate arrays within its state; they are defensively copied when this rule is used with `update`.) For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`: @@ -87,17 +85,18 @@ Yota is another modern automatic differentiation package, an alternative to Zygo Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`) but also returns a gradient component for the loss function. -To extract what Optimisers.jl needs, you can write `_, (_, ∇model) = Yota.grad(f, model, data)` -or, for the Flux model above: +To extract what Optimisers.jl needs, you can write (for the Flux model above): ```julia using Yota loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x - sum(m(x)) + sum(m(x) end; -``` +# Or else, this may save computing ∇image: +loss, (_, ∇model) = grad(m -> sum(m(image)), model); +``` ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl) From 8562963248bb5a6bc1926ca2a5d3d26cf70a2bad Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 7 Dec 2022 19:45:58 -0500 Subject: [PATCH 7/8] skip Yota tests on 1.9 & later --- test/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rules.jl b/test/rules.jl index 766ae148..fd9660a1 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -230,7 +230,7 @@ end end end -@testset "using Yota" begin +VERSION < v"1.9-" && @testset "using Yota" begin @testset "$(name(o))" for o in RULES w′ = (abc = (α = rand(3, 3), β = rand(3, 3), γ = rand(3)), d = (δ = rand(3), ε = eps)) w = (abc = (α = 5rand(3, 3), β = rand(3, 3), γ = rand(3)), d = (δ = rand(3), ε = eps)) From ce3cc0ca4491aeb1cfc99c31251b9a61b0473e58 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 7 Dec 2022 21:27:03 -0500 Subject: [PATCH 8/8] skip more tests --- test/destructure.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/destructure.jl b/test/destructure.jl index 592b9be9..90f28fb4 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -105,7 +105,7 @@ end # Diffractor error in perform_optic_transform end - @testset "using Yota" begin + VERSION < v"1.9-" && @testset "using Yota" begin @test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] @test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) @test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) @@ -175,7 +175,7 @@ end # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) end - @testset "using Yota" begin + VERSION < v"1.9-" && @testset "using Yota" begin re1 = destructure(m1)[2] @test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] re2 = destructure(m2)[2]