-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transfer data directly to the device #5772
Conversation
The issue here is that I was calculating the In the numerous special cases in In cases where the source type does not match the output type, I believe we'll still have to "stage" the data in an |
I have everything working locally now. Separating the more tedious changes here into #5777. After this PR, we'll still have to make an intermediate copy if the input tensor type does not match the target type. I added a counter to capture this overhead, since it may have a performance impact. For what it's worth, casting the Getting around the copy is simple: just create the tensors such that the type matches what it will be on the device. So if you want a |
13035ed
to
df239db
Compare
The TPU CI is currently hanging on |
* Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516.
This reverts commit 4225deb.
30abda8
to
62cf72d
Compare
@jonb377 and I were able to figure out where the deadlock is. The hang is caused by a GIL deadlock when we try to retrieve data from the device before a transfer finishes, and the transfer is the only thing keeping an Here's what's happening:
Here's the relevant stack trace through
|
I fixed a similar GIL deadlock bug about a year ago in #4504. In that case, the solution was to release the GIL during |
I wrapped the GIL release and data transfer into a new utility, I'm open to better names for |
@@ -37,17 +37,17 @@ class PjRtComputationClient : public ComputationClient { | |||
std::optional<xla::OpSharding> GetDataSharding(DataPtr handle) override; | |||
|
|||
std::vector<DataPtr> TransferToServer( | |||
absl::Span<const TensorSource> tensors) override; | |||
absl::Span<const std::shared_ptr<const TensorSource>> tensors) override; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we require the caller to manage the ownership & lifetime of TensorSource*, i.e., const TensorSource*
instead, or it's necessary to ensure that the memory is held during the client ops and in the client?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. PJRT lets us tie the lifetime of an object to an operation by capturing it in a callback. std::function
s have to be copyable, so shared_ptr
is our best choice here. TensorSource
itself may be expensive or impossible to copy.
The caller of TransferToServer
will be much shorter-lived than the actual transfer, so ownership should pass down. We could tighten up the interface here and consume a unique_ptr
since we only need copyability within the implementation of TransferToServer
. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I will try to take a look today |
* Transfer data directly to the device (pytorch#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (pytorch#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (pytorch#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
Take two. See the notes from the original PR #5752
New changes:
at::Tensor
instead of the strides in thexla::Shape
.at::Tensor
if the target type differs from the actual type.XlaDataToTensors
, sinceat::Tensor
destruction afterTransferToServer
can deadlock withTransferFromServer
.