Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

derivative of derivative does not work #135

Open
MariusDrulea opened this issue Mar 4, 2023 · 3 comments
Open

derivative of derivative does not work #135

MariusDrulea opened this issue Mar 4, 2023 · 3 comments

Comments

@MariusDrulea
Copy link

Taking the derivative of derivative with Yota gives an error. See the following MWE.

using Yota

# the function of interest
f(x, y, z) = x*y*z

# yota gradients
fx(x, y, z) = grad(x->f(x, y, z), x)[2][2]
fxy(x, y, z) = grad(y->fx(x, y, z), y)[2][2]
fxyz(x, y, z) = grad(z->fxy(x, y, z), z)

fx(1.f0, 2.f0, 3.f0) # this works, output is 6.f0
fxy(1.f0, 2.f0, 3.f0) # error
fxyz(1.f0, 2.f0, 3.f0) 

Stacktrace:

ERROR: No derivative rule found for op %69 = %62(%59, %63, %50, %20, %4)::Tuple{Float32, Tuple{ChainRulesCore.Tangent{var"#13#14"{Float32, Float32}, NamedTuple{(:y, :z), Tuple{Float32, Float32}}}, Float32}} , try defining it using

        ChainRulesCore.rrule(::Base.var"#invokelatest##kw", ::NamedTuple{(:seed,), Tuple{Int64}}, ::typeof(Base.invokelatest), ::Yota.var"###tape_#13#316", ::var"#13#14"{Float32, Float32}, ::Float32) = ...

Stacktrace:
 [1] error(s::String)
   @ Base .\error.jl:35
 [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
   @ Yota D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:178
 [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
   @ Yota D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:220
 [4] #gradtape!#77
   @ D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:245 [inlined]
 [5] gradtape(f::Function, args::Float32; ctx::Yota.GradCtx, seed::Int64)
   @ Yota D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:264
 [6] grad(f::Function, args::Float32; seed::Int64)
   @ Yota D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:356
 [7] grad
   @ D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:348 [inlined]
 [8] fxy(x::Float32, y::Float32, z::Float32) grad_grad_yota.jl:8
 [9] top-level scope grad_grad_yota.jl:12

What I additionally remarked is that fx returns Any instead of Float32, see here:

@code_warntype fx(1.f0, 2.f0, 3.f0)

MethodInstance for fx(::Float32, ::Float32, ::Float32)
  from fx(x, y, z) in grad_grad_yota.jl:7
Arguments
  #self#::Core.Const(fx)
  x::Float32
  y::Float32
  z::Float32
Locals
  #23::var"#23#24"{Float32, Float32}
Body::Any
1 ─ %1 = Main.:(var"#23#24")::Core.Const(var"#23#24")
│   %2 = Core.typeof(y)::Core.Const(Float32)
│   %3 = Core.typeof(z)::Core.Const(Float32)
│   %4 = Core.apply_type(%1, %2, %3)::Core.Const(var"#23#24"{Float32, Float32})
│        (#23 = %new(%4, y, z))
│   %6 = #23::var"#23#24"{Float32, Float32}
│   %7 = Main.grad(%6, x)::Any
│   %8 = Base.getindex(%7, 2)::Any
│   %9 = Base.getindex(%8, 2)::Any
└──      return %9
@dfdx
Copy link
Owner

dfdx commented Mar 4, 2023

The short answer is that higher-order derivatives are not supported.

A more detailed answer is that Yota indeed can compute higher-order derivatives for a well-behaved subset of functions and their first-order derivatives, but the approach is pretty different from what you expect. To understand how and why, consider the following example:

julia> using Yota

julia> # the function of interest
       f(x, y, z) = x*y*z
f (generic function with 1 method)

julia> Yota.Umlaut.trace(f, 1.0f0, 2.0f0, 3.0f0)
(6.0f0, Tape{Umlaut.BaseCtx}
  inp %1::typeof(f)
  inp %2::Float32
  inp %3::Float32
  inp %4::Float32
  %5 = *(%2, %3, %4)::Float32 
)

Tape is the computation graph that Yota works on. trace(), as its name suggests, traces the function to produce this graph. Our graph has 4 inputs (including function itself) and a single operation *. This graph is differentiable. Internally, Yota takes the graph and creates another graph:

julia> tape = Yota.gradtape(f, 1.f0, 2.f0, 3.f0);

julia> Yota.Umlaut.show_compact(stdout, tape)
Tape{Yota.GradCtx}
  inp %1::typeof(f)
  inp %2::Float32
  inp %3::Float32
  inp %4::Float32
  %6, %7 = [%5] = rrule(Yota.YotaRuleConfig(), *, %2, %3, %4) 
  const %8 = 1::Int64
  _, %10, %11, %12 = [%9] = %7(%8) 
  %13 = tuple(ChainRulesCore.ZeroTangent(), %10, %11, %12)::Tuple{ChainRulesCore.ZeroTangent, Float32, Float32, Float32} 
  %14 = map(ChainRulesCore.unthunk, %13)::Tuple{ChainRulesCore.ZeroTangent, Float32, Float32, Float32} 
  %15 = tuple(%6, %14)::Tuple{Float32, Tuple{ChainRulesCore.ZeroTangent, Float32, Float32, Float32}}

The most important lines here are:

  %6, %7 = [%5] = rrule(Yota.YotaRuleConfig(), *, %2, %3, %4) 
  _, %10, %11, %12 = [%9] = %7(%8) 

which is equivalent to:

  val, pullback = rrule(Yota.YotaRuleConfig(), *, x, y, z) 
  df, dx, dy, dz = pullback(seed) 

rrule and pullbacks are part of ChainRules machinery. To differentiate this second tape you need all the rrules and pullbacks to be also differentiable, but ChainRules doesn't guarantee it. That is, rrules are allowed to contain functions that don't have their own rrules.

Some AD engines like Diffractor or Enzyme can still go through these functions, but performance implications are sometimes terrible. Yota prefers to throw an error whenever it encounters something it's not about (usually, internal Julia functions not covered by ChainRules). If your first derivative contains such functions, you are out of luck.

If your first derivative is differentiable, you should be able to compute higher order derivatives. But you'll have to unwrap gradient calculation from all the caching and invokelatest in grad(). Perhaps, you can use:

import Yota: gradtape, grad_compile

tape = gradtape(f, args...; seed=seed)
gf = grad_compile(tape)
gradtape(gf, ...)

or, even better, work with the tape directly:

tape = gradtape(f, args...; seed=seed)
tape.result = ...  # define the target variable
gradtape!(tape)

It's not as scary is it may look, but still far beyond high-level user API like grad(f, args...). Which brings us to the initial point - higher-order derivatives are not supported :)

@MariusDrulea
Copy link
Author

MariusDrulea commented Sep 5, 2023

Sorry for the long delay,
The question is what to put in the ... for gradtape(gf, ...) and for tape.result = ... # define the target variable.

I tried the following:

Approach 1:

x = Float32[1, 2, 3]
tape1 = gradtape(f, x...)
gf1 = grad_compile(tape1)
gradtape(gf1, tape1.ops[5].args[4])

Which gives me the error: AssertionError: No IR found for ##tape_f#301(DataType[Umlaut.Variable]...)

Approach 2:

tape2 = gradtape(f, x...)
tape2.result = tape2.ops[5].args[4] # define the target variable
gradtape!(tape2)

Which gives me the error: AssertionError: The tape's result is expected to be a Call, but instead Umlaut.Input was encountered

@dfdx
Copy link
Owner

dfdx commented Sep 7, 2023

gradtape(gf1, tape1.ops[5].args[4])

In this case, gradtape expects a list of appropriate values, e.g. numbers, arrays, functions, etc., but instead you provide a Variable object. For example, gradtape(x -> sin(x), 1.0) is valid, but gradtape(x -> sin(x), V(1)) is not, because sin(::V) is not defined.

tape2 = gradtape(f, x...)
tape2.result = tape2.ops[5].args[4] # define the target variable
gradtape!(tape2)

I'm not sure what you are trying to do here, but tape2.ops[5].args[4] is indeed an input operation, not something you can (need to) differentiate :)

julia> tape2 = gradtape(f, x...)
Tape{Yota.GradCtx}
  inp %1::typeof(f)
  inp %2::Float32
  inp %3::Float32     <--- you point here
  inp %4::Float32
  %5 = rrule(Yota.YotaRuleConfig(), *, %2, %3, %4)::Tuple{Float32, ChainRules.var"#times_pullback3#1347"{Float32, Float32, Float32}} 
  %6 = _getfield(%5, 1)::Float32 
  %7 = _getfield(%5, 2)::ChainRules.var"#times_pullback3#1347"{Float32, Float32, Float32} 
  const %8 = 1::Int64
  %9 = %7(%8)::Tuple{ChainRulesCore.NoTangent, Float32, Float32, Float32} 
  %10 = getfield(%9, 2)::Float32 
  %11 = getfield(%9, 3)::Float32 
  %12 = getfield(%9, 4)::Float32 
  %13 = tuple(ChainRulesCore.ZeroTangent(), %10, %11, %12)::Tuple{ChainRulesCore.ZeroTangent, Float32, Float32, Float32} 
  %14 = map(ChainRulesCore.unthunk, %13)::Tuple{ChainRulesCore.ZeroTangent, Float32, Float32, Float32} 
  %15 = tuple(%6, %14)::Tuple{Float32, Tuple{ChainRulesCore.ZeroTangent, Float32, Float32, Float32}} 


julia> tape2.ops[5].args[4]
%3

You need to choose or add another variable as the output of the forward pass, e.g. tape2.result = V(tape, 12) (i.e. %12 = getfield(%9, 4)::Float32 ). But it's not enough, because:

julia> gradtape!(tape2)
ERROR: No derivative rule found for op %9 = %7(%8)::Tuple{ChainRulesCore.NoTangent, Float32, Float32, Float32} , try defining it using 

	ChainRulesCore.rrule(::ChainRules.var"#times_pullback3#1347"{Float32, Float32, Float32}, ::Int64) = ...

Indeed, there's no chainrule for ChainRules.var"#times_pullback3#1347", which is the pullback function of *. You can "unwrap" all such complex function calls into a list of primitives using Umlaut.primitivize!:

julia> import Yota.Umlaut.primitivize!

julia> primitivize!(tape2)

julia> tape2
Tape{Yota.GradCtx}
  inp %1::typeof(f)
  inp %2::Float32
  inp %3::Float32
  inp %4::Float32
  %5 = tuple(%3, %4, %5, %6)::Tuple{typeof(*), Float32, Float32, Float32} 
  %6 = check_variable_length(%5, 4, 7)::Nothing 
  const %7 = ChainRules.var"#times_pullback3#1347"::UnionAll
  %8 = typeof(%4)::DataType 
  %9 = typeof(%5)::DataType 
  %10 = typeof(%6)::DataType 

But pullback functions are not designed to be differentiable themselves, and so the "primitivized" tape is already too low-level for practical use.

So, as I said, it only works for a well-behaved subset of functions and their first-order derivatives. It can be useful if you design your differentiable system from scratch and are ready to invest a lot into your own closed system of primitives. Some people indeed do it using Yota/Umlaut. But if want something that will just work, then Yota is limited to first-order derivatives.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants