-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dlpack support #7025
Add dlpack support #7025
Conversation
Currently, the tests fails with
callstack https://gist.github.com/vanbasten23/a42b196b2fcd5e985a083752386dd3f8 It fails at https://github.com/openxla/xla/blob/0aa2ab4df32bdb099664b0edeef991f40ff1af49/xla/pjrt/pjrt_stream_executor_client.cc#L1322 when i==1 (kExternalReference). I think we are calling the Edit: i figured it out and fixed it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly lgtm, minor nits.
std::shared_ptr<xla::PjRtBuffer> PjRtComputationClient::GetPjRtBuffer( | ||
const DataPtr handle) { | ||
std::shared_ptr<PjRtData> pjrt_data = | ||
std::dynamic_pointer_cast<PjRtData>(handle); | ||
XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); | ||
return pjrt_data->buffer; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this method has to be a class method of PjRtComputationClient
? It doesn;t access any class private members. If it is just a helper, you don't need this to b a class method. @will-cromar wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The struct PjRtData
is private to PjRtComputationClient. That's why I have to make it a class method.
ComputationClient::DataPtr PjRtComputationClient::CreateData( | ||
std::string device, xla::Shape shape, | ||
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer) { | ||
return std::make_shared<PjRtData>(std::move(device), std::move(shape), | ||
pjrt_buffer); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you just need a helper, not the class method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The struct PjRtData
is private to PjRtComputationClient so I can't make it a helper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the input pjrt_buffer
depend on the current instance of PjRtClient
? ie do you need to check here that this buffer belongs to this client? If you don't depend on any members of PjRtComputationClient
, can you make this a static method to make that clearer?
Also, this should be private and probably not part of ComputationClient
. The public API of PjRtComputationClient
shouldn't operate on PJRT primitives. It should be a complete wrapper to leave the door open to other runtime interfaces (namely IFRT).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, the pjrt_buffer
doesn't depend on the current instance of PjRtClient
. I have made it a static method.
I also made it not a part of ComputationClient
. Not sure how I can make it private and call it in the dl_convertor.cpp at the same time.
hey @will-cromar @JackCaoG , it seems the CI skips tests that require PyTorch CUDA support, for example
so the tests I added in this PR are not actually run by the CI. What do you think of adding a new CI workflow that builds pytorch with CUDA, builds pytorch/xla, and runs the test? By preserving the existing CI workflow (build pytorch with CUDA disabled), people can still get faster feedback. |
Yeah, that's a loose end from all of the refactoring I did. I agree with adding a new workflow or branch of the workflow that adds CUDA separately. If it's okay with you, the easiest option would be to download the pre-built nightly CUDA wheel from https://download.pytorch.org/whl/nightly/cu121. This will be faster and require less maintenance, but it will also break periodically if there's a breaking change on head since the last nightly build. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This is really neat
ComputationClient::DataPtr PjRtComputationClient::CreateData( | ||
std::string device, xla::Shape shape, | ||
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer) { | ||
return std::make_shared<PjRtData>(std::move(device), std::move(shape), | ||
pjrt_buffer); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the input pjrt_buffer
depend on the current instance of PjRtClient
? ie do you need to check here that this buffer belongs to this client? If you don't depend on any members of PjRtComputationClient
, can you make this a static method to make that clearer?
Also, this should be private and probably not part of ComputationClient
. The public API of PjRtComputationClient
shouldn't operate on PJRT primitives. It should be a complete wrapper to leave the door open to other runtime interfaces (namely IFRT).
@@ -84,6 +91,15 @@ class IfrtComputationClient : public ComputationClient { | |||
absl::AsciiStrToUpper(client_->platform_name())); | |||
}; | |||
|
|||
xla::PjRtPlatformId GetPlatformID() const override { | |||
return client_->platform_id(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Heads up: this method doesn't exist in the PJRT C API: https://github.com/openxla/xla/blob/80b854f9ef3e8b913ed9ca2d930c81e32c6d02da/xla/pjrt/pjrt_c_api_client.h#L209-L211
Since dlpack is a very special case, what do you think of just adding a is_cuda
or supports_dlpack
attribute to PjRtComputationClient
and setting it during init?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Heads up: this method doesn't exist in the PJRT C API
The client_
here has type xla::ifrt::PjRtClient
which implements platform_id
at https://github.com/openxla/xla/blob/62568f0ef380c883defd44a060722dd62e81df1b/xla/python/pjrt_ifrt/pjrt_client.h#L147-L150. I wonder how the xla::ifrt::PjRtClient
relate to the PJRT C API PjRtCApiTopologyDescription
.
what do you think of just adding a is_cuda or supports_dlpack attribute to PjRtComputationClient and setting it during init?
Can you elaborate more? Do you mean adding a supports_dlpack
to ComputationClient
and set it true for PjRtComputationClient
and false to IfRtComputationClient
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PJRT C API is a subset of the PJRT C++ API, so some methods of the C++ -> C wrapper are unimplemented. I would add a field to PjRtComputationClient
to indicate if dlpack is expected to work with the underlying PjRtClient
(only true when PJRT_DEVICE=CUDA
), since that depends on the specific device type, which we should not rely on after initialization.
I'm not concerned about IFRT right now. Also, remember that xla::PjRtClient
is different than xla::ifrt::PjRtClient
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. If the concern is dlpack should only be used when PJRT_DEVICE=CUDA, then there is a check in this PR https://github.com/pytorch/xla/pull/7025/files#diff-cf3922091c803bce3341e4e55b2c54d277812059adeafa297eb1c7b444213b2aR45-R53 to make sure dlpack is used for CUDA instead of TPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really great work, @vanbasten23. I have left a few comments in the PR. I do have a couple of questions, though:
-
Could you add an overview of the functions involved in the DLPack-XLA (to/from) execution paths to the PR description? I feel like that would make your PR easier to follow.
-
Before converting CUDA to DLPack to XLA, don't we need to call
torch.cuda.synchronize()
? That's because IIRC, some CUDA computation is also lazy.
Added the execution paths.
Good callout. I couldn't find this in the pytorch's documentation. But from their test pytorch/test/test_dlpack.py, it seems it is only needed when |
b0198df
to
eeb1adf
Compare
@will-cromar Do you mean that, at some point, we stopped using |
>>> a = torch.rand(1024, 1024, 1024, device="cuda")
# This returns instantly
>>> for i in range(1000):
>>> a = a @ a
# Does NOT return instantly. Can hear the GPU fans going up.
>>> torch.cuda.synchronize() I think that, at first, we can just call |
Yeah, that's right. The TPU CI, CPU/GPU CI, and nightly build were all using different build scripts. While splitting the XLA CUDA build from the |
Right I observed the same in #7025 (comment). I'm adding a new workflow that build pytorch with cuda enabled and only exercise those tests requiring pytorch cuda. For now, I'll just run locally and the tests pass. |
Thanks for the suggestion. I've seen a few incompatible torch wheel and torch_xla wheel recently so I feel it is not very uncommon. If they are incompatible, our CI will be red for a day or two and we'll have to communicate to the team which is kind of a hassle lol. So I'll strive for building pytorch with cuda from source: #7073 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need more comments, e.g.:
- where each function was copied from?
- how to use the added API correctly?
Otherwise, LGTM.
I will try to take a look tmr |
…ice to host transfer.
…olds_[i] == 0 (1 vs. 0)
Thanks for the review. I've added the comments you suggested. |
8375645
to
98cc739
Compare
// If ext_data is the result of an CUDA computation, we should synchronize | ||
// (waits for all kernels in all streams on a CUDA device to complete) if the | ||
// current stream is different from the ext_data's stream. Otherwise, we may | ||
// risk of getting incorrect results. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is _from_dlpack
's comment.
hey @JackCaoG could you take a look at the PR when you are available? Thanks! |
Thanks for the review Jack and Yukio! |
XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null."; | ||
|
||
runtime::ComputationClient::DataPtr data = | ||
runtime::PjRtComputationClient::CreateData( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: Does this only apply to PjRtComputationClient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Background: I am working on a new ComputationClient and here reports the error:
torch_xla/csrc/dl_convertor.cpp:337:16: error: 'torch_xla::runtime::PjRtComputationClient' has not been declared
337 | runtime::PjRtComputationClient::CreateData(
| ^~~~~~~~~~~~~~~~~~~~~
Looks like we have to use PjRtComputationClient with this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. As far as I understand it, the DLPack API only works with XLA:CUDA, i.e. PJRT.
This PR adds basic dlpack support for pytorch/xla.
Execution path:
Test plans: PJRT_DEVICE=CUDA python pytorch/xla/test/test_operations.py -k TestDLPack.test_
references: