You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Let's take rrule for matrix multiplication as an example. At the moment we differentiate it by rewriting:
y = A * B
with
rr =rrule(*, A, B)
y =getfield(rr, 1)
pb =getfield(rr, 2)
...
drr =pb(dy)
dA =getfield(drr, 2)
dB =getfield(drr, 3)
There are several issues with this approach:
The pullback pb is a closure and thus cannot be serialized e.g. to ONNX.
Since rrule is a single call, we cannot
The code becomes much harder to read and find inconsistencies or mistakes.
If we take a look at this rrule's code:
functionrrule(
::typeof(*),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
project_A =ProjectTo(A)
project_B =ProjectTo(B)
functiontimes_pullback(ȳ)
Ȳ =unthunk(ȳ)
dA =@thunk(project_A(Ȳ * B'))
dB =@thunk(project_B(A'* Ȳ))
returnNoTangent(), dA, dB
endreturn A * B, times_pullback
end
we can see that for ordinary dense matrices it can be replaced with this:
y = A * B
...
dA = dy * B'
dB = A'* dy
which is much easier to work with.
I'm not sure if it will work well in general case, but one way to implement it is to tweak record_primitive!() to trace rrule() and split its primal and pullback code into 2 separate lists of operations. Something like:
functionrecord_primitive!(tape::Tape{GradCtx}, v_fargs...)
v_f, v_args...= v_fargs
f, args...= [v isa V ? tape[v].val : v for v in v_fargs]
ifisprimitive(ChainRulesCtx(), f, args...)
t = tape.c.tracer # a bit weird backref, but let it be for this example
res =trace!(t, get_code_info(f, args...), v_fargs...)
v_val, v_pb = tape[res].args # destructure tuple constructed as the return value from rrule
tape.c.pullbacks[v_val] = v_pb
return v_val
elsereturnpush!(tape, mkcall(v_fargs...))
endend
Then, during the reverse pass, we can trace the saved pullback and re-map captured values to variables from the primal subtape.
This is pretty sophisticated approach, but so far it looks doable.
(Todo: check out how JAX implements it)
The text was updated successfully, but these errors were encountered:
Let's take
rrule
for matrix multiplication as an example. At the moment we differentiate it by rewriting:with
There are several issues with this approach:
pb
is a closure and thus cannot be serialized e.g. to ONNX.rrule
is a single call, we cannotIf we take a look at this
rrule
's code:we can see that for ordinary dense matrices it can be replaced with this:
which is much easier to work with.
I'm not sure if it will work well in general case, but one way to implement it is to tweak
record_primitive!()
to tracerrule()
and split its primal and pullback code into 2 separate lists of operations. Something like:Then, during the reverse pass, we can trace the saved pullback and re-map captured values to variables from the primal subtape.
This is pretty sophisticated approach, but so far it looks doable.
(Todo: check out how JAX implements it)
The text was updated successfully, but these errors were encountered: