Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance analysis #131

Closed
dfdx opened this issue Oct 20, 2022 · 1 comment
Closed

Performance analysis #131

dfdx opened this issue Oct 20, 2022 · 1 comment

Comments

@dfdx
Copy link
Owner

dfdx commented Oct 20, 2022

Here I'm going to track the performance issues with Yota

Starter code
using Yota
using Yota.Umlaut
using Metalhead
using Profile
using ProfileView


loss(model, image) = sum(model(image))

function main()
    model = Metalhead.ResNet(18)
    image = rand(Float32, 224, 224, 3, 1)
    @time model(image)
    @time trace(model, image; ctx=GradCtx())
    @profile grad(loss, model, image)
    Profile.print(mincount=100)
end
Currently, the gradient of ResNet takes forever (or at least > 30 minutes). Function execution takes ~10 seconds, tracing it - 60 seconds, so most of the time is spent in `grad()`.

Profiler output after 2 minutes of execution

My current interpretation is as follows:

  • Flux extensively uses higher-order functions.
  • During initial tracing, Yota simply writes these higher-order functions to the tape, which is relatively fast.
  • These functions are rewritten into corresponding rrules.
  • The rrules invoke rrule_via_ad, which, in their turn, trigger tracing of the argument function. So the bottleneck is indeed tracing, though not the initial tracer pass.
  • Tracing is slow because Julia turns on type inference and specializes each call (a lot of slowness comes from mkcall()).
  • Simply wrapping the tape and the tracing into @nospecialize helped a bit in tests, but is definitely not a game changer.

In the original design, Yota wasn't supposed to trigger compilation during backpropagation. In fact, the design was very similar to JAX, with the only exception that we used IR-level tracing instead of operator overloading due to issues in multiple dispatch. Just to recap, the first versions of Yota worked like this:

  1. Trace a function to turn it into a list of primitives for which we know the differentiation rules (forward pass).
  2. Record these differentiation rules to the same tape (backward pass).
  3. Compile the tape.

So exactly one tracing and one compilation. However, prevalence of ChainRules changed the game. The current design looks like this:

  1. Trace a function.
  2. Replace all primitive calls y = f(xs...) with y, pb = rrule(f, xs...). Save pullbacks for the next step.
  3. Record pullback invocations.
  4. Compile the tape.

Now Yota has no control over what happens in an rrule and, in case of higher-order functions, cannot avoid additional tracing and compilation. Since Flux uses higher-order functions extensively, we get what we get.

As far as I can see, the only way forward is to speed up tracing. However, it requires a really, really good understanding of the Julia compiler, which I don't have. To my knowledge, the only autodiff package that managed to do it is Diffractor, and not too many people understand how it works.

So I'm pretty much puzzled.

@dfdx
Copy link
Owner Author

dfdx commented Oct 30, 2022

Thanks to all the commenters in this thread, tracing now works ~2x times faster. I also fixed a performance bug in todo_list(), and now the whole grad(loss, model, image) compiles and runs in 61 second, which is reasonably good.

@dfdx dfdx closed this as completed Oct 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant