diff --git a/docs/src/index.md b/docs/src/index.md index 00991651..9be3915a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -64,6 +64,14 @@ There is also [`Optimisers.update!`](@ref) which similarly returns a new model a but is free to mutate arrays within the old one for efficiency. The method of `apply!` for each rule is likewise free to mutate arrays within its state; they are defensively copied when this rule is used with `update`. +(The method of `apply!` above is likewise free to mutate arrays within its state; +they are defensively copied when this rule is used with `update`.) +For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`: + +```julia +Base.summarysize(model) / 1024^2 # about 45MB +Base.summarysize(state) / 1024^2 # about 90MB +``` Optimisers.jl does not depend on any one automatic differentiation package, but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl). @@ -72,6 +80,7 @@ This `∇model` is another tree structure, rather than the dictionary-like objec Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference. + ## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl) Yota is another modern automatic differentiation package, an alternative to Zygote. @@ -89,40 +98,6 @@ loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x end; ``` -Unfortunately this example doesn't actually run right now. This is the error: -``` -julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x - sum(m(x)) - end; -┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via: -│ (f, args) = Yota.RRULE_VIA_AD_STATE[] -└ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160 -ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using - - ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ... - -Stacktrace: - [1] error(s::String) - @ Base ./error.jl:35 - [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable) - @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197 - [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol) - @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238 - [4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol) - @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249 - [5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol) - @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276 - [6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}) - @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109 - [7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}) - @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153 -... - -(jl_GWa2lX) pkg> st -Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml` -⌃ [587475ba] Flux v0.13.4 - [cd998857] Yota v0.7.4 -``` ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl) @@ -163,6 +138,14 @@ y, lux_state = Lux.apply(lux_model, images, params, lux_state); Besides the parameters stored in `params` and gradually optimised, any other model state is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.) This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit. + +```julia +Base.summarysize(lux_model) / 1024 # just 2KB +Base.summarysize(params) / 1024^2 # about 45MB, same as Flux model +Base.summarysize(lux_state) / 1024 # 40KB +Base.summarysize(opt_state) / 1024^2 # about 90MB, with Adam +``` + If you are certain there is no model state, then the gradient calculation can be simplified to use `Zygote.gradient` instead of `Zygote.pullback`: