Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested Differentiation #35

Closed
MikeInnes opened this issue Apr 5, 2019 · 3 comments
Closed

Nested Differentiation #35

MikeInnes opened this issue Apr 5, 2019 · 3 comments

Comments

@MikeInnes
Copy link

I assume this is an upstream issue in nesting Cassette contexts with metadata, but filing just in case.

julia> using Yota

julia> D(f, x) = grad(f, x)[2][1]
D (generic function with 1 method)

julia> D(x -> x*x, 5)
10

julia> D(x -> D(sin, x), 0.5)
ERROR: MethodError: no method matching metadatatype(::Type{Cassette.Context{nametype(TraceCtx),Yota.TapeBox,Cassette.Tag{nametype(TraceCtx),0xae68220057761640,Nothing},getfield(Cassette, Symbol("##PassType#365")),IdDict{Module,Dict{Symbol,Cassette.BindingMeta}},Nothing}}, ::Type{Type})
@dfdx
Copy link
Owner

dfdx commented Apr 7, 2019

Thanks for posting this. grad() has very specific use case - generate highly optimized code for first-order derivatives. It includes caching, memory pre-allocation, etc. which is hard to use in nested differentiation. Perhaps, a closer function is simplegrad() which avoids all of these tricks and returns plain old Julia code, although it has one more obstacle before supporting nested differentiation, and I think I will be able to look at this obstacle soon.

Code below is mostly for my own reference:

julia> import Yota: simplegrad

julia> foo(x) = x * x
foo (generic function with 1 method)

julia> simplegrad(foo, 5)
##tape_fn#362 (generic function with 1 method)

julia> ∇foo = simplegrad(foo, 5)
##tape_fn#363 (generic function with 1 method)

julia> ∇foo(5)
(25, 10)

julia> ∇foo(5)[2]
10

julia> ∇nested = simplegrad(x -> ∇foo(x)[2], 5)
ERROR: BoundsError: attempt to access 3-element Array{Yota.AbstractOp,1} at index [-1]
Stacktrace:
 [1] getindex(::Array{Yota.AbstractOp,1}, ::Int64) at ./array.jl:729
 [2] getindex(::Yota.Tape, ::Int64) at /home/slipslop/work/Yota/src/tape.jl:112
 [3] back!(::Yota.Tape) at /home/slipslop/work/Yota/src/grad.jl:166
 [4] _grad(::Yota.Tape) at /home/slipslop/work/Yota/src/grad.jl:201
 [5] simplegrad(::Function, ::Int64) at /home/slipslop/work/Yota/src/grad.jl:213
 [6] top-level scope at none:0

julia> _, tape = Yota.trace(x -> ∇foo(x)[2], 5)
(10, Tape
  inp %1::Int64
  %2 = *(%1, %1)::Int64
  const %3 = 1::Int64
  %4 = *(%1, %3)::Int64
  const %5 = 1::Int64
  %6 = *(%1, %5)::Int64
  const %7 = +::typeof(+)
  %8 = broadcast(%7, %6, %4)::Int64
  const %9 = 2::Int64
  %10 = getindex(%-1, %9)::Int64
)

@MikeInnes
Copy link
Author

Might be interesting to look at how Jax handles this kind of thing. When the tracer hits optimised code it could, for example, instead trace through the original un-optimised code, and then you can optimise the result nested AD trace as if it were first order.

@dfdx
Copy link
Owner

dfdx commented Jul 3, 2021

During the last 2 years no work has been done towards nested differentiation, so it would be fair to say it's out of Yota's scope for indefinite time.

@dfdx dfdx closed this as completed Jul 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants