Skip to content

Commit

Permalink
Apply 3 suggestions from myself
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Aug 27, 2022
1 parent 7eddf0e commit ef4fedd
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
ChainRulesCore = "1"
Functors = "0.2.8, 0.3"
Yota = "0.7.3"
Yota = "0.8.0"
Zygote = "0.6.40"
julia = "1.6"

Expand Down
42 changes: 15 additions & 27 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,36 +92,24 @@ end;
Unfortunately this example doesn't actually run right now. This is the error:
```
julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x))
end;
┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via:
│ (f, args) = Yota.RRULE_VIA_AD_STATE[]
└ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160
ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using
ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ...
sum(m(x))
end;
ERROR: BoundsError: attempt to access Nothing at index [1]
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197
[3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238
[4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249
[5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol)
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276
[6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
@ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109
[7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
@ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153
[1] _getfield(value::Nothing, fld::Int64)
@ Yota ~/.julia/packages/Yota/uu3H0/src/helpers.jl:40
[2] mkcall(::Function, ::Umlaut.Variable, ::Vararg{Any}; val::Missing, line::Nothing, kwargs::NamedTuple{(), Tuple{}}, free_kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/tape.jl:192
[3] mkcall
@ ~/.julia/packages/Umlaut/SvDaQ/src/tape.jl:174 [inlined]
[4] chainrules_transform!(tape::Umlaut.Tape{Yota.GradCtx})
@ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:183
...
(jl_GWa2lX) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml`
⌃ [587475ba] Flux v0.13.4
[cd998857] Yota v0.7.4
(@v1.9) pkg> st Yota
Status `~/.julia/environments/v1.9/Project.toml`
[cd998857] Yota v0.8.0
```

## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ struct TwoThirds a; b; c; end
Functors.@functor TwoThirds (a, c)
Optimisers.trainable(x::TwoThirds) = (a = x.a,)

Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2])
Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]));
y2z(::AbstractZero) = nothing # we don't care about different flavours
y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t)))
y2z(x) = x

@testset verbose=true "Optimisers.jl" begin
@testset verbose=true "Features" begin
Expand Down

0 comments on commit ef4fedd

Please sign in to comment.