-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
derivative of derivative does not work #135
Comments
The short answer is that higher-order derivatives are not supported. A more detailed answer is that Yota indeed can compute higher-order derivatives for a well-behaved subset of functions and their first-order derivatives, but the approach is pretty different from what you expect. To understand how and why, consider the following example: julia> using Yota
julia> # the function of interest
f(x, y, z) = x*y*z
f (generic function with 1 method)
julia> Yota.Umlaut.trace(f, 1.0f0, 2.0f0, 3.0f0)
(6.0f0, Tape{Umlaut.BaseCtx}
inp %1::typeof(f)
inp %2::Float32
inp %3::Float32
inp %4::Float32
%5 = *(%2, %3, %4)::Float32
)
julia> tape = Yota.gradtape(f, 1.f0, 2.f0, 3.f0);
julia> Yota.Umlaut.show_compact(stdout, tape)
Tape{Yota.GradCtx}
inp %1::typeof(f)
inp %2::Float32
inp %3::Float32
inp %4::Float32
%6, %7 = [%5] = rrule(Yota.YotaRuleConfig(), *, %2, %3, %4)
const %8 = 1::Int64
_, %10, %11, %12 = [%9] = %7(%8)
%13 = tuple(ChainRulesCore.ZeroTangent(), %10, %11, %12)::Tuple{ChainRulesCore.ZeroTangent, Float32, Float32, Float32}
%14 = map(ChainRulesCore.unthunk, %13)::Tuple{ChainRulesCore.ZeroTangent, Float32, Float32, Float32}
%15 = tuple(%6, %14)::Tuple{Float32, Tuple{ChainRulesCore.ZeroTangent, Float32, Float32, Float32}} The most important lines here are: %6, %7 = [%5] = rrule(Yota.YotaRuleConfig(), *, %2, %3, %4)
_, %10, %11, %12 = [%9] = %7(%8) which is equivalent to: val, pullback = rrule(Yota.YotaRuleConfig(), *, x, y, z)
df, dx, dy, dz = pullback(seed)
Some AD engines like Diffractor or Enzyme can still go through these functions, but performance implications are sometimes terrible. Yota prefers to throw an error whenever it encounters something it's not about (usually, internal Julia functions not covered by ChainRules). If your first derivative contains such functions, you are out of luck. If your first derivative is differentiable, you should be able to compute higher order derivatives. But you'll have to unwrap gradient calculation from all the caching and import Yota: gradtape, grad_compile
tape = gradtape(f, args...; seed=seed)
gf = grad_compile(tape)
gradtape(gf, ...) or, even better, work with the tape directly: tape = gradtape(f, args...; seed=seed)
tape.result = ... # define the target variable
gradtape!(tape) It's not as scary is it may look, but still far beyond high-level user API like |
Sorry for the long delay, I tried the following: Approach 1:
Which gives me the error: Approach 2:
Which gives me the error: |
gradtape(gf1, tape1.ops[5].args[4]) In this case, tape2 = gradtape(f, x...)
tape2.result = tape2.ops[5].args[4] # define the target variable
gradtape!(tape2) I'm not sure what you are trying to do here, but julia> tape2 = gradtape(f, x...)
Tape{Yota.GradCtx}
inp %1::typeof(f)
inp %2::Float32
inp %3::Float32 <--- you point here
inp %4::Float32
%5 = rrule(Yota.YotaRuleConfig(), *, %2, %3, %4)::Tuple{Float32, ChainRules.var"#times_pullback3#1347"{Float32, Float32, Float32}}
%6 = _getfield(%5, 1)::Float32
%7 = _getfield(%5, 2)::ChainRules.var"#times_pullback3#1347"{Float32, Float32, Float32}
const %8 = 1::Int64
%9 = %7(%8)::Tuple{ChainRulesCore.NoTangent, Float32, Float32, Float32}
%10 = getfield(%9, 2)::Float32
%11 = getfield(%9, 3)::Float32
%12 = getfield(%9, 4)::Float32
%13 = tuple(ChainRulesCore.ZeroTangent(), %10, %11, %12)::Tuple{ChainRulesCore.ZeroTangent, Float32, Float32, Float32}
%14 = map(ChainRulesCore.unthunk, %13)::Tuple{ChainRulesCore.ZeroTangent, Float32, Float32, Float32}
%15 = tuple(%6, %14)::Tuple{Float32, Tuple{ChainRulesCore.ZeroTangent, Float32, Float32, Float32}}
julia> tape2.ops[5].args[4]
%3 You need to choose or add another variable as the output of the forward pass, e.g. julia> gradtape!(tape2)
ERROR: No derivative rule found for op %9 = %7(%8)::Tuple{ChainRulesCore.NoTangent, Float32, Float32, Float32} , try defining it using
ChainRulesCore.rrule(::ChainRules.var"#times_pullback3#1347"{Float32, Float32, Float32}, ::Int64) = ... Indeed, there's no chainrule for julia> import Yota.Umlaut.primitivize!
julia> primitivize!(tape2)
julia> tape2
Tape{Yota.GradCtx}
inp %1::typeof(f)
inp %2::Float32
inp %3::Float32
inp %4::Float32
%5 = tuple(%3, %4, %5, %6)::Tuple{typeof(*), Float32, Float32, Float32}
%6 = check_variable_length(%5, 4, 7)::Nothing
const %7 = ChainRules.var"#times_pullback3#1347"::UnionAll
%8 = typeof(%4)::DataType
%9 = typeof(%5)::DataType
%10 = typeof(%6)::DataType But pullback functions are not designed to be differentiable themselves, and so the "primitivized" tape is already too low-level for practical use. So, as I said, it only works for a well-behaved subset of functions and their first-order derivatives. It can be useful if you design your differentiable system from scratch and are ready to invest a lot into your own closed system of primitives. Some people indeed do it using Yota/Umlaut. But if want something that will just work, then Yota is limited to first-order derivatives. |
Taking the derivative of derivative with Yota gives an error. See the following MWE.
Stacktrace:
What I additionally remarked is that fx returns
Any
instead of Float32, see here:The text was updated successfully, but these errors were encountered: