Skip to content

Commit

Permalink
tidy, add summarysize
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 17, 2022
1 parent a6e85fc commit 2bb637a
Showing 1 changed file with 17 additions and 34 deletions.
51 changes: 17 additions & 34 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ There is also [`Optimisers.update!`](@ref) which similarly returns a new model a
but is free to mutate arrays within the old one for efficiency.
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.
(The method of `apply!` above is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.)
For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`:

```julia
Base.summarysize(model) / 1024^2 # about 45MB
Base.summarysize(state) / 1024^2 # about 90MB
```

Optimisers.jl does not depend on any one automatic differentiation package,
but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl).
Expand All @@ -72,6 +80,7 @@ This `∇model` is another tree structure, rather than the dictionary-like objec
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.


## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl)

Yota is another modern automatic differentiation package, an alternative to Zygote.
Expand All @@ -89,40 +98,6 @@ loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
end;
```

Unfortunately this example doesn't actually run right now. This is the error:
```
julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x))
end;
┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via:
│ (f, args) = Yota.RRULE_VIA_AD_STATE[]
└ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160
ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using
ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ...
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197
[3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238
[4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249
[5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol)
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276
[6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
@ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109
[7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
@ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153
...
(jl_GWa2lX) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml`
⌃ [587475ba] Flux v0.13.4
[cd998857] Yota v0.7.4
```

## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)

Expand Down Expand Up @@ -163,6 +138,14 @@ y, lux_state = Lux.apply(lux_model, images, params, lux_state);
Besides the parameters stored in `params` and gradually optimised, any other model state
is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.)
This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.

```julia
Base.summarysize(lux_model) / 1024 # just 2KB
Base.summarysize(params) / 1024^2 # about 45MB, same as Flux model
Base.summarysize(lux_state) / 1024 # 40KB
Base.summarysize(opt_state) / 1024^2 # about 90MB, with Adam
```

If you are certain there is no model state, then the gradient calculation can
be simplified to use `Zygote.gradient` instead of `Zygote.pullback`:

Expand Down

0 comments on commit 2bb637a

Please sign in to comment.