From ef4fedd47ef2595e8d99d452901a39ba70ba3742 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 26 Aug 2022 20:38:34 -0400 Subject: [PATCH] Apply 3 suggestions from myself --- Project.toml | 2 +- docs/src/index.md | 42 +++++++++++++++--------------------------- test/runtests.jl | 5 ++++- 3 files changed, 20 insertions(+), 29 deletions(-) diff --git a/Project.toml b/Project.toml index 5f0c3498..ad3e6d9a 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" Functors = "0.2.8, 0.3" -Yota = "0.7.3" +Yota = "0.8.0" Zygote = "0.6.40" julia = "1.6" diff --git a/docs/src/index.md b/docs/src/index.md index cb8b3c42..8dcfb463 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -92,36 +92,24 @@ end; Unfortunately this example doesn't actually run right now. This is the error: ``` julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x - sum(m(x)) - end; -┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via: -│ (f, args) = Yota.RRULE_VIA_AD_STATE[] -└ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160 -ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using - - ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ... - + sum(m(x)) + end; +ERROR: BoundsError: attempt to access Nothing at index [1] Stacktrace: - [1] error(s::String) - @ Base ./error.jl:35 - [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable) - @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197 - [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol) - @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238 - [4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol) - @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249 - [5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol) - @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276 - [6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}) - @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109 - [7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}) - @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153 + [1] _getfield(value::Nothing, fld::Int64) + @ Yota ~/.julia/packages/Yota/uu3H0/src/helpers.jl:40 + [2] mkcall(::Function, ::Umlaut.Variable, ::Vararg{Any}; val::Missing, line::Nothing, kwargs::NamedTuple{(), Tuple{}}, free_kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) + @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/tape.jl:192 + [3] mkcall + @ ~/.julia/packages/Umlaut/SvDaQ/src/tape.jl:174 [inlined] + [4] chainrules_transform!(tape::Umlaut.Tape{Yota.GradCtx}) + @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:183 + ... -(jl_GWa2lX) pkg> st -Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml` -⌃ [587475ba] Flux v0.13.4 - [cd998857] Yota v0.7.4 +(@v1.9) pkg> st Yota +Status `~/.julia/environments/v1.9/Project.toml` + [cd998857] Yota v0.8.0 ``` ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl) diff --git a/test/runtests.jl b/test/runtests.jl index a1cac5bd..f240b89e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,7 +13,10 @@ struct TwoThirds a; b; c; end Functors.@functor TwoThirds (a, c) Optimisers.trainable(x::TwoThirds) = (a = x.a,) -Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2]) +Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2])); +y2z(::AbstractZero) = nothing # we don't care about different flavours +y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) +y2z(x) = x @testset verbose=true "Optimisers.jl" begin @testset verbose=true "Features" begin