Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Primitivize rrules #103

Closed
dfdx opened this issue Feb 1, 2022 · 1 comment
Closed

Primitivize rrules #103

dfdx opened this issue Feb 1, 2022 · 1 comment

Comments

@dfdx
Copy link
Owner

dfdx commented Feb 1, 2022

Let's take rrule for matrix multiplication as an example. At the moment we differentiate it by rewriting:

y = A * B

with

rr = rrule(*, A, B)
y = getfield(rr, 1)
pb = getfield(rr, 2)
...
drr = pb(dy)
dA = getfield(drr, 2)
dB = getfield(drr, 3)

There are several issues with this approach:

  1. The pullback pb is a closure and thus cannot be serialized e.g. to ONNX.
  2. Since rrule is a single call, we cannot
  3. The code becomes much harder to read and find inconsistencies or mistakes.

If we take a look at this rrule's code:

function rrule(
    ::typeof(*),
    A::AbstractVecOrMat{<:CommutativeMulNumber},
    B::AbstractVecOrMat{<:CommutativeMulNumber},
)
    project_A = ProjectTo(A)
    project_B = ProjectTo(B)
    function times_pullback(ȳ)
        Ȳ = unthunk(ȳ)
        dA = @thunk(project_A(Ȳ * B'))
        dB = @thunk(project_B(A' * Ȳ))
        return NoTangent(), dA, dB
    end
    return A * B, times_pullback
end

we can see that for ordinary dense matrices it can be replaced with this:

y = A * B
...
dA = dy * B'
dB = A' * dy

which is much easier to work with.

I'm not sure if it will work well in general case, but one way to implement it is to tweak record_primitive!() to trace rrule() and split its primal and pullback code into 2 separate lists of operations. Something like:

function record_primitive!(tape::Tape{GradCtx}, v_fargs...)
    v_f, v_args... = v_fargs
    f, args... = [v isa V ? tape[v].val : v for v in v_fargs]
    if isprimitive(ChainRulesCtx(), f, args...)
        t = tape.c.tracer   # a bit weird backref, but let it be for this example
        res = trace!(t, get_code_info(f, args...), v_fargs...)
        v_val, v_pb = tape[res].args    # destructure tuple constructed as the return value from rrule
        tape.c.pullbacks[v_val] = v_pb        
        return v_val
    else
        return push!(tape, mkcall(v_fargs...))
    end
end

Then, during the reverse pass, we can trace the saved pullback and re-map captured values to variables from the primal subtape.

This is pretty sophisticated approach, but so far it looks doable.

(Todo: check out how JAX implements it)

@dfdx
Copy link
Owner Author

dfdx commented Aug 21, 2022

With the current vision, this ideas is unlikely to land in the foreseeable future.

@dfdx dfdx closed this as completed Aug 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant