Skip to content

Commit

Permalink
actually try out the doc examples
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Dec 8, 2022
1 parent a7d575f commit 181c2f0
Showing 1 changed file with 47 additions and 8 deletions.
55 changes: 47 additions & 8 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ state = Optimisers.setup(rule, model); # initialise this optimiser's momentum e
end;

state, model = Optimisers.update(state, model, ∇model);
@show sum(model(image));
@show sum(model(image)); # reduced

```

Expand Down Expand Up @@ -82,14 +82,51 @@ To extract what Optimisers.jl needs, you can write `_, (_, ∇model) = Yota.grad
or, for the Flux model above:

```julia
using Yota

loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x))
end;
```

Unfortunately this example doesn't actually run right now. This is the error:
```
julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x))
end;
┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via:
│ (f, args) = Yota.RRULE_VIA_AD_STATE[]
└ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160
ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using
ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ...
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197
[3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238
[4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249
[5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol)
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276
[6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
@ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109
[7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
@ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153
...
(jl_GWa2lX) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml`
⌃ [587475ba] Flux v0.13.4
[cd998857] Yota v0.7.4
```

## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)

The main design difference of Lux is that the tree of parameters is separate from
The main design difference of Lux from Flux is that the tree of parameters is separate from
the layer structure. It is these parameters which `setup` and `update` need to know about.

Lux describes this separation of parameter storage from model description as "explicit" parameters.
Expand All @@ -104,25 +141,27 @@ using Lux, Boltz, Zygote, Optimisers
lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model
images = rand(Float32, 224, 224, 3, 4) |> gpu; # batch of dummy data
y, lux_state = Lux.apply(lux_model, images, params, lux_state); # run the model
@show sum(y) # initial dummy loss
@show sum(y); # initial dummy loss

rule = Optimisers.Adam()
opt_state = Optimisers.setup(rule, params); # optimiser state based on model parameters

(loss, lux_state), back = Zygote.pullback(params, images) do p, x
y, st = Lux.apply(lux_model, x, p, lux_state)
sum(y), st # return both the loss, and the updated lux_state
end
∇params, _ = back((one.(loss), nothing)) # gradient of only the loss, with respect to parameter tree

@show sum(loss)
end;
∇params, _ = back((one.(loss), nothing)); # gradient of only the loss, with respect to parameter tree
loss == sum(y) # not yet changed

opt_state, params = Optimisers.update!(opt_state, params, ∇params);

y, lux_state = Lux.apply(lux_model, images, params, lux_state);
@show sum(y); # now reduced

```

Besides the parameters stored in `params` and gradually optimised, any other model state
is stored in `lux_state`, and returned by `Lux.apply`.
is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.)
This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.
If you are certain there is no model state, then the gradient calculation can
be simplified to use `Zygote.gradient` instead of `Zygote.pullback`:
Expand Down

0 comments on commit 181c2f0

Please sign in to comment.