From 33a2e784742b918ae96437f11792c9eb5f1d4f22 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 30 Oct 2022 20:20:48 -0400 Subject: [PATCH] Yota 0.8.2, etc --- Project.toml | 2 +- docs/src/index.md | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 2105f2da..cf06eeb2 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" Functors = "0.3" -Yota = "0.8.1" +Yota = "0.8.2" Zygote = "0.6.40" julia = "1.6" diff --git a/docs/src/index.md b/docs/src/index.md index a5cb1040..ed34344c 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -62,8 +62,6 @@ tree formed by the model and update the parameters using the gradients. There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, but is free to mutate arrays within the old one for efficiency. -The method of `apply!` for each rule is likewise free to mutate arrays within its state; -they are defensively copied when this rule is used with `update`. (The method of `apply!` above is likewise free to mutate arrays within its state; they are defensively copied when this rule is used with `update`.) For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`: @@ -87,17 +85,18 @@ Yota is another modern automatic differentiation package, an alternative to Zygo Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`) but also returns a gradient component for the loss function. -To extract what Optimisers.jl needs, you can write `_, (_, ∇model) = Yota.grad(f, model, data)` -or, for the Flux model above: +To extract what Optimisers.jl needs, you can write (for the Flux model above): ```julia using Yota loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x - sum(m(x)) + sum(m(x) end; -``` +# Or else, this may save computing ∇image: +loss, (_, ∇model) = grad(m -> sum(m(image)), model); +``` ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)