-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] PyTorch/XLA eager mode as default #7253
Comments
Is it possible to support the mixed use of PyTorch eager (for CUDA tensors) and XLA eager (for XLA tensors), for instance, in cases where some operations are not supported by XLA or operations that introduce dynamic shapes need to be executed by PyTorch? |
@baoleai You can do that. Now @vanbasten23 added the dlpack support you can just use that api to do a zero-copy convert between XLA:GPU tensor to cuda tensor then operations will happens on eager cuda. IMO the downside is that Pytorch/XLA execution(even for eager) is async, you will need to add a |
That being said. @baoleai I don't think there is much point of using PyTorch/XLA:GPU eager, you can just use pytorch eager which is likely faster and use |
I was able to get most of the basic stuffs( I also realized the perfomrance of the eager mode is really model dependent, decoder only model can have ~45% of the throughput while resnet is really slow(on my current run eager mode throuput is 1/400 of the compiled). I chatted with Blake and ~ |
Context
Objective
In this RFC I will talk about the roadmap to enable eager mode as the default computation mode for PyTorch/XLA users and how to enable graph compilation in this mode.
Background
PyTorch/XLA has been using tracing mode as the default mode since the project started. All of the torch operation users issued will be accumulated in the background and sent to the XLA for compilation and execution upon a
mark_step
call.The upside of this approach is that users don’t need to change their model code too much. As long as the user adds a
mark_step
at the right place everything should just work. However from the user feedback in the last couple years this approach creates too much confusion and frustration for the user. Both PyTorch and JAX took the approach of using eager mode as default and asking users to specify the function that they want to compile. PyTorch/XLA should take the same approach.Design
Eager mode
There is no real eager mode in TPU. However we can fake the eager mode by compiling and executing each torch operation. Such mode already exist as a debug only mode today, it was contributed by @aws-rhsoln 2 year ago in #3306. The work here is to do a better API level wrapping and make sure this mode work with other features(debug output, SPMD, multiprocess etc). This approach was way too slow a couple years ago due to XRT not being able to execute small executions very efficiently but with PJRT the performance is much better.
The whole eager mode still builds on top of the existing Lazy tensor framework, but becomes invisible to the user. A couple things we need to do to accommodate the eager mode are
Compile
For the compile part we currently have 2 options, lazy tensor and torch dynamo(torch.compile).
For lazy tensor based compile I will add a new API_
Which under the hood just enables the tracing mode upon running the function and executes the traced graph before returning. Here is the implementation. For
torch.compile
we can just use the existing API.Example UX
Note that two changes user need to make is to enable the eager mode by
torch_xla.experimental.eager_mode(True)
and then compile the step function withtorch_xla.experimental.compile
ortorch.compile
.Users can also choose to run the whole model in eager mode.
Why
IMO using tracing mode as the default has a couple very significant drawback
mark_step
.Both JAX and PyTorch took the approach of asking users to explicitly mark the region/function for compilation. This methodology seems well received for users that want compilation mode. I think this proposal will make a much better usability story by
compiled_fn
should generate graphs.Benchmark
I am running a 2 layer decoder only model training(it is pretty much just a llama2) with fake data on a single chip of v4-8 for 300 steps. This is not a very scientific benchmark so take it with a grain of salt.
Eager mode can achieve ~45% performance of the fully compiled model for the decoder only model. The trainer I used to test can be found here and here.
Work Breakdown
torch_xla.experimental.compile
(done)torch.compile
(pr)Timeline
2.4 release -> experimental
2.5 release -> beta
2.6 release -> enable by default
cc @ezyang @bdhirsh @wconstab @baoleai @amithrm @jeffhataws @albanD @gkroiz @Liyang90
The text was updated successfully, but these errors were encountered: