Skip to content

Commit

Permalink
[pjrt] Removed deprecated PjRtBuffer::CopyToDevice
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723097372
  • Loading branch information
superbobry authored and Google-ML-Automation committed Feb 5, 2025
1 parent b58aee5 commit 2859972
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 57 deletions.
7 changes: 4 additions & 3 deletions xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2034,9 +2034,10 @@ PJRT_Error* PJRT_Buffer_CopyToDevice(PJRT_Buffer_CopyToDevice_Args* args) {
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
"PJRT_Buffer_CopyToDevice_Args",
PJRT_Buffer_CopyToDevice_Args_STRUCT_SIZE, args->struct_size));
PJRT_ASSIGN_OR_RETURN(
std::unique_ptr<xla::PjRtBuffer> dst_buffer,
args->buffer->buffer->CopyToDevice(args->dst_device->device));
PJRT_ASSIGN_OR_RETURN(xla::PjRtMemorySpace * memory_space,
args->dst_device->device->default_memory_space());
PJRT_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtBuffer> dst_buffer,
args->buffer->buffer->CopyToMemorySpace(memory_space));
args->dst_buffer =
new PJRT_Buffer{std::move(dst_buffer), args->dst_device->client};
return nullptr;
Expand Down
5 changes: 3 additions & 2 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -944,8 +944,9 @@ TEST(StreamExecutorGpuClientTest, AsyncCopyToDevice) {
auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice({src_literal.shape()}, d0));
auto src_buffer = transfer_manager->RetrieveBuffer(0);
// CopyToDevice won't be enqueued until src_buffer is available.
auto local_recv_buffer = *src_buffer->CopyToDevice(d1);
// CopyToMemorySpace won't be enqueued until src_buffer is available.
auto local_recv_buffer =
*src_buffer->CopyToMemorySpace(*d1->default_memory_space());

TF_ASSERT_OK(
transfer_manager->TransferLiteralToBuffer(0, src_literal, []() {}));
Expand Down
6 changes: 0 additions & 6 deletions xla/pjrt/interpreter/interpreter_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,6 @@ class InterpreterLiteralWrapperBuffer final : public PjRtBuffer {

bool IsDeleted() override { return is_deleted_; }

absl::StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
PjRtDevice* dst_device) override {
return absl::UnimplementedError(
"CopyToDevice not supported by InterpreterLiteralWrapperBuffer.");
}

absl::StatusOr<std::unique_ptr<PjRtBuffer>> CopyToMemorySpace(
PjRtMemorySpace* dst_memory_space) override {
return absl::UnimplementedError(
Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/pjrt_c_api_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ TEST(PjRtClientTest, CreateViewAndCopyToDeviceAsyncExternalCpuOnly) {

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtBuffer> result,
buffer->CopyToDevice(client->addressable_devices()[1]));
buffer->CopyToMemorySpace(client->memory_spaces()[1]));
buffer.reset();
ASSERT_TRUE(result);
TF_ASSERT_OK_AND_ASSIGN(auto literal, result->ToLiteralSync());
Expand Down
15 changes: 0 additions & 15 deletions xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -1248,21 +1248,6 @@ class PjRtBuffer {
// True if and only if Delete or Release has previously been called.
virtual bool IsDeleted() = 0;

// Copies the buffer to device `dst_device`, performing a d2d transfer when
// `dst_device` is sharing the same Client, and performing a d2h and h2d copy
// if `dst_device` lives on a different Client.
// Returns an error if the buffer is already on dst_device.
//
// See note on semantics of cross-device copies in the class definition
// comment for PjRtClient.
ABSL_DEPRECATED("Use CopyToMemorySpace instead")
virtual absl::StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
PjRtDevice* dst_device) {
TF_ASSIGN_OR_RETURN(PjRtMemorySpace * dst_memory_space,
dst_device->default_memory_space());
return CopyToMemorySpace(dst_memory_space);
};

// Copies the buffer to memory space `dst_memory_space`.
//
// The destination memory space may be attached to any client, but optimized
Expand Down
9 changes: 6 additions & 3 deletions xla/pjrt/pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,8 @@ TEST(PjRtClientTest, CopyToDevice) {

auto* device_1 = client->addressable_devices()[1];

TF_ASSERT_OK_AND_ASSIGN(auto result, buffer->CopyToDevice(device_1));
TF_ASSERT_OK_AND_ASSIGN(auto result, buffer->CopyToMemorySpace(
*device_1->default_memory_space()));

TF_ASSERT_OK_AND_ASSIGN(auto literal, result->ToLiteralSync());

Expand Down Expand Up @@ -434,7 +435,8 @@ TEST(PjRtClientTest, CopyToDeviceAsync) {
constexpr int kConcurrentCopy = 16;
std::vector<std::unique_ptr<PjRtBuffer>> results(kConcurrentCopy);
for (int i = 0; i < kConcurrentCopy; ++i) {
TF_ASSERT_OK_AND_ASSIGN(results[i], buffer->CopyToDevice(device_1));
TF_ASSERT_OK_AND_ASSIGN(results[i], buffer->CopyToMemorySpace(
*device_1->default_memory_space()));
}

// The destructor of TfrtCpuBuffer should wait for outstanding copy.
Expand Down Expand Up @@ -480,7 +482,8 @@ TEST(PjRtClientTest, CopyToDeviceAsyncExternalCpuOnly) {
constexpr int kConcurrentCopy = 16;
std::vector<std::unique_ptr<PjRtBuffer>> results(kConcurrentCopy);
for (int i = 0; i < kConcurrentCopy; ++i) {
TF_ASSERT_OK_AND_ASSIGN(results[i], buffer->CopyToDevice(device_1));
TF_ASSERT_OK_AND_ASSIGN(results[i], buffer->CopyToMemorySpace(
*device_1->default_memory_space()));
}

// The destructor of TfrtCpuBuffer should wait for outstanding copy.
Expand Down
44 changes: 17 additions & 27 deletions xla/python/pjrt_ifrt/pjrt_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,15 +426,13 @@ absl::StatusOr<tsl::RCReference<Array>> PjRtArray::Copy(
TF_ASSIGN_OR_RETURN(Device * buffer_device,
client_->LookupPjRtDevice(pjrt_buffers_[i]->device()));
bool devices_equal = buffer_device == new_sharding_devices[i];
bool memories_supported = pjrt_buffers_[i]->memory_space() != nullptr;
bool memory_kind_equal =
new_sharding_has_memory_kind && memories_supported &&
new_sharding_has_memory_kind &&
pjrt_buffers_[i]->memory_space()->kind() ==
canonicalized_sharding_memory_kind.memory_kind();

// No need for data transfer.
if (devices_equal && (!new_sharding_has_memory_kind ||
!memories_supported || memory_kind_equal)) {
if (devices_equal && (!new_sharding_has_memory_kind || memory_kind_equal)) {
switch (semantics) {
case ArrayCopySemantics::kAlwaysCopy:
// HBM is the only thing that doesn't support same-device copy and
Expand Down Expand Up @@ -480,38 +478,30 @@ absl::StatusOr<tsl::RCReference<Array>> PjRtArray::Copy(
return InvalidArgument("Cannot copy array to non-addressable device %s",
pjrt_device->DebugString());
}
// Use `PjRtBuffer::CopyToMemorySpace` instead of
// `PjRtBuffer::CopyToDevice` when memories are supported. Because the
// semantics of the latter one is to copy to the default memory space of
// the device.
if (new_sharding_has_memory_kind && memories_supported) {
PjRtMemorySpace* pjrt_memory_space = nullptr;
if (new_sharding_has_memory_kind) {
TF_ASSIGN_OR_RETURN(
auto memory,
GetMemorySpaceFromMemoryKind(new_sharding_devices[i],
canonicalized_sharding_memory_kind));
PjRtMemory* pjrt_memory = llvm::dyn_cast<PjRtMemory>(memory);
TF_RET_CHECK(pjrt_memory != nullptr);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> copied_buffer,
pjrt_buffers_[i]->CopyToMemorySpace(pjrt_memory->pjrt_memory()));
if (semantics == ArrayCopySemantics::kDonateInput) {
if (!memory_kind_equal) {
return Unimplemented(
"Donation across different memory kinds is not implemented.");
}
pjrt_buffers_[i] = nullptr;
}
buffers.push_back(std::move(copied_buffer));
pjrt_memory_space = pjrt_memory->pjrt_memory();
} else {
// Use `PjRtBuffer::CopyToDevice` when memories are not supported.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<xla::PjRtBuffer> copied_buffer,
pjrt_buffers_[i]->CopyToDevice(pjrt_device->pjrt_device()));
if (semantics == ArrayCopySemantics::kDonateInput) {
pjrt_buffers_[i] = nullptr;
TF_ASSIGN_OR_RETURN(pjrt_memory_space,
pjrt_device->pjrt_device()->default_memory_space());
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> copied_buffer,
pjrt_buffers_[i]->CopyToMemorySpace(pjrt_memory_space));
if (semantics == ArrayCopySemantics::kDonateInput) {
if (!memory_kind_equal) {
return Unimplemented(
"Donation across different memory kinds is not implemented.");
}
buffers.push_back(std::move(copied_buffer));
pjrt_buffers_[i] = nullptr;
}
buffers.push_back(std::move(copied_buffer));
}
}
return std::visit(
Expand Down

0 comments on commit 2859972

Please sign in to comment.